"examples/pytorch/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "c03046a08f1407bce914905509acbeb343e75db8"
Unverified Commit e0917e6b authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

Remove vllm ops scaled fp8 quant and accelerate per token quant by 20-28% (#4215)


Co-authored-by: default avatarStefan He <bhe@linkedin.com>
parent 7130a7ce
from typing import Optional
import torch import torch
from torch import nn from torch import nn
...@@ -40,3 +42,60 @@ class CustomOp(nn.Module): ...@@ -40,3 +42,60 @@ class CustomOp(nn.Module):
return self.forward_hip return self.forward_hip
else: else:
return self.forward_native return self.forward_native
if _is_cuda:
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 (8-bit floating point) format.
Args:
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
determines the quantization granularity:
- True: compute scale per token
- False: compute single scale per tensor
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization
Raises:
AssertionError: If input is not 2D or if static scale's numel != 1
"""
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
output = torch.empty(shape, device=input.device, dtype=out_dtype)
if scale is None:
# Dynamic scaling
if use_per_token_if_dynamic:
scale = torch.empty(
(shape[0], 1), device=input.device, dtype=torch.float32
)
sgl_per_token_quant_fp8(input, output, scale)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=False
) # False for dynamic
else:
# Static scaling
assert (
scale.numel() == 1
), f"Expected scalar scale, got numel={scale.numel()}"
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=True
) # True for static
return output, scale
...@@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Tuple ...@@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Tuple
import torch import torch
from torch.nn import Module from torch.nn import Module
from vllm import _custom_ops as ops from vllm import _custom_ops as vllm_ops
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import ( from sglang.srt.distributed import (
...@@ -26,7 +26,13 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -26,7 +26,13 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.utils import is_hip, set_weight_attrs from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs
_is_cuda = is_cuda()
if _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -719,12 +725,20 @@ class Fp8EPMoEMethod(Fp8MoEMethod): ...@@ -719,12 +725,20 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
) )
for expert in range(layer.num_experts_per_partition): for expert in range(layer.num_experts_per_partition):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( if _is_cuda:
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
) sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( )
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
) sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
else:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
return return
......
...@@ -11,7 +11,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple ...@@ -11,7 +11,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm import _custom_ops as ops from vllm import _custom_ops as vllm_ops
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
...@@ -42,6 +42,7 @@ _is_cuda = is_cuda() ...@@ -42,6 +42,7 @@ _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul from sgl_kernel import gelu_and_mul, silu_and_mul
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8, sglang_per_token_group_quant_fp8,
) )
...@@ -486,7 +487,7 @@ def moe_align_block_size( ...@@ -486,7 +487,7 @@ def moe_align_block_size(
cumsum_buffer, cumsum_buffer,
) )
else: else:
ops.moe_align_block_size( vllm_ops.moe_align_block_size(
topk_ids, topk_ids,
num_experts, num_experts,
block_size, block_size,
...@@ -527,7 +528,10 @@ def invoke_fused_moe_kernel( ...@@ -527,7 +528,10 @@ def invoke_fused_moe_kernel(
if block_shape is None: if block_shape is None:
# activation tensor-wise fp8 quantization, dynamic or static # activation tensor-wise fp8 quantization, dynamic or static
padded_size = padding_size padded_size = padding_size
A, A_scale = ops.scaled_fp8_quant(A, A_scale) if _is_cuda:
A, A_scale = sgl_scaled_fp8_quant(A, A_scale)
else:
A, A_scale = vllm_ops.scaled_fp8_quant(A, A_scale)
else: else:
# activation block-wise fp8 quantization # activation block-wise fp8 quantization
assert len(block_shape) == 2 assert len(block_shape) == 2
...@@ -1095,12 +1099,16 @@ def fused_experts_impl( ...@@ -1095,12 +1099,16 @@ def fused_experts_impl(
if _is_cuda: if _is_cuda:
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else: else:
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) vllm_ops.silu_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, N)
)
elif activation == "gelu": elif activation == "gelu":
if _is_cuda: if _is_cuda:
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else: else:
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) vllm_ops.gelu_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, N)
)
else: else:
raise ValueError(f"Unsupported activation: {activation=}") raise ValueError(f"Unsupported activation: {activation=}")
...@@ -1132,7 +1140,7 @@ def fused_experts_impl( ...@@ -1132,7 +1140,7 @@ def fused_experts_impl(
if no_combine: if no_combine:
pass pass
elif _is_hip: elif _is_hip:
ops.moe_sum( vllm_ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape), intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx], out_hidden_states[begin_chunk_idx:end_chunk_idx],
) )
......
# Adapted from https://github.com/vllm-project/vllm/blob/8ca7a71df787ad711ad3ac70a5bd2eb2bb398938/tests/quantization/test_fp8.py
import pytest
import torch
from sglang.srt.custom_op import scaled_fp8_quant
from sglang.srt.utils import is_cuda
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant_per_tensor(dtype) -> None:
def quantize_ref_per_tensor(tensor, inv_scale):
# The reference implementation that fully aligns to
# the kernel being tested.
finfo = torch.finfo(torch.float8_e4m3fn)
scale = inv_scale.reciprocal()
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
qweight = qweight.to(torch.float8_e4m3fn)
return qweight
def dequantize_per_tensor(tensor, inv_scale, dtype):
fake_qweight = tensor.to(dtype)
dq_weight = fake_qweight * inv_scale
return dq_weight
# Note that we use a shape % 8 != 0 to cover edge cases,
# because scaled_fp8_quant is vectorized by 8.
x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)
# Test Per Tensor Dynamic quantization
# scale = max(abs(x)) / FP8_E4M3_MAX
y, scale = scaled_fp8_quant(x, None)
ref_y = quantize_ref_per_tensor(x, scale)
torch.testing.assert_close(y, ref_y)
torch.testing.assert_close(
dequantize_per_tensor(y, scale, dtype),
dequantize_per_tensor(ref_y, scale, dtype),
)
# Test Per Tensor Static quantization
y, _ = scaled_fp8_quant(x, scale)
ref_y = quantize_ref_per_tensor(x, scale)
torch.testing.assert_close(y, ref_y)
torch.testing.assert_close(
dequantize_per_tensor(y, scale, dtype),
dequantize_per_tensor(ref_y, scale, dtype),
)
if is_cuda:
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None:
def quantize_ref_per_token(tensor, inv_scale):
# The reference implementation that fully aligns to
# the kernel being tested.
finfo = torch.finfo(torch.float8_e4m3fn)
scale = inv_scale.reciprocal()
qweight = (tensor.to(torch.float32) * scale).clamp(
min=finfo.min, max=finfo.max
)
qweight = qweight.to(torch.float8_e4m3fn)
return qweight
def dequantize_per_token(tensor, inv_scale, dtype):
fake_qweight = tensor.to(dtype)
dq_weight = fake_qweight * inv_scale
return dq_weight
# Note that we use a shape % 8 = 0,
# because per_token_quant_fp8 is vectorized by 8 elements.
x = (torch.randn(size=(11, 16), device="cuda") * 13).to(dtype)
# Test Per Tensor Dynamic quantization
# scale = max(abs(x)) / FP8_E4M3_MAX
y, scale = scaled_fp8_quant(x, None, use_per_token_if_dynamic=True)
ref_y = quantize_ref_per_token(x, scale)
torch.testing.assert_close(y, ref_y)
torch.testing.assert_close(
dequantize_per_token(y, scale, dtype),
dequantize_per_token(ref_y, scale, dtype),
)
if __name__ == "__main__":
# Run the specific test function directly
pytest.main([__file__])
...@@ -14,7 +14,6 @@ __global__ void per_token_quant_fp8_kernel( ...@@ -14,7 +14,6 @@ __global__ void per_token_quant_fp8_kernel(
const int64_t hidden_dim, const int64_t hidden_dim,
const int64_t num_tokens) { const int64_t num_tokens) {
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
if (token_idx >= num_tokens) return; if (token_idx >= num_tokens) return;
const int tid = threadIdx.x; const int tid = threadIdx.x;
...@@ -25,9 +24,20 @@ __global__ void per_token_quant_fp8_kernel( ...@@ -25,9 +24,20 @@ __global__ void per_token_quant_fp8_kernel(
float max_value = 0.0f; float max_value = 0.0f;
for (int i = tid; i < hidden_dim; i += block_dim) { constexpr uint32_t vec_size = 16 / sizeof(T);
float val = static_cast<float>(token_input[i]); using vec_t = flashinfer::vec_t<T, vec_size>;
max_value = fmaxf(max_value, fabsf(val)); const int32_t num_vec_elems = hidden_dim / vec_size;
// Find max using vectorized loads
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
vec_t input_vec;
input_vec.cast_load(token_input + i * vec_size);
#pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) {
float val = static_cast<float>(input_vec[j]);
max_value = fmaxf(max_value, fabsf(val));
}
} }
max_value = blockReduceMax(max_value); max_value = blockReduceMax(max_value);
...@@ -41,11 +51,7 @@ __global__ void per_token_quant_fp8_kernel( ...@@ -41,11 +51,7 @@ __global__ void per_token_quant_fp8_kernel(
const float scale_val = 1.0f / block_max; const float scale_val = 1.0f / block_max;
constexpr uint32_t vec_size = 16 / sizeof(T); // Quantize using vectorized loads
using vec_t = flashinfer::vec_t<T, vec_size>;
const int32_t num_vec_elems = hidden_dim / vec_size;
for (int32_t i = tid; i < num_vec_elems; i += block_dim) { for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
vec_t input_vec; vec_t input_vec;
input_vec.cast_load(token_input + i * vec_size); input_vec.cast_load(token_input + i * vec_size);
...@@ -53,7 +59,7 @@ __global__ void per_token_quant_fp8_kernel( ...@@ -53,7 +59,7 @@ __global__ void per_token_quant_fp8_kernel(
FP8_TYPE output_arr[vec_size]; FP8_TYPE output_arr[vec_size];
#pragma unroll #pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) { for (uint32_t j = 0; j < vec_size; ++j) {
float val = fmax(fmin(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX); float val = fmaxf(fminf(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
#ifndef USE_ROCM #ifndef USE_ROCM
output_arr[j] = static_cast<FP8_TYPE>(val); output_arr[j] = static_cast<FP8_TYPE>(val);
#else #else
...@@ -68,18 +74,6 @@ __global__ void per_token_quant_fp8_kernel( ...@@ -68,18 +74,6 @@ __global__ void per_token_quant_fp8_kernel(
token_output[i * vec_size + j] = output_arr[j]; token_output[i * vec_size + j] = output_arr[j];
} }
} }
const int32_t remaining_start = num_vec_elems * vec_size;
for (int32_t idx = remaining_start + tid; idx < hidden_dim; idx += block_dim) {
float val = fmax(-FP8_E4M3_MAX, fmin(static_cast<float>(token_input[idx]) * scale_val, FP8_E4M3_MAX));
#ifndef USE_ROCM
token_output[idx] = static_cast<FP8_TYPE>(val);
#else
token_output[idx] = c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
c10::Float8_e4m3fnuz::from_bits());
#endif
}
} }
void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s) { void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s) {
...@@ -91,7 +85,9 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch: ...@@ -91,7 +85,9 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
const int64_t num_tokens = input_sizes[0]; const int64_t num_tokens = input_sizes[0];
const int64_t hidden_dim = input_sizes[1]; const int64_t hidden_dim = input_sizes[1];
const int block_size = 128; TORCH_CHECK(hidden_dim % 8 == 0, "Hidden dimension must be divisible by 8, but got ", hidden_dim);
const int block_size = 256;
const int num_blocks = num_tokens; const int num_blocks = num_tokens;
dim3 grid(num_blocks); dim3 grid(num_blocks);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment