Unverified Commit 485a023b authored by ChangyiYang's avatar ChangyiYang Committed by GitHub
Browse files

refactor apply_w8a8_block_fp8_linear in fp (#6545)

parent 7e412900
...@@ -10,7 +10,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ...@@ -10,7 +10,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul, w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,
) )
from sglang.srt.layers.quantization.fp8_kernel import w8a8_block_fp8_matmul from sglang.srt.layers.quantization.fp8_kernel import (
w8a8_block_fp8_matmul_deepgemm as w8a8_block_fp8_matmul,
)
# Adapted from https://github.com/tile-ai/tilelang/blob/a8cfdce92795cb861c9033573534653ee040b5ed/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py#L1 # Adapted from https://github.com/tile-ai/tilelang/blob/a8cfdce92795cb861c9033573534653ee040b5ed/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py#L1
......
...@@ -49,8 +49,8 @@ from sglang.srt.layers.quantization.fp8_kernel import ( ...@@ -49,8 +49,8 @@ from sglang.srt.layers.quantization.fp8_kernel import (
) )
from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear, apply_fp8_linear,
apply_w8a8_block_fp8_linear,
cutlass_fp8_supported, cutlass_fp8_supported,
dispatch_w8a8_block_fp8_linear,
input_to_float8, input_to_float8,
is_sm100_supported, is_sm100_supported,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
...@@ -209,6 +209,8 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -209,6 +209,8 @@ class Fp8LinearMethod(LinearMethodBase):
# Marlin doesn't support block-wise fp8 # Marlin doesn't support block-wise fp8
self.use_marlin = False self.use_marlin = False
self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear()
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -417,7 +419,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -417,7 +419,7 @@ class Fp8LinearMethod(LinearMethodBase):
) )
if self.block_quant: if self.block_quant:
return apply_w8a8_block_fp8_linear( return self.w8a8_block_fp8_linear(
input=x, input=x,
weight=layer.weight, weight=layer.weight,
block_size=self.quant_config.weight_block_size, block_size=self.quant_config.weight_block_size,
......
...@@ -740,7 +740,59 @@ if _is_hip: ...@@ -740,7 +740,59 @@ if _is_hip:
return _w8a8_block_fp8_matmul return _w8a8_block_fp8_matmul
def w8a8_block_fp8_matmul( def prepare_block_fp8_matmul_inputs(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype = torch.float16,
) -> Tuple[int, int, int]:
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1]
assert A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2
assert B.is_contiguous()
assert Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
return M, N, K, C
def w8a8_block_fp8_matmul_deepgemm(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, output_dtype)
# Deepgemm only supports output tensor type as bfloat16
assert C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM
if supports_custom_op():
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
else:
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
return C
def w8a8_block_fp8_matmul_triton(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, As: torch.Tensor,
...@@ -764,81 +816,81 @@ def w8a8_block_fp8_matmul( ...@@ -764,81 +816,81 @@ def w8a8_block_fp8_matmul(
Returns: Returns:
torch.Tensor: The result of matmul. torch.Tensor: The result of matmul.
""" """
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, output_dtype)
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N,) block_n, block_k = block_size
C = A.new_empty(C_shape, dtype=output_dtype)
# deepgemm only support bf16 configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM: if configs:
if supports_custom_op(): # If an optimal configuration map has been found, look up the
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C) # optimal config
else: config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
else: else:
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1]) # Default config
if configs: # Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
# If an optimal configuration map has been found, look up the config = {
# optimal config "BLOCK_SIZE_M": 64,
config = configs[min(configs.keys(), key=lambda x: abs(x - M))] "BLOCK_SIZE_N": block_size[0],
else: "BLOCK_SIZE_K": block_size[1],
# Default config "GROUP_SIZE_M": 32,
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1] "num_warps": 4,
config = { "num_stages": 3,
"BLOCK_SIZE_M": 64, }
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1], def grid(META):
"GROUP_SIZE_M": 32, return (
"num_warps": 4, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
"num_stages": 3, )
}
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"])
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
kernel = select_w8a8_block_fp8_matmul_kernel(M, N, config) kernel = select_w8a8_block_fp8_matmul_kernel(M, N, config)
kernel[grid]( kernel[grid](
A, A,
B, B,
C, C,
As, As,
Bs, Bs,
M, M,
N, N,
K, K,
block_n, block_n,
block_k, block_k,
A.stride(-2), A.stride(-2),
A.stride(-1), A.stride(-1),
B.stride(1), B.stride(1),
B.stride(0), B.stride(0),
C.stride(-2), C.stride(-2),
C.stride(-1), C.stride(-1),
As.stride(-2), As.stride(-2),
As.stride(-1), As.stride(-1),
Bs.stride(1), Bs.stride(1),
Bs.stride(0), Bs.stride(0),
**config, **config,
) )
return C return C
# universal entry point, for testing purposes
def w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
if output_dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
return w8a8_block_fp8_matmul_deepgemm(
A, B, As, Bs, block_size, output_dtype=output_dtype
)
return w8a8_block_fp8_matmul_triton(
A, B, As, Bs, block_size, output_dtype=output_dtype
)
@triton.jit @triton.jit
def _per_tensor_quant_mla_fp8_stage1( def _per_tensor_quant_mla_fp8_stage1(
x_ptr, x_ptr,
......
import os import os
from typing import List, Optional, Tuple from curses import flash
from typing import Callable, List, Optional, Tuple
import torch import torch
...@@ -21,7 +22,8 @@ from sglang.srt.layers.quantization.fp8_kernel import ( ...@@ -21,7 +22,8 @@ from sglang.srt.layers.quantization.fp8_kernel import (
scaled_fp8_quant, scaled_fp8_quant,
sglang_per_token_quant_fp8, sglang_per_token_quant_fp8,
static_quant_fp8, static_quant_fp8,
w8a8_block_fp8_matmul, w8a8_block_fp8_matmul_deepgemm,
w8a8_block_fp8_matmul_triton,
) )
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
...@@ -134,7 +136,20 @@ if ENABLE_FLASHINFER_GEMM: ...@@ -134,7 +136,20 @@ if ENABLE_FLASHINFER_GEMM:
from flashinfer.gemm import gemm_fp8_nt_groupwise from flashinfer.gemm import gemm_fp8_nt_groupwise
def apply_w8a8_block_fp8_linear( def dispatch_w8a8_block_fp8_linear() -> Callable:
if ENABLE_FLASHINFER_GEMM:
return flashinfer_gemm_w8a8_block_fp8_linear
elif CUTLASS_BLOCK_FP8_SUPPORTED:
return cutlass_w8a8_block_fp8_linear_with_fallback
elif _is_hip and use_aiter_moe:
return aiter_w8a8_block_fp8_linear
elif _ENABLE_JIT_DEEPGEMM:
return deepgemm_w8a8_block_fp8_linear_with_fallback
else:
return triton_w8a8_block_fp8_linear
def flashinfer_gemm_w8a8_block_fp8_linear(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
block_size: List[int], block_size: List[int],
...@@ -143,58 +158,148 @@ def apply_w8a8_block_fp8_linear( ...@@ -143,58 +158,148 @@ def apply_w8a8_block_fp8_linear(
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert input_scale is None assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1]) input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]] output_shape = [*input.shape[:-1], weight.shape[0]]
# TODO: add more robust shape check here
shape_supported_by_cutlass = ( q_input, x_scale = sglang_per_token_group_quant_fp8(
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0 input_2d, block_size[1], column_major_scales=False
) )
if ENABLE_FLASHINFER_GEMM:
q_input, x_scale = sglang_per_token_group_quant_fp8( x_scale_input = x_scale.T.contiguous()
input_2d, block_size[1], column_major_scales=False weight_scale_input = weight_scale.T.contiguous()
)
x_scale_input = x_scale.T.contiguous() output = gemm_fp8_nt_groupwise(
weight_scale_input = weight_scale.T.contiguous() q_input, weight, x_scale_input, weight_scale_input, out_dtype=input_2d.dtype
output = gemm_fp8_nt_groupwise( )
q_input, weight, x_scale_input, weight_scale_input, out_dtype=input.dtype
) if bias is not None:
elif CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass: output += bias
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=True return output.to(dtype=input_2d.dtype).view(*output_shape)
)
output = fp8_blockwise_scaled_mm(
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype def cutlass_w8a8_block_fp8_linear_with_fallback(
) input: torch.Tensor,
elif _is_hip and use_aiter_moe: weight: torch.Tensor,
q_input, x_scale = per_token_group_quant_fp8( block_size: List[int],
input_2d, block_size[1], column_major_scales=False weight_scale: torch.Tensor,
) input_scale: Optional[torch.Tensor] = None,
output = torch.zeros( bias: Optional[torch.Tensor] = None,
[q_input.shape[0], weight.shape[0]], ) -> torch.Tensor:
dtype=input.dtype, assert input_scale is None
device=q_input.device,
# TODO: add more robust shape check here
shape_supported = weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
if not shape_supported:
# fallback to triton
return triton_w8a8_block_fp8_linear(
input, weight, block_size, weight_scale, input_scale, bias
) )
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
else: input_2d = input.view(-1, input.shape[-1])
if _ENABLE_JIT_DEEPGEMM: output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = sglang_per_token_group_quant_fp8(
input_2d, q_input, x_scale = per_token_group_quant_fp8(
block_size[1], input_2d, block_size[1], column_major_scales=True
column_major_scales=True, )
scale_tma_aligned=True, output = fp8_blockwise_scaled_mm(
) q_input, weight.T, x_scale, weight_scale.T, out_dtype=input_2d.dtype
else: )
q_input, x_scale = per_token_group_quant_fp8( if bias is not None:
input_2d, block_size[1], column_major_scales=False output += bias
) return output.to(dtype=input_2d.dtype).view(*output_shape)
output = w8a8_block_fp8_matmul(
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
def deepgemm_w8a8_block_fp8_linear_with_fallback(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
output_dtype = input.dtype
dtype_supported = output_dtype == torch.bfloat16
# TODO: add more robust shape check here
shape_supported = weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
if not (shape_supported and dtype_supported):
# fall back to triton
return triton_w8a8_block_fp8_linear(
input, weight, block_size, weight_scale, input_scale, bias
) )
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = sglang_per_token_group_quant_fp8(
input_2d,
block_size[1],
column_major_scales=True,
scale_tma_aligned=True,
)
output = w8a8_block_fp8_matmul_deepgemm(
q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
)
if bias is not None: if bias is not None:
output = output + bias output += bias
return output.to(dtype=input.dtype).view(*output_shape) return output.to(dtype=output_dtype).view(*output_shape)
def aiter_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
output = torch.zeros(
[q_input.shape[0], weight.shape[0]],
dtype=input_2d.dtype,
device=q_input.device,
)
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
if bias is not None:
output += bias
return output.to(dtype=input_2d.dtype).view(*output_shape)
def triton_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
output = w8a8_block_fp8_matmul_triton(
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input_2d.dtype
)
if bias is not None:
output += bias
return output.to(dtype=input_2d.dtype).view(*output_shape)
def input_to_float8( def input_to_float8(
......
...@@ -9,7 +9,9 @@ from deep_gemm import get_col_major_tma_aligned_tensor ...@@ -9,7 +9,9 @@ from deep_gemm import get_col_major_tma_aligned_tensor
from sgl_kernel import fp8_blockwise_scaled_mm from sgl_kernel import fp8_blockwise_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from sglang.srt.layers.quantization.fp8_kernel import w8a8_block_fp8_matmul from sglang.srt.layers.quantization.fp8_kernel import (
w8a8_block_fp8_matmul_triton as w8a8_block_fp8_matmul,
)
def get_weight_shapes(args): def get_weight_shapes(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment