Unverified Commit 4d643f6c authored by HandH1998's avatar HandH1998 Committed by GitHub
Browse files

[1/2] Support Qserve (#6457)


Co-authored-by: default avataryych0745 <1398089567@qq.com>
Co-authored-by: default avatarsleepcoo <sleepcoo@gmail.com>
parent 6ce0ed07
......@@ -203,6 +203,8 @@ set(SOURCES
"csrc/gemm/per_tensor_quant_fp8.cu"
"csrc/gemm/per_token_group_quant_8bit.cu"
"csrc/gemm/per_token_quant_fp8.cu"
"csrc/gemm/qserve_w4a8_per_chn_gemm.cu"
"csrc/gemm/qserve_w4a8_per_group_gemm.cu"
"csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_topk_softmax_kernels.cu"
......
import argparse
import copy
import itertools
import torch
import triton
from sgl_kernel import (
int8_scaled_mm,
qserve_w4a8_per_chn_gemm,
qserve_w4a8_per_group_gemm,
)
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
WEIGHT_SHAPES = {
"meta-llama/Llama-3.1-8B-Instruct": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-3.3-70B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
"mistralai/Mistral-Large-Instruct-2407": [
([12288, 14336], 1),
([12288, 12288], 0),
([12288, 57344], 1),
([28672, 12288], 0),
],
"Qwen/Qwen2.5-7B-Instruct": [
([3584, 4608], 1),
([3584, 3584], 0),
([3584, 37888], 1),
([18944, 3584], 0),
],
"Qwen/Qwen2.5-32B-Instruct": [
([5120, 7168], 1),
([5120, 5120], 0),
([5120, 55296], 1),
([27648, 5120], 0),
],
"Qwen/Qwen2.5-72B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 59136], 1),
([29568, 8192], 0),
],
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
([2048, 3072], 1),
([2048, 4096], 1),
([2048, 2048], 0),
([2048, 576], 0),
([2048, 21888], 1),
([10944, 2048], 0),
([2048, 2816], 1),
([1408, 2048], 0),
],
}
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
x_log=False,
line_arg="provider",
line_vals=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"],
line_names=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"],
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
ylabel="ms",
plot_name="FP16_vs_W8A8_vs_Qserve_W4A8_GEMM",
args={},
)
)
def benchmark(batch_size, provider, N, K):
M = batch_size
# For W8A8
a = to_int8(torch.randn((M, K), device="cuda") * 5)
b = to_int8(torch.randn((N, K), device="cuda").t() * 5)
a_fp16 = a.to(torch.float16)
b_fp16 = b.to(torch.float16)
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
# For Qserve W4A8 per channel
a_qserve_chn = a
# two int4s pack into one int8
b_qserve_chn = to_int8(torch.randn((N, K // 2), device="cuda") * 5)
# b_qserve_chn = b.t().contiguous()
scale_a_qserve_chn = scale_a.to(torch.float16)
scale_b_qserve_chn = scale_b.to(torch.float16)
szero_b_qserve_chn = torch.randn((N,), device="cuda", dtype=torch.float16)
a_sum_qserve_chn = torch.randn((M,), device="cuda", dtype=torch.float16)
# For Qserve W4A8 per group
group_size = 128
assert K % group_size == 0, "K must be divisible by group_size"
a_qserve_group = a
# two int4s pack into one int8
b_qserve_group = to_int8(torch.randn((N, K // 2), device="cuda") * 5)
# b_qserve_group = b.t().contiguous()
scale_a_qserve_group = scale_a.to(torch.float16)
scale_b_qserve_group = scale_b.to(torch.float16)
scale_i8_b_qserve_group = to_int8(
torch.randn((K // group_size, N), device="cuda", dtype=torch.float16)
)
zero_i8_b_qserve_group = to_int8(
torch.randn((K // group_size, N), device="cuda", dtype=torch.float16)
)
quantiles = [0.5, 0.2, 0.8]
if provider == "FP16":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch.matmul(a_fp16, b_fp16),
quantiles=quantiles,
)
if provider == "W8A8":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16),
quantiles=quantiles,
)
if provider == "Qserve_W4A8_Per_Channel":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: qserve_w4a8_per_chn_gemm(
a_qserve_chn,
b_qserve_chn,
scale_b_qserve_chn,
scale_a_qserve_chn,
szero_b_qserve_chn,
a_sum_qserve_chn,
),
quantiles=quantiles,
)
if provider == "Qserve_W4A8_Per_Group":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: qserve_w4a8_per_group_gemm(
a_qserve_group,
b_qserve_group,
zero_i8_b_qserve_group,
scale_i8_b_qserve_group,
scale_b_qserve_group,
scale_a_qserve_group,
),
quantiles=quantiles,
)
return ms, max_ms, min_ms
def prepare_shapes(args):
KN_model_names = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
assert model in WEIGHT_SHAPES
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KN.append(model)
KN_model_names.append(KN)
return KN_model_names
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.1-8B-Instruct"],
help="List of models to benchmark",
)
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
args = parser.parse_args()
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_qserve_w4a8_gemm_res",
N=N,
K=K,
)
print("Benchmark finished!")
......@@ -265,6 +265,19 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
*/
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
/*
* From QServe
*/
m.def(
"qserve_w4a8_per_chn_gemm(Tensor _in_feats, Tensor _kernel, Tensor _wscales, Tensor _ascales, Tensor _w_szs, "
"Tensor _a_ssums, Tensor! _out_feats) -> ()");
m.impl("qserve_w4a8_per_chn_gemm", torch::kCUDA, &qserve_w4a8_per_chn_gemm);
m.def(
"qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, "
"Tensor _ascales, Tensor! _out_feats) -> ()");
m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm);
}
REGISTER_EXTENSION(common_ops)
This diff is collapsed.
This diff is collapsed.
......@@ -404,3 +404,24 @@ void convert_vertical_slash_indexes_mergehead(
* From XGrammar
*/
void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional<at::Tensor> indices = at::nullopt);
/*
* From QServe
*/
void qserve_w4a8_per_chn_gemm(
const torch::Tensor& _in_feats,
const torch::Tensor& _kernel,
const torch::Tensor& _wscales,
const torch::Tensor& _ascales,
const torch::Tensor& _w_szs,
const torch::Tensor& _a_ssums,
torch::Tensor& _out_feats);
void qserve_w4a8_per_group_gemm(
const torch::Tensor& _in_feats,
const torch::Tensor& _kernel,
const torch::Tensor& _zeros,
const torch::Tensor& _scales_i8,
const torch::Tensor& _wscales,
const torch::Tensor& _ascales,
torch::Tensor& _out_feats);
......@@ -36,6 +36,8 @@ from sgl_kernel.gemm import (
fp8_blockwise_scaled_mm,
fp8_scaled_mm,
int8_scaled_mm,
qserve_w4a8_per_chn_gemm,
qserve_w4a8_per_group_gemm,
scaled_fp4_quant,
sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8,
......
......@@ -197,3 +197,47 @@ def scaled_fp4_quant(
)
output_scale = output_scale.view(torch.float8_e4m3fn)
return output, output_scale
def qserve_w4a8_per_chn_gemm(
in_feats: torch.Tensor,
kernel: torch.Tensor,
wscales: torch.Tensor,
ascales: torch.Tensor,
w_szs: torch.Tensor,
a_ssums: torch.Tensor,
out_feats: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out_feats is None:
# NOTE(HandH1998): qserve_w4a8_per_chn_gemm only supports out dtype=torch.float16 now
out_feats = torch.empty(
(in_feats.shape[0], kernel.shape[0]),
device=in_feats.device,
dtype=torch.float16,
)
torch.ops.sgl_kernel.qserve_w4a8_per_chn_gemm.default(
in_feats, kernel, wscales, ascales, w_szs, a_ssums, out_feats
)
return out_feats
def qserve_w4a8_per_group_gemm(
in_feats: torch.Tensor,
kernel: torch.Tensor,
zeros: torch.Tensor,
scales_i8: torch.Tensor,
wscales: torch.Tensor,
ascales: torch.Tensor,
out_feats: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out_feats is None:
# NOTE(HandH1998): qserve_w4a8_per_group_gemm only supports out dtype=torch.float16 now
out_feats = torch.empty(
(in_feats.shape[0], kernel.shape[0]),
device=in_feats.device,
dtype=torch.float16,
)
torch.ops.sgl_kernel.qserve_w4a8_per_group_gemm.default(
in_feats, kernel, zeros, scales_i8, wscales, ascales, out_feats
)
return out_feats
import pytest
import torch
from sgl_kernel import qserve_w4a8_per_chn_gemm
# Adapted from https://github.com/mit-han-lab/omniserve/blob/main/omniserve/modeling/layers/quantized_linear/w4a8_linear.py
def convert_to_qserve_format(qweight, scale, zero):
assert qweight.min() >= 0 and qweight.max() <= 15, "Quantized weight out of range"
in_features = qweight.shape[1]
out_features = qweight.shape[0]
assert in_features % 32 == 0, "Input features must be divisible by 32"
assert out_features % 32 == 0, "Output features must be divisible by 32"
# ---- Repack the weight ---- #
# pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4)
qweight_unpack_reorder = (
qweight.reshape(
out_features // 32,
2,
2,
8,
in_features // 32,
2,
4,
4,
)
.permute(0, 4, 3, 6, 1, 5, 2, 7)
.contiguous()
)
qweight_unpack_reorder = (
qweight_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7, 4)
.contiguous()
.to(torch.int8)
)
# B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous()
# [16, 0, 17, 1, ...]
qweight_unpack_repacked = (
qweight_unpack_reorder[..., 1] << 4
) + qweight_unpack_reorder[..., 0]
qweight_unpack_repacked = qweight_unpack_repacked.reshape(
out_features // 32, in_features // 32, 32, 16
)
qweight_unpack_repacked = qweight_unpack_repacked.reshape(
out_features, in_features // 2
).contiguous()
# ---- Pack the scales ---- #
scale = scale.reshape(out_features).to(torch.float16).contiguous()
szero = zero.reshape(out_features).to(torch.float16).contiguous() * scale
return qweight_unpack_repacked, scale, szero
# INT4 Quantization
def asym_quantize_tensor(tensor):
tensor_min = tensor.min(dim=-1, keepdim=True)[0]
tensor_max = tensor.max(dim=-1, keepdim=True)[0]
q_min = 0
q_max = 15
tensor_scale = (tensor_max - tensor_min) / (q_max - q_min)
tensor_zero = q_min - torch.round(tensor_min / tensor_scale)
tensor_q = torch.clamp(
torch.round(tensor / tensor_scale) + tensor_zero, q_min, q_max
).to(torch.int8)
return tensor_q, tensor_scale.to(torch.float16), tensor_zero.to(torch.int8)
# INT8 Quantization
def sym_quantize_tensor(tensor):
tensor_scale = tensor.abs().max(dim=-1, keepdim=True)[0] / 127
tensor_q = torch.clamp(torch.round(tensor / tensor_scale), -128, 127).to(torch.int8)
return tensor_q, tensor_scale.to(torch.float16)
def torch_w4a8_per_chn_gemm(a, b, a_scale, b_scale, b_zero, out_dtype):
print(a.shape)
print(b.shape)
print(b_zero.shape)
o = torch.matmul(
a.to(torch.float16), (b.to(torch.float16) - b_zero.to(torch.float16)).t()
)
o = o * a_scale.view(-1, 1) * b_scale.view(1, -1)
return o.to(out_dtype)
def _test_accuracy_once(M, N, K, out_dtype, device):
# to avoid overflow, multiply 0.01
a = torch.randn((M, K), device=device, dtype=torch.float32) * 0.01
b = torch.randn((N, K), device=device, dtype=torch.float32) * 0.01
# symmetric quantize a
a_q, a_scale = sym_quantize_tensor(a)
# asymmetric quantize b
b_q, b_scale, b_zero = asym_quantize_tensor(b)
# convert to qserve format
b_q_format, b_scale_format, b_szero_format = convert_to_qserve_format(
b_q, b_scale, b_zero
)
# cal sum of every row of a
a_sum = a.sum(dim=-1, keepdim=True).to(torch.float16)
out = qserve_w4a8_per_chn_gemm(
a_q, b_q_format, b_scale_format, a_scale, b_szero_format, a_sum
)
ref_out = torch_w4a8_per_chn_gemm(a_q, b_q, a_scale, b_scale, b_zero, out_dtype)
torch.testing.assert_close(out, ref_out, rtol=1e-3, atol=1e-2)
@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192])
@pytest.mark.parametrize("N", [128, 512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("out_dtype", [torch.float16])
def test_accuracy(M, N, K, out_dtype):
_test_accuracy_once(M, N, K, out_dtype, "cuda")
if __name__ == "__main__":
pytest.main([__file__])
import pytest
import torch
from sgl_kernel import qserve_w4a8_per_group_gemm
# Adapted from https://github.com/mit-han-lab/omniserve/blob/main/omniserve/modeling/layers/quantized_linear/w4a8_linear.py
def convert_to_qserve_format(qweight, chn_scale, scale_i8, zero_i8, group_size):
assert qweight.min() >= 0 and qweight.max() <= 15, "Quantized weight out of range"
in_features = qweight.shape[1]
out_features = qweight.shape[0]
assert in_features % 32 == 0, "Input features must be divisible by 32"
assert out_features % 32 == 0, "Output features must be divisible by 32"
assert group_size == 128, "Group size must be 128"
assert (
in_features % group_size == 0
), "Input features must be divisible by group_size"
# ---- Repack the weight ---- #
# pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4)
qweight_unpack_reorder = (
qweight.reshape(
out_features // 32,
2,
2,
8,
in_features // 32,
2,
4,
4,
)
.permute(0, 4, 3, 6, 1, 5, 2, 7)
.contiguous()
)
qweight_unpack_reorder = (
qweight_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7, 4)
.contiguous()
.to(torch.int8)
)
# B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous()
# [16, 0, 17, 1, ...]
qweigth_unpack_repacked = (
qweight_unpack_reorder[..., 1] << 4
) + qweight_unpack_reorder[..., 0]
qweigth_unpack_repacked = qweigth_unpack_repacked.reshape(
out_features // 32, in_features // 32, 32, 16
)
qweigth_unpack_repacked = qweigth_unpack_repacked.reshape(
out_features, in_features // 2
)
# ---- Pack the scales ---- #
chn_scale = chn_scale.reshape(out_features)
scale_i8 = (
scale_i8.reshape(out_features, in_features // group_size)
.transpose(0, 1)
.contiguous()
)
scale_i8 = scale_i8.reshape(in_features // group_size, out_features // 32, 32)
scale_i8 = (
scale_i8.reshape(in_features // group_size, out_features // 32, 4, 8)
.transpose(-2, -1)
.contiguous()
)
scale_i8 = scale_i8.reshape(in_features // group_size, out_features).contiguous()
# ---- Pack the zeros ---- #
zero_i8 = -zero_i8
# zero_i8 = zero_i8.int() # convert to 2-complement
zero_i8 = (
zero_i8.reshape(out_features, in_features // group_size)
.transpose(0, 1)
.contiguous()
)
zero_i8 = zero_i8.reshape(in_features // group_size, out_features // 32, 32)
# for the last dimension, organize as 0, 8, 16, 24, 1, 9, 17, 25, ... following the requirement of tensor core gemm
zero_i8 = (
zero_i8.reshape(in_features // group_size, out_features // 32, 4, 8)
.transpose(-2, -1)
.contiguous()
)
zero_i8 = (
zero_i8.reshape(in_features // group_size, out_features).contiguous() * scale_i8
)
return qweigth_unpack_repacked, chn_scale, scale_i8, zero_i8
# Progressive Group INT4 Quantization
def progressive_group_quantize_tensor(tensor, group_size):
assert group_size == 128, "Group size must be 128"
assert (
tensor.shape[-1] % group_size == 0
), "Input features must be divisible by group_size"
# Channel scale
# NOTE(HandH1998): use protective quantization range
chn_scale = tensor.abs().max(dim=-1, keepdim=True)[0] / 119
tensor_i8 = torch.clamp(torch.round(tensor / chn_scale), -119, 119)
# Group scale
tensor_i8 = tensor_i8.reshape(-1, group_size)
tensor_i8_min = tensor_i8.min(dim=-1, keepdim=True)[0]
tensor_i8_max = tensor_i8.max(dim=-1, keepdim=True)[0]
q_min = 0
q_max = 15
scale_i8 = torch.round((tensor_i8_max - tensor_i8_min) / (q_max - q_min))
zero_i8 = q_min - torch.round(tensor_i8_min / scale_i8)
tensor_q = (
torch.clamp(torch.round(tensor_i8 / scale_i8) + zero_i8, q_min, q_max)
.reshape(tensor.shape[0], -1)
.to(torch.int8)
)
return (
tensor_q,
chn_scale.to(torch.float16),
scale_i8.reshape(tensor.shape[0], -1).to(torch.int8),
zero_i8.reshape(tensor.shape[0], -1).to(torch.int8),
)
# INT8 Quantization
def sym_quantize_tensor(tensor):
tensor_scale = tensor.abs().max(dim=-1, keepdim=True)[0] / 127
tensor_q = torch.clamp(torch.round(tensor / tensor_scale), -128, 127).to(torch.int8)
return tensor_q, tensor_scale.to(torch.float16)
def torch_w4a8_per_group_gemm(
a, b, a_scale, b_chn_scale, b_scale_i8, b_zero_i8, group_size, out_dtype
):
assert group_size == 128, "Group size must be 128"
b_dq = (
b.reshape(-1, group_size).to(torch.float32)
- b_zero_i8.reshape(-1, 1).to(torch.float32)
) * b_scale_i8.reshape(-1, 1).to(torch.float32)
b_dq = b_dq.reshape(b.shape[0], b.shape[1])
o = torch.matmul(a.to(torch.float32), b_dq.t())
o = o * a_scale.view(-1, 1) * b_chn_scale.view(1, -1)
return o.to(out_dtype)
def _test_accuracy_once(M, N, K, group_size, out_dtype, device):
# to avoid overflow, multiply 0.01
a = torch.randn((M, K), device=device, dtype=torch.float32) * 0.01
b = torch.randn((N, K), device=device, dtype=torch.float32) * 0.01
# symmetric quantize a
a_q, a_scale = sym_quantize_tensor(a)
# asymmetric quantize b
b_q, b_chn_scale, b_scale_i8, b_zero_i8 = progressive_group_quantize_tensor(
b, group_size
)
# convert to qserve format
b_q_format, b_chn_scale_format, b_scale_i8_format, b_zero_i8_format = (
convert_to_qserve_format(b_q, b_chn_scale, b_scale_i8, b_zero_i8, group_size)
)
out = qserve_w4a8_per_group_gemm(
a_q,
b_q_format,
b_zero_i8_format,
b_scale_i8_format,
b_chn_scale_format,
a_scale,
)
ref_out = torch_w4a8_per_group_gemm(
a_q, b_q, a_scale, b_chn_scale, b_scale_i8, b_zero_i8, group_size, out_dtype
)
torch.testing.assert_close(out, ref_out, rtol=1e-3, atol=1e-5)
@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192])
@pytest.mark.parametrize("N", [128, 512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("group_size", [128])
@pytest.mark.parametrize("out_dtype", [torch.float16])
def test_accuracy(M, N, K, group_size, out_dtype):
_test_accuracy_once(M, N, K, group_size, out_dtype, "cuda")
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment