Commit 96590097 authored by lixh6's avatar lixh6 Committed by wanglong3
Browse files

feat: support fp8-blockwise matmul impl.

parent 43155293
...@@ -26,6 +26,8 @@ from vllm.triton_utils import tl, triton ...@@ -26,6 +26,8 @@ from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import (fp8_gemm_nt, is_deep_gemm_e8m0_used, from vllm.utils.deep_gemm import (fp8_gemm_nt, is_deep_gemm_e8m0_used,
should_use_deepgemm_for_fp8_linear) should_use_deepgemm_for_fp8_linear)
from lmslim import quant_ops
from lmslim.quantize.quant_ops import BlockSize
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -98,7 +100,7 @@ if current_platform.is_rocm(): ...@@ -98,7 +100,7 @@ if current_platform.is_rocm():
def dispatch_w8a8_blockscale_func( def dispatch_w8a8_blockscale_func(
use_cutlass: bool, use_aiter_and_is_supported: bool use_cutlass: bool, use_aiter_and_is_supported: bool, use_blaslt: bool
) -> Callable[[ ) -> Callable[[
torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
...@@ -111,6 +113,8 @@ def dispatch_w8a8_blockscale_func( ...@@ -111,6 +113,8 @@ def dispatch_w8a8_blockscale_func(
return cutlass_scaled_mm return cutlass_scaled_mm
if (use_aiter_and_is_supported): if (use_aiter_and_is_supported):
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
if use_blaslt:
return hipblaslt_w8a8_block_fp8_matmul
return w8a8_block_fp8_matmul return w8a8_block_fp8_matmul
...@@ -129,14 +133,17 @@ def apply_w8a8_block_fp8_linear( ...@@ -129,14 +133,17 @@ def apply_w8a8_block_fp8_linear(
assert input_scale is None assert input_scale is None
# View input as 2D matrix for fp8 methods # View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1]) input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]] output_shape = []
if envs.VLLM_W8A8_BACKEND == 3:
output_shape = [*input.shape[:-1], weight.shape[-1]]
else:
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype output_dtype = input.dtype
if should_use_deepgemm_for_fp8_linear(output_dtype, weight): if should_use_deepgemm_for_fp8_linear(output_dtype, weight):
input_2d = input.view(-1, input.shape[-1]) input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]] output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = per_token_group_quant_fp8( q_input, x_scale = per_token_group_quant_fp8(
input_2d, input_2d,
block_size[1], block_size[1],
...@@ -149,9 +156,8 @@ def apply_w8a8_block_fp8_linear( ...@@ -149,9 +156,8 @@ def apply_w8a8_block_fp8_linear(
if bias is not None: if bias is not None:
output += bias output += bias
return output.to(dtype=output_dtype).view(*output_shape) return output.to(dtype=output_dtype).view(*output_shape)
w8a8_blockscale_func = dispatch_w8a8_blockscale_func( w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
cutlass_block_fp8_supported, use_aiter_and_is_supported) cutlass_block_fp8_supported, use_aiter_and_is_supported, envs.VLLM_W8A8_BACKEND == 3)
if cutlass_block_fp8_supported: if cutlass_block_fp8_supported:
num_pad = 0 num_pad = 0
if current_platform.is_device_capability(90): if current_platform.is_device_capability(90):
...@@ -195,7 +201,11 @@ def apply_w8a8_block_fp8_linear_fake( ...@@ -195,7 +201,11 @@ def apply_w8a8_block_fp8_linear_fake(
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False, use_aiter_and_is_supported: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
output_shape = [*input.shape[:-1], weight.shape[0]] output_shape = []
if envs.VLLM_W8A8_BACKEND == 3:
output_shape = [*input.shape[:-1], weight.shape[-1]]
else:
output_shape = [*input.shape[:-1], weight.shape[0]]
return torch.empty(output_shape, dtype=input.dtype, device=input.device) return torch.empty(output_shape, dtype=input.dtype, device=input.device)
...@@ -581,6 +591,26 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, ...@@ -581,6 +591,26 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
) )
return None return None
def hipblaslt_w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
assert A.shape[1] == B.shape[0]
m, k = A.shape
_, n = B.shape
enum_block_size = BlockSize.block_128x128
if block_size[0] == 64:
enum_block_size = BlockSize.block_64x64
elif block_size[0] == 128:
enum_block_size = BlockSize.block_128x128
else:
print(f"[WARN] Unsupported block_size: {block_size}. Falling back to BlockSize.block_128x128")
_, d = quant_ops.hipblaslt_w8a8_blockwise_gemm(A, B, As, Bs, m, n, k, 'NN', output_dtype, enum_block_size, None)
return d
def w8a8_block_fp8_matmul( def w8a8_block_fp8_matmul(
A: torch.Tensor, A: torch.Tensor,
...@@ -898,7 +928,11 @@ def process_fp8_weight_block_strategy( ...@@ -898,7 +928,11 @@ def process_fp8_weight_block_strategy(
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale) weight=weight, weight_scale=weight_scale)
weight = _maybe_pad_fp8_weight(weight) if envs.VLLM_W8A8_BACKEND == 3:
weight = weight.T.contiguous()
weight_scale = weight_scale.T.contiguous()
else:
weight = _maybe_pad_fp8_weight(weight)
return weight, weight_scale return weight, weight_scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment