Unverified Commit bf86c5e9 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

restruct compressed_tensors_w8a8_fp8 (#5475)

parent dca90f1d
...@@ -16,7 +16,7 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import ( ...@@ -16,7 +16,7 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsScheme,
) )
from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.layers.quantization.fp8_utils import (
Fp8LinearOp, apply_fp8_linear,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
...@@ -29,7 +29,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -29,7 +29,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool): def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
...@@ -149,11 +148,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -149,11 +148,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return apply_fp8_linear(
return self.fp8_linear.apply(
input=x, input=x,
weight=layer.weight, weight=layer.weight,
weight_scale=layer.weight_scale, weight_scale=layer.weight_scale,
input_scale=layer.input_scale, input_scale=layer.input_scale,
bias=bias, bias=bias,
use_per_token_if_dynamic=True,
compressed_tensor_quant=True,
) )
import os
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
...@@ -5,7 +6,7 @@ import torch ...@@ -5,7 +6,7 @@ import torch
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
try: try:
from vllm import _custom_ops as vllm_ops from vllm import _custom_ops as ops
VLLM_AVAILABLE = True VLLM_AVAILABLE = True
except ImportError: except ImportError:
...@@ -234,6 +235,43 @@ def channel_quant_to_tensor_quant( ...@@ -234,6 +235,43 @@ def channel_quant_to_tensor_quant(
return x_q_tensor, scale return x_q_tensor, scale
def _process_scaled_mm_output(output, input_2d_shape, output_shape):
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, input_2d_shape[0]).view(*output_shape)
def _apply_fallback_scaled_mm(
qinput,
weight,
x_scale,
weight_scale,
input_2d_shape,
output_shape,
bias,
input_dtype,
):
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32, device=weight.device)
output = torch._scaled_mm(
qinput,
weight,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32,
)
output = _process_scaled_mm_output(output, input_2d_shape, output_shape)
x_scale = torch.narrow(x_scale, 0, 0, input_2d_shape[0])
output = output * x_scale * weight_scale.t()
if bias is not None:
output = output + bias
return output.to(dtype=input_dtype)
def apply_fp8_linear( def apply_fp8_linear(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
...@@ -241,43 +279,38 @@ def apply_fp8_linear( ...@@ -241,43 +279,38 @@ def apply_fp8_linear(
input_scale: Optional[torch.Tensor] = None, input_scale: Optional[torch.Tensor] = None,
input_scale_ub: Optional[torch.Tensor] = None, input_scale_ub: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
cutlass_fp8_supported: bool = True, cutlass_fp8_supported: bool = cutlass_fp8_supported(),
use_per_token_if_dynamic: bool = False, use_per_token_if_dynamic: bool = False,
pad_output: Optional[bool] = None,
compressed_tensor_quant: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
if pad_output is None:
pad_output = not get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE")
output_padding = 17 if pad_output else None
# View input as 2D matrix for fp8 methods # View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1]) input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[1]] output_shape = [*input.shape[:-1], weight.shape[1]]
# cutlass w8a8 fp8 sgl-kernel only supports per-token scale if compressed_tensor_quant:
if input_scale is not None: # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
assert input_scale.numel() == 1 # for sgl-kernel fp8_scaled_mm, it support per channel W now
# broadcast per-tensor scale to per-token scale when supporting cutlass if cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
qinput, x_scale = static_quant_fp8( qinput, x_scale = scaled_fp8_quant(
input_2d, input_scale, repeat_scale=cutlass_fp8_supported
)
else:
# default use per-token quantization if dynamic
if _is_cuda:
qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
else:
# TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling
# final solution should be: 1. add support to per-tensor activation scaling.
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
if _is_hip and weight_scale.numel() == 1:
qinput, x_scale = vllm_ops.scaled_fp8_quant(
input_2d, input_2d,
input_scale, input_scale,
use_per_token_if_dynamic=use_per_token_if_dynamic, use_per_token_if_dynamic=use_per_token_if_dynamic,
) )
else:
qinput, x_scale = per_token_group_quant_fp8(
input_2d, group_size=input_2d.shape[1]
)
if cutlass_fp8_supported: # Fused GEMM_DQ
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel: if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
# Fall back to vllm cutlass w8a8 fp8 kernel # Fall back to vllm cutlass w8a8 fp8 kernel
output = vllm_ops.cutlass_scaled_mm( output = ops.cutlass_scaled_mm(
qinput, qinput,
weight, weight,
out_dtype=input.dtype, out_dtype=input.dtype,
...@@ -302,6 +335,23 @@ def apply_fp8_linear( ...@@ -302,6 +335,23 @@ def apply_fp8_linear(
# torch.scaled_mm supports per tensor weights + activations only # torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token # so fallback to naive if per channel or per token
else: else:
# Maybe apply padding to output, see comment in __init__
qinput, x_scale = (
scaled_fp8_quant(
input_2d,
input_scale,
num_token_padding=output_padding,
use_per_token_if_dynamic=use_per_token_if_dynamic,
)
if _is_cuda
else ops.scaled_fp8_quant(
input_2d,
input_scale,
num_token_padding=output_padding,
use_per_token_if_dynamic=use_per_token_if_dynamic,
)
)
per_tensor_weights = weight_scale.numel() == 1 per_tensor_weights = weight_scale.numel() == 1
per_tensor_activations = x_scale.numel() == 1 per_tensor_activations = x_scale.numel() == 1
...@@ -315,12 +365,30 @@ def apply_fp8_linear( ...@@ -315,12 +365,30 @@ def apply_fp8_linear(
scale_b=weight_scale, scale_b=weight_scale,
bias=bias, bias=bias,
) )
# A fix for discrepancy in scaled_mm which returns tuple return _process_scaled_mm_output(output, input_2d.shape, output_shape)
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) elif (
use_per_token_if_dynamic
and not per_tensor_weights
and not per_tensor_activations
and USE_ROWWISE_TORCH_SCALED_MM
):
# For now validated on ROCm platform
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
# and ROCm 6.3, which only exists in torch 2.7 and above.
# For CUDA platform please validate if the
# torch._scaled_mm support rowwise scaled GEMM
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(
qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale.t(),
bias=bias,
)
return _process_scaled_mm_output(output, input_2d.shape, output_shape)
else: else:
# Fallback for channelwise case, where we use unfused DQ # Fallback for channelwise case, where we use unfused DQ
...@@ -337,115 +405,48 @@ def apply_fp8_linear( ...@@ -337,115 +405,48 @@ def apply_fp8_linear(
# #
# For the scaled_mm fallback case, we break this down, since it # For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector. # does not support s_w being a vector.
return _apply_fallback_scaled_mm(
# Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(
1, dtype=torch.float32, device=weight.device
)
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(
qinput, qinput,
weight, weight,
scale_a=TORCH_DEVICE_IDENTITY, x_scale,
scale_b=TORCH_DEVICE_IDENTITY, weight_scale,
out_dtype=torch.float32, input_2d.shape,
output_shape,
bias,
input.dtype,
) )
# A fix for discrepancy in scaled_mm which returns tuple else:
# for torch < 2.5 and a single value in torch >= 2.5 # cutlass w8a8 fp8 sgl-kernel only supports per-token scale
if type(output) is tuple and len(output) == 2: if input_scale is not None:
output = output[0] assert input_scale.numel() == 1
# Unpad (undo num_token_padding) # broadcast per-tensor scale to per-token scale when supporting cutlass
output = torch.narrow(output, 0, 0, input_2d.shape[0]) qinput, x_scale = static_quant_fp8(
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0]) input_2d, input_scale, repeat_scale=cutlass_fp8_supported
)
# DQ else:
# C = sw * sx * (X * W) + bias # default use per-token quantization if dynamic
output = output * x_scale * weight_scale.t()
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
# https://github.com/vllm-project/vllm/issues/14397
class Fp8LinearOp:
"""
This class executes a FP8 linear layer using cutlass if supported and
torch.scaled_mm otherwise.
It needs to be a class instead of a method so that config can be read
in the __init__ method, as reading config is not allowed inside forward.
"""
def __init__(
self,
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
use_per_token_if_dynamic: bool = False,
pad_output: Optional[bool] = None,
):
self.cutlass_fp8_supported = cutlass_fp8_supported
self.use_per_token_if_dynamic = use_per_token_if_dynamic
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
if pad_output is None:
enable_torch_compile = get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE")
pad_output = not enable_torch_compile
self.output_padding = 17 if pad_output else None
def apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
input_scale_ub: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
# TODO(luka) remove this parameter in favor of __init__
use_per_token_if_dynamic: Optional[bool] = None,
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[1]]
# TODO(luka) this is here because currently MLA only decides this
# during the forward method instead of in __init__.
if use_per_token_if_dynamic is None:
use_per_token_if_dynamic = self.use_per_token_if_dynamic
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
# for sgl-kernel fp8_scaled_mm, it support per channel W now
if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
if _is_cuda: if _is_cuda:
qinput, x_scale = scaled_fp8_quant( qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
else:
# TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling
# final solution should be: 1. add support to per-tensor activation scaling.
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
if _is_hip and weight_scale.numel() == 1:
qinput, x_scale = ops.scaled_fp8_quant(
input_2d, input_2d,
input_scale, input_scale,
use_per_token_if_dynamic=use_per_token_if_dynamic, use_per_token_if_dynamic=use_per_token_if_dynamic,
) )
else: else:
qinput, x_scale = vllm_ops.scaled_fp8_quant( qinput, x_scale = per_token_group_quant_fp8(
input_2d, input_2d, group_size=input_2d.shape[1]
input_scale,
scale_ub=input_scale_ub,
use_per_token_if_dynamic=use_per_token_if_dynamic,
) )
# Fused GEMM_DQ if cutlass_fp8_supported:
try:
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel: if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
# Fall back to vllm cutlass w8a8 fp8 kernel # Fall back to vllm cutlass w8a8 fp8 kernel
output = vllm_ops.cutlass_scaled_mm( output = ops.cutlass_scaled_mm(
qinput, qinput,
weight, weight,
out_dtype=input.dtype, out_dtype=input.dtype,
...@@ -466,26 +467,11 @@ class Fp8LinearOp: ...@@ -466,26 +467,11 @@ class Fp8LinearOp:
bias=bias, bias=bias,
) )
return output.view(*output_shape) return output.view(*output_shape)
except (ImportError, NameError, AttributeError):
pass
# torch.scaled_mm supports per tensor weights + activations only # torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token # so fallback to naive if per channel or per token
else:
# Maybe apply padding to output, see comment in __init__
if _is_cuda:
qinput, x_scale = scaled_fp8_quant(
input_2d,
input_scale,
num_token_padding=self.output_padding,
use_per_token_if_dynamic=use_per_token_if_dynamic,
)
else:
qinput, x_scale = vllm_ops.scaled_fp8_quant(
input_2d,
input_scale,
num_token_padding=self.output_padding,
use_per_token_if_dynamic=use_per_token_if_dynamic,
)
per_tensor_weights = weight_scale.numel() == 1 per_tensor_weights = weight_scale.numel() == 1
per_tensor_activations = x_scale.numel() == 1 per_tensor_activations = x_scale.numel() == 1
...@@ -499,38 +485,7 @@ class Fp8LinearOp: ...@@ -499,38 +485,7 @@ class Fp8LinearOp:
scale_b=weight_scale, scale_b=weight_scale,
bias=bias, bias=bias,
) )
# A fix for discrepancy in scaled_mm which returns tuple return _process_scaled_mm_output(output, input_2d.shape, output_shape)
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
elif (
use_per_token_if_dynamic
and not per_tensor_weights
and not per_tensor_activations
and USE_ROWWISE_TORCH_SCALED_MM
):
# For now validated on ROCm platform
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
# and ROCm 6.3, which only exists in torch 2.7 and above.
# For CUDA platform please validate if the
# torch._scaled_mm support rowwise scaled GEMM
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(
qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale.t(),
bias=bias,
)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
output = output.view(*output_shape)
return output
else: else:
# Fallback for channelwise case, where we use unfused DQ # Fallback for channelwise case, where we use unfused DQ
...@@ -547,36 +502,13 @@ class Fp8LinearOp: ...@@ -547,36 +502,13 @@ class Fp8LinearOp:
# #
# For the scaled_mm fallback case, we break this down, since it # For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector. # does not support s_w being a vector.
return _apply_fallback_scaled_mm(
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
# Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(
1, dtype=torch.float32, device=weight.device
)
output = torch._scaled_mm(
qinput, qinput,
weight, weight,
scale_a=TORCH_DEVICE_IDENTITY, x_scale,
scale_b=TORCH_DEVICE_IDENTITY, weight_scale,
out_dtype=torch.float32, input_2d.shape,
output_shape,
bias,
input.dtype,
) )
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * weight_scale.t()
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestCompressedTensorsLlama3FP8(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "RedHatAI/Meta-Llama-3.1-8B-FP8"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.45)
if __name__ == "__main__":
unittest.main()
...@@ -20,6 +20,7 @@ suites = { ...@@ -20,6 +20,7 @@ suites = {
TestFile("models/test_generation_models.py", 103), TestFile("models/test_generation_models.py", 103),
TestFile("models/test_grok_models.py", 60), TestFile("models/test_grok_models.py", 60),
TestFile("models/test_qwen_models.py", 82), TestFile("models/test_qwen_models.py", 82),
TestFile("models/test_compressed_tensors_models.py", 100),
TestFile("models/test_reward_models.py", 83), TestFile("models/test_reward_models.py", 83),
TestFile("models/test_gme_qwen_models.py", 45), TestFile("models/test_gme_qwen_models.py", 45),
TestFile("models/test_clip_models.py", 100), TestFile("models/test_clip_models.py", 100),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment