Unverified Commit 485a023b authored by ChangyiYang's avatar ChangyiYang Committed by GitHub
Browse files

refactor apply_w8a8_block_fp8_linear in fp (#6545)

parent 7e412900
......@@ -10,7 +10,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,
)
from sglang.srt.layers.quantization.fp8_kernel import w8a8_block_fp8_matmul
from sglang.srt.layers.quantization.fp8_kernel import (
w8a8_block_fp8_matmul_deepgemm as w8a8_block_fp8_matmul,
)
# Adapted from https://github.com/tile-ai/tilelang/blob/a8cfdce92795cb861c9033573534653ee040b5ed/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py#L1
......
......@@ -49,8 +49,8 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear,
apply_w8a8_block_fp8_linear,
cutlass_fp8_supported,
dispatch_w8a8_block_fp8_linear,
input_to_float8,
is_sm100_supported,
normalize_e4m3fn_to_e4m3fnuz,
......@@ -209,6 +209,8 @@ class Fp8LinearMethod(LinearMethodBase):
# Marlin doesn't support block-wise fp8
self.use_marlin = False
self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear()
def create_weights(
self,
layer: torch.nn.Module,
......@@ -417,7 +419,7 @@ class Fp8LinearMethod(LinearMethodBase):
)
if self.block_quant:
return apply_w8a8_block_fp8_linear(
return self.w8a8_block_fp8_linear(
input=x,
weight=layer.weight,
block_size=self.quant_config.weight_block_size,
......
......@@ -740,7 +740,59 @@ if _is_hip:
return _w8a8_block_fp8_matmul
def w8a8_block_fp8_matmul(
def prepare_block_fp8_matmul_inputs(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype = torch.float16,
) -> Tuple[int, int, int]:
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1]
assert A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2
assert B.is_contiguous()
assert Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
return M, N, K, C
def w8a8_block_fp8_matmul_deepgemm(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, output_dtype)
# Deepgemm only supports output tensor type as bfloat16
assert C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM
if supports_custom_op():
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
else:
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
return C
def w8a8_block_fp8_matmul_triton(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
......@@ -764,81 +816,81 @@ def w8a8_block_fp8_matmul(
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, output_dtype)
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
block_n, block_k = block_size
# deepgemm only support bf16
if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
if supports_custom_op():
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
else:
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
}
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"])
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
}
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
kernel = select_w8a8_block_fp8_matmul_kernel(M, N, config)
kernel = select_w8a8_block_fp8_matmul_kernel(M, N, config)
kernel[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
kernel[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
return C
# universal entry point, for testing purposes
def w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
if output_dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
return w8a8_block_fp8_matmul_deepgemm(
A, B, As, Bs, block_size, output_dtype=output_dtype
)
return w8a8_block_fp8_matmul_triton(
A, B, As, Bs, block_size, output_dtype=output_dtype
)
@triton.jit
def _per_tensor_quant_mla_fp8_stage1(
x_ptr,
......
import os
from typing import List, Optional, Tuple
from curses import flash
from typing import Callable, List, Optional, Tuple
import torch
......@@ -21,7 +22,8 @@ from sglang.srt.layers.quantization.fp8_kernel import (
scaled_fp8_quant,
sglang_per_token_quant_fp8,
static_quant_fp8,
w8a8_block_fp8_matmul,
w8a8_block_fp8_matmul_deepgemm,
w8a8_block_fp8_matmul_triton,
)
from sglang.srt.utils import (
get_bool_env_var,
......@@ -134,7 +136,20 @@ if ENABLE_FLASHINFER_GEMM:
from flashinfer.gemm import gemm_fp8_nt_groupwise
def apply_w8a8_block_fp8_linear(
def dispatch_w8a8_block_fp8_linear() -> Callable:
if ENABLE_FLASHINFER_GEMM:
return flashinfer_gemm_w8a8_block_fp8_linear
elif CUTLASS_BLOCK_FP8_SUPPORTED:
return cutlass_w8a8_block_fp8_linear_with_fallback
elif _is_hip and use_aiter_moe:
return aiter_w8a8_block_fp8_linear
elif _ENABLE_JIT_DEEPGEMM:
return deepgemm_w8a8_block_fp8_linear_with_fallback
else:
return triton_w8a8_block_fp8_linear
def flashinfer_gemm_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
......@@ -143,58 +158,148 @@ def apply_w8a8_block_fp8_linear(
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
# TODO: add more robust shape check here
shape_supported_by_cutlass = (
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
q_input, x_scale = sglang_per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
if ENABLE_FLASHINFER_GEMM:
q_input, x_scale = sglang_per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
x_scale_input = x_scale.T.contiguous()
weight_scale_input = weight_scale.T.contiguous()
output = gemm_fp8_nt_groupwise(
q_input, weight, x_scale_input, weight_scale_input, out_dtype=input.dtype
)
elif CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=True
)
output = fp8_blockwise_scaled_mm(
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
)
elif _is_hip and use_aiter_moe:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
output = torch.zeros(
[q_input.shape[0], weight.shape[0]],
dtype=input.dtype,
device=q_input.device,
x_scale_input = x_scale.T.contiguous()
weight_scale_input = weight_scale.T.contiguous()
output = gemm_fp8_nt_groupwise(
q_input, weight, x_scale_input, weight_scale_input, out_dtype=input_2d.dtype
)
if bias is not None:
output += bias
return output.to(dtype=input_2d.dtype).view(*output_shape)
def cutlass_w8a8_block_fp8_linear_with_fallback(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
# TODO: add more robust shape check here
shape_supported = weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
if not shape_supported:
# fallback to triton
return triton_w8a8_block_fp8_linear(
input, weight, block_size, weight_scale, input_scale, bias
)
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
else:
if _ENABLE_JIT_DEEPGEMM:
q_input, x_scale = sglang_per_token_group_quant_fp8(
input_2d,
block_size[1],
column_major_scales=True,
scale_tma_aligned=True,
)
else:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
output = w8a8_block_fp8_matmul(
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=True
)
output = fp8_blockwise_scaled_mm(
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input_2d.dtype
)
if bias is not None:
output += bias
return output.to(dtype=input_2d.dtype).view(*output_shape)
def deepgemm_w8a8_block_fp8_linear_with_fallback(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
output_dtype = input.dtype
dtype_supported = output_dtype == torch.bfloat16
# TODO: add more robust shape check here
shape_supported = weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
if not (shape_supported and dtype_supported):
# fall back to triton
return triton_w8a8_block_fp8_linear(
input, weight, block_size, weight_scale, input_scale, bias
)
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = sglang_per_token_group_quant_fp8(
input_2d,
block_size[1],
column_major_scales=True,
scale_tma_aligned=True,
)
output = w8a8_block_fp8_matmul_deepgemm(
q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
output += bias
return output.to(dtype=output_dtype).view(*output_shape)
def aiter_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
output = torch.zeros(
[q_input.shape[0], weight.shape[0]],
dtype=input_2d.dtype,
device=q_input.device,
)
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
if bias is not None:
output += bias
return output.to(dtype=input_2d.dtype).view(*output_shape)
def triton_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
output = w8a8_block_fp8_matmul_triton(
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input_2d.dtype
)
if bias is not None:
output += bias
return output.to(dtype=input_2d.dtype).view(*output_shape)
def input_to_float8(
......
......@@ -9,7 +9,9 @@ from deep_gemm import get_col_major_tma_aligned_tensor
from sgl_kernel import fp8_blockwise_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from sglang.srt.layers.quantization.fp8_kernel import w8a8_block_fp8_matmul
from sglang.srt.layers.quantization.fp8_kernel import (
w8a8_block_fp8_matmul_triton as w8a8_block_fp8_matmul,
)
def get_weight_shapes(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment