"vscode:/vscode.git/clone" did not exist on "2bcf71b9c0305e2a3d645e7a5ced4460262b4a6c"
Unverified Commit 8d75fe48 authored by Tyler Michael Smith's avatar Tyler Michael Smith Committed by GitHub
Browse files

[Kernel] Switch fp8 layers to use the CUTLASS kernels (#5183)

Switching from torch._scaled_mm to vLLM's cutlass fp8 kernels when supported as we are seeing 5-15% improvement in e2e performance on neuralmagic/Meta-Llama-3-8B-Instruct-FP8

see https://docs.google.com/spreadsheets/d/1GiAnmzyGHgZ6zL_LDSTm35Bdrt4A8AaFEurDlISYYA4/ for some quick e2e benchmarks and #5144 for comparisons across different GEMM sizes.
parent 388596c9
...@@ -179,7 +179,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -179,7 +179,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# cutlass # cutlass
def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
a_scales: torch.Tensor, b_scales: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor,
out_dtype: Type[torch.dtype]) -> torch.Tensor: out_dtype: Type[torch.dtype]) -> torch.Tensor:
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
...@@ -188,7 +188,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, ...@@ -188,7 +188,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
n = b.shape[1] n = b.shape[1]
out = torch.empty((m, n), dtype=out_dtype, device=a.device) out = torch.empty((m, n), dtype=out_dtype, device=a.device)
vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales) vllm_ops.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b)
return out return out
......
...@@ -17,6 +17,24 @@ ACTIVATION_SCHEMES = ["static", "dynamic"] ...@@ -17,6 +17,24 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__) logger = init_logger(__name__)
def cutlass_fp8_supported() -> bool:
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
version = torch.version.cuda
version = version[0] * 10 + version[1]
# CUTLASS FP8 kernels need at least
# CUDA 12.0 on SM90 systems (Hopper)
# CUDA 12.4 on SM89 systems (Lovelace)
gpu_is_supported = False
if capability >= 900:
gpu_is_supported = version > 120
elif capability >= 890:
gpu_is_supported = version > 124
return gpu_is_supported
class Fp8Config(QuantizationConfig): class Fp8Config(QuantizationConfig):
"""Config class for FP8.""" """Config class for FP8."""
...@@ -92,6 +110,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -92,6 +110,7 @@ class Fp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: Fp8Config): def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
def _create_scale_param( def _create_scale_param(
self, self,
...@@ -233,9 +252,24 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -233,9 +252,24 @@ class Fp8LinearMethod(LinearMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant. # ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.act_scale is None and x_scale computed from x. # If dynamic, layer.act_scale is None and x_scale computed from x.
# If static, layer.act_scale is scalar and x_scale set to act_scale. # If static, layer.act_scale is scalar and x_scale set to act_scale.
if bias is None and self.cutlass_fp8_supported:
qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale)
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm_dq(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
)
else:
qinput, x_scale = ops.scaled_fp8_quant(x, qinput, x_scale = ops.scaled_fp8_quant(x,
layer.act_scale, layer.act_scale,
batch_dim_padding=17) batch_dim_padding=17)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment