Unverified Commit 07f94463 authored by Rex's avatar Rex Committed by GitHub
Browse files

Add awq dequantize kernel to sgl with 1x to 3x speedup (#4104)

parent e0917e6b
import itertools
from typing import List, Tuple
import torch
import triton
import triton.testing
from sgl_kernel import awq_dequantize
from vllm import _custom_ops as ops
def vllm_awq_dequantize(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
def sglang_awq_dequantize(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
return awq_dequantize(qweight, scales, qzeros)
def calculate_diff(qweight_row: int, qweight_col: int):
"""Calculate difference between VLLM and SGLang implementations."""
device = torch.device("cuda")
qweight = torch.randint(
0,
torch.iinfo(torch.int32).max,
(qweight_row, qweight_col),
dtype=torch.int32,
device=device,
)
group_size = qweight_row
scales_row = qweight_row // group_size
scales_col = qweight_col * 8
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
qzeros = torch.randint(
0,
torch.iinfo(torch.int32).max,
(scales_row, qweight_col),
dtype=torch.int32,
device=device,
)
vllm_out = vllm_awq_dequantize(qweight, scales, qzeros)
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
if torch.allclose(
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
qweight_row_range = [3584, 18944, 128, 256, 512, 1024]
qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128]
configs = list(itertools.product(qweight_row_range, qweight_cols_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["qweight_row", "qweight_col"],
x_vals=configs,
line_arg="provider",
line_vals=["vllm", "sglang"],
line_names=["VLLM", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="awq-dequantize-performance",
args={},
)
)
def benchmark(qweight_row, qweight_col, provider):
dtype = torch.float16
device = torch.device("cuda")
qweight = torch.randint(
0,
torch.iinfo(torch.int32).max,
(qweight_row, qweight_col),
dtype=torch.int32,
device=device,
)
group_size = qweight_row
scales_row = qweight_row // group_size
scales_col = qweight_col * 8
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
qzeros = torch.randint(
0,
torch.iinfo(torch.int32).max,
(scales_row, qweight_col),
dtype=torch.int32,
device=device,
)
quantiles = [0.5, 0.2, 0.8]
if provider == "vllm":
fn = lambda: vllm_awq_dequantize(
qweight.clone(), scales.clone(), qzeros.clone()
)
elif provider == "sglang":
fn = lambda: sglang_awq_dequantize(
qweight.clone(), scales.clone(), qzeros.clone()
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
calculate_diff(qweight_row=3584, qweight_col=448)
benchmark.run(print_data=True)
// Adapted from
// https://github.com/vllm-project/vllm/blob/eb59b5a6cba6727d3727c0372258db9002f687c1/csrc/quantization/awq/gemm_kernels.cu#L350
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h>
#include <torch/all.h>
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
uint4 result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is
// thanks to the register packing format and the fact that we force our
// integers to be unsigned, and account for this in the fp16 subtractions. In
// addition, I exploit the fact that sub and fma have the same throughput in
// order to convert elt_23 and elt_67 to fp16 without having to shift them to
// the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// dependency if we issue immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// This is the half2 {1024, 1024} represented as an integer.
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-64, -64} represented as an integer.
static constexpr uint32_t NEG_64 = 0xd400d400;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
return result;
#else
assert(false);
return {};
#endif
}
__global__ void __launch_bounds__(256) dequantize_weights(
int* __restrict__ qweight,
half* __restrict__ scales,
int* __restrict__ qzeros,
half* __restrict__ output,
int group_size,
int qweight_cols) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
int row = blockIdx.y * blockDim.y + threadIdx.y;
uint4 zeros = dequantize_s4_to_fp16x2(qzeros[col + (row / group_size) * qweight_cols]);
uint4 loaded_scale = *(uint4*)(scales + 8 * col + (row / group_size) * qweight_cols * 8);
uint4 weight_fp16 = dequantize_s4_to_fp16x2(qweight[col + row * qweight_cols]);
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(zeros.x));
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(loaded_scale.x));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(zeros.y));
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(loaded_scale.y));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(zeros.z));
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(loaded_scale.z));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(zeros.w));
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(loaded_scale.w));
half* output_ptr = output + 8 * col + 8 * row * qweight_cols;
*(uint4*)output_ptr = weight_fp16;
}
torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros) {
int qweight_rows = qweight.size(0);
int qweight_cols = qweight.size(1);
int group_size = qweight_rows / scales.size(0);
int x_num_threads = 16;
int y_num_threads = 16;
int x_blocks = qweight_cols / x_num_threads;
int y_blocks = qweight_rows / y_num_threads;
const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight));
auto output_tensor_options = torch::TensorOptions().dtype(scales.dtype()).device(scales.device());
at::Tensor output = torch::empty({qweight_rows, qweight_cols * 8}, output_tensor_options);
auto _qweight = reinterpret_cast<int*>(qweight.data_ptr<int>());
auto _scales = reinterpret_cast<half*>(scales.data_ptr<at::Half>());
auto _zeros = reinterpret_cast<int*>(qzeros.data_ptr<int>());
auto _output = reinterpret_cast<half*>(output.data_ptr<at::Half>());
dim3 num_blocks(x_blocks, y_blocks);
dim3 threads_per_block(x_num_threads, y_num_threads);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
_qweight, _scales, _zeros, _output, group_size, qweight_cols);
return output;
}
...@@ -75,6 +75,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -75,6 +75,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/* /*
* From csrc/gemm * From csrc/gemm
*/ */
m.def("awq_dequantize(Tensor qweight, Tensor scales, Tensor qzeros) -> Tensor");
m.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
m.def( m.def(
"int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " "int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
"bias) -> Tensor"); "bias) -> Tensor");
......
...@@ -112,6 +112,7 @@ void apply_rope_pos_ids_cos_sin_cache( ...@@ -112,6 +112,7 @@ void apply_rope_pos_ids_cos_sin_cache(
/* /*
* From csrc/gemm * From csrc/gemm
*/ */
torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros);
torch::Tensor int8_scaled_mm( torch::Tensor int8_scaled_mm(
const torch::Tensor& mat_a, const torch::Tensor& mat_a,
const torch::Tensor& mat_b, const torch::Tensor& mat_b,
......
...@@ -23,6 +23,7 @@ from sgl_kernel.elementwise import ( ...@@ -23,6 +23,7 @@ from sgl_kernel.elementwise import (
silu_and_mul, silu_and_mul,
) )
from sgl_kernel.gemm import ( from sgl_kernel.gemm import (
awq_dequantize,
bmm_fp8, bmm_fp8,
cublas_grouped_gemm, cublas_grouped_gemm,
fp8_blockwise_scaled_mm, fp8_blockwise_scaled_mm,
......
...@@ -4,6 +4,12 @@ import torch ...@@ -4,6 +4,12 @@ import torch
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
def awq_dequantize(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> torch.ByteTensor:
return torch.ops.sgl_kernels.awq_dequantize(qweight, scales, qzeros)
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return torch.ops.sgl_kernel.int8_scaled_mm( return torch.ops.sgl_kernel.int8_scaled_mm(
mat_a, mat_a,
......
...@@ -150,6 +150,7 @@ sources = [ ...@@ -150,6 +150,7 @@ sources = [
"csrc/elementwise/rope.cu", "csrc/elementwise/rope.cu",
"csrc/gemm/bmm_fp8.cu", "csrc/gemm/bmm_fp8.cu",
"csrc/gemm/cublas_grouped_gemm.cu", "csrc/gemm/cublas_grouped_gemm.cu",
"csrc/gemm/awq_kernel.cu",
"csrc/gemm/fp8_gemm_kernel.cu", "csrc/gemm/fp8_gemm_kernel.cu",
"csrc/gemm/fp8_blockwise_gemm_kernel.cu", "csrc/gemm/fp8_blockwise_gemm_kernel.cu",
"csrc/gemm/int8_gemm_kernel.cu", "csrc/gemm/int8_gemm_kernel.cu",
......
import itertools
from typing import Optional, Tuple
import pytest
import torch
from sgl_kernel import awq_dequantize
from vllm import _custom_ops as ops
def vllm_awq_dequantize(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> torch.Tensor:
return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
def sglang_awq_dequantize(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> torch.Tensor:
return awq_dequantize(qweight, scales, qzeros)
@pytest.mark.parametrize(
"qweight_row,qweight_col",
list(
itertools.product(
[3584, 18944, 128, 256, 512, 1024], [448, 576, 4736, 16, 32, 64, 128]
)
),
)
def test_awq_dequant_compare_implementations(
qweight_row: int,
qweight_col: int,
):
device = torch.device("cuda")
qweight = torch.randint(
0,
torch.iinfo(torch.int32).max,
(qweight_row, qweight_col),
dtype=torch.int32,
device=device,
)
group_size = qweight_row
scales_row = qweight_row // group_size
scales_col = qweight_col * 8
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
qzeros = torch.randint(
0,
torch.iinfo(torch.int32).max,
(scales_row, qweight_col),
dtype=torch.int32,
device=device,
)
# Run both implementations
vllm_out = vllm_awq_dequantize(qweight, scales, qzeros)
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
# Compare results
torch.testing.assert_close(
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
)
if __name__ == "__main__":
# Run the specific test function directly
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment