Unverified Commit e0917e6b authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

Remove vllm ops scaled fp8 quant and accelerate per token quant by 20-28% (#4215)


Co-authored-by: default avatarStefan He <bhe@linkedin.com>
parent 7130a7ce
from typing import Optional
import torch
from torch import nn
......@@ -40,3 +42,60 @@ class CustomOp(nn.Module):
return self.forward_hip
else:
return self.forward_native
if _is_cuda:
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 (8-bit floating point) format.
Args:
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
determines the quantization granularity:
- True: compute scale per token
- False: compute single scale per tensor
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization
Raises:
AssertionError: If input is not 2D or if static scale's numel != 1
"""
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
output = torch.empty(shape, device=input.device, dtype=out_dtype)
if scale is None:
# Dynamic scaling
if use_per_token_if_dynamic:
scale = torch.empty(
(shape[0], 1), device=input.device, dtype=torch.float32
)
sgl_per_token_quant_fp8(input, output, scale)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=False
) # False for dynamic
else:
# Static scaling
assert (
scale.numel() == 1
), f"Expected scalar scale, got numel={scale.numel()}"
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=True
) # True for static
return output, scale
......@@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Tuple
import torch
from torch.nn import Module
from vllm import _custom_ops as ops
from vllm import _custom_ops as vllm_ops
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import (
......@@ -26,7 +26,13 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.utils import is_hip, set_weight_attrs
from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs
_is_cuda = is_cuda()
if _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
logger = logging.getLogger(__name__)
......@@ -719,12 +725,20 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
)
for expert in range(layer.num_experts_per_partition):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
if _is_cuda:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
else:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
return
......
......@@ -11,7 +11,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm import _custom_ops as ops
from vllm import _custom_ops as vllm_ops
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
......@@ -42,6 +42,7 @@ _is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
......@@ -486,7 +487,7 @@ def moe_align_block_size(
cumsum_buffer,
)
else:
ops.moe_align_block_size(
vllm_ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
......@@ -527,7 +528,10 @@ def invoke_fused_moe_kernel(
if block_shape is None:
# activation tensor-wise fp8 quantization, dynamic or static
padded_size = padding_size
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
if _is_cuda:
A, A_scale = sgl_scaled_fp8_quant(A, A_scale)
else:
A, A_scale = vllm_ops.scaled_fp8_quant(A, A_scale)
else:
# activation block-wise fp8 quantization
assert len(block_shape) == 2
......@@ -1095,12 +1099,16 @@ def fused_experts_impl(
if _is_cuda:
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else:
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
vllm_ops.silu_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, N)
)
elif activation == "gelu":
if _is_cuda:
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else:
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
vllm_ops.gelu_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, N)
)
else:
raise ValueError(f"Unsupported activation: {activation=}")
......@@ -1132,7 +1140,7 @@ def fused_experts_impl(
if no_combine:
pass
elif _is_hip:
ops.moe_sum(
vllm_ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
)
......
# Adapted from https://github.com/vllm-project/vllm/blob/8ca7a71df787ad711ad3ac70a5bd2eb2bb398938/tests/quantization/test_fp8.py
import pytest
import torch
from sglang.srt.custom_op import scaled_fp8_quant
from sglang.srt.utils import is_cuda
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant_per_tensor(dtype) -> None:
def quantize_ref_per_tensor(tensor, inv_scale):
# The reference implementation that fully aligns to
# the kernel being tested.
finfo = torch.finfo(torch.float8_e4m3fn)
scale = inv_scale.reciprocal()
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
qweight = qweight.to(torch.float8_e4m3fn)
return qweight
def dequantize_per_tensor(tensor, inv_scale, dtype):
fake_qweight = tensor.to(dtype)
dq_weight = fake_qweight * inv_scale
return dq_weight
# Note that we use a shape % 8 != 0 to cover edge cases,
# because scaled_fp8_quant is vectorized by 8.
x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)
# Test Per Tensor Dynamic quantization
# scale = max(abs(x)) / FP8_E4M3_MAX
y, scale = scaled_fp8_quant(x, None)
ref_y = quantize_ref_per_tensor(x, scale)
torch.testing.assert_close(y, ref_y)
torch.testing.assert_close(
dequantize_per_tensor(y, scale, dtype),
dequantize_per_tensor(ref_y, scale, dtype),
)
# Test Per Tensor Static quantization
y, _ = scaled_fp8_quant(x, scale)
ref_y = quantize_ref_per_tensor(x, scale)
torch.testing.assert_close(y, ref_y)
torch.testing.assert_close(
dequantize_per_tensor(y, scale, dtype),
dequantize_per_tensor(ref_y, scale, dtype),
)
if is_cuda:
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None:
def quantize_ref_per_token(tensor, inv_scale):
# The reference implementation that fully aligns to
# the kernel being tested.
finfo = torch.finfo(torch.float8_e4m3fn)
scale = inv_scale.reciprocal()
qweight = (tensor.to(torch.float32) * scale).clamp(
min=finfo.min, max=finfo.max
)
qweight = qweight.to(torch.float8_e4m3fn)
return qweight
def dequantize_per_token(tensor, inv_scale, dtype):
fake_qweight = tensor.to(dtype)
dq_weight = fake_qweight * inv_scale
return dq_weight
# Note that we use a shape % 8 = 0,
# because per_token_quant_fp8 is vectorized by 8 elements.
x = (torch.randn(size=(11, 16), device="cuda") * 13).to(dtype)
# Test Per Tensor Dynamic quantization
# scale = max(abs(x)) / FP8_E4M3_MAX
y, scale = scaled_fp8_quant(x, None, use_per_token_if_dynamic=True)
ref_y = quantize_ref_per_token(x, scale)
torch.testing.assert_close(y, ref_y)
torch.testing.assert_close(
dequantize_per_token(y, scale, dtype),
dequantize_per_token(ref_y, scale, dtype),
)
if __name__ == "__main__":
# Run the specific test function directly
pytest.main([__file__])
......@@ -14,7 +14,6 @@ __global__ void per_token_quant_fp8_kernel(
const int64_t hidden_dim,
const int64_t num_tokens) {
const int token_idx = blockIdx.x;
if (token_idx >= num_tokens) return;
const int tid = threadIdx.x;
......@@ -25,9 +24,20 @@ __global__ void per_token_quant_fp8_kernel(
float max_value = 0.0f;
for (int i = tid; i < hidden_dim; i += block_dim) {
float val = static_cast<float>(token_input[i]);
max_value = fmaxf(max_value, fabsf(val));
constexpr uint32_t vec_size = 16 / sizeof(T);
using vec_t = flashinfer::vec_t<T, vec_size>;
const int32_t num_vec_elems = hidden_dim / vec_size;
// Find max using vectorized loads
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
vec_t input_vec;
input_vec.cast_load(token_input + i * vec_size);
#pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) {
float val = static_cast<float>(input_vec[j]);
max_value = fmaxf(max_value, fabsf(val));
}
}
max_value = blockReduceMax(max_value);
......@@ -41,11 +51,7 @@ __global__ void per_token_quant_fp8_kernel(
const float scale_val = 1.0f / block_max;
constexpr uint32_t vec_size = 16 / sizeof(T);
using vec_t = flashinfer::vec_t<T, vec_size>;
const int32_t num_vec_elems = hidden_dim / vec_size;
// Quantize using vectorized loads
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
vec_t input_vec;
input_vec.cast_load(token_input + i * vec_size);
......@@ -53,7 +59,7 @@ __global__ void per_token_quant_fp8_kernel(
FP8_TYPE output_arr[vec_size];
#pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) {
float val = fmax(fmin(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
float val = fmaxf(fminf(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
#ifndef USE_ROCM
output_arr[j] = static_cast<FP8_TYPE>(val);
#else
......@@ -68,18 +74,6 @@ __global__ void per_token_quant_fp8_kernel(
token_output[i * vec_size + j] = output_arr[j];
}
}
const int32_t remaining_start = num_vec_elems * vec_size;
for (int32_t idx = remaining_start + tid; idx < hidden_dim; idx += block_dim) {
float val = fmax(-FP8_E4M3_MAX, fmin(static_cast<float>(token_input[idx]) * scale_val, FP8_E4M3_MAX));
#ifndef USE_ROCM
token_output[idx] = static_cast<FP8_TYPE>(val);
#else
token_output[idx] = c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
c10::Float8_e4m3fnuz::from_bits());
#endif
}
}
void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s) {
......@@ -91,7 +85,9 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
const int64_t num_tokens = input_sizes[0];
const int64_t hidden_dim = input_sizes[1];
const int block_size = 128;
TORCH_CHECK(hidden_dim % 8 == 0, "Hidden dimension must be divisible by 8, but got ", hidden_dim);
const int block_size = 256;
const int num_blocks = num_tokens;
dim3 grid(num_blocks);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment