Commit 96590097 authored by lixh6's avatar lixh6 Committed by wanglong3
Browse files

feat: support fp8-blockwise matmul impl.

parent 43155293
......@@ -26,6 +26,8 @@ from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import (fp8_gemm_nt, is_deep_gemm_e8m0_used,
should_use_deepgemm_for_fp8_linear)
from lmslim import quant_ops
from lmslim.quantize.quant_ops import BlockSize
logger = init_logger(__name__)
......@@ -98,7 +100,7 @@ if current_platform.is_rocm():
def dispatch_w8a8_blockscale_func(
use_cutlass: bool, use_aiter_and_is_supported: bool
use_cutlass: bool, use_aiter_and_is_supported: bool, use_blaslt: bool
) -> Callable[[
torch.Tensor,
torch.Tensor,
......@@ -111,6 +113,8 @@ def dispatch_w8a8_blockscale_func(
return cutlass_scaled_mm
if (use_aiter_and_is_supported):
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
if use_blaslt:
return hipblaslt_w8a8_block_fp8_matmul
return w8a8_block_fp8_matmul
......@@ -129,14 +133,17 @@ def apply_w8a8_block_fp8_linear(
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
output_shape = []
if envs.VLLM_W8A8_BACKEND == 3:
output_shape = [*input.shape[:-1], weight.shape[-1]]
else:
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype
if should_use_deepgemm_for_fp8_linear(output_dtype, weight):
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = per_token_group_quant_fp8(
input_2d,
block_size[1],
......@@ -149,9 +156,8 @@ def apply_w8a8_block_fp8_linear(
if bias is not None:
output += bias
return output.to(dtype=output_dtype).view(*output_shape)
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
cutlass_block_fp8_supported, use_aiter_and_is_supported)
cutlass_block_fp8_supported, use_aiter_and_is_supported, envs.VLLM_W8A8_BACKEND == 3)
if cutlass_block_fp8_supported:
num_pad = 0
if current_platform.is_device_capability(90):
......@@ -195,7 +201,11 @@ def apply_w8a8_block_fp8_linear_fake(
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
) -> torch.Tensor:
output_shape = [*input.shape[:-1], weight.shape[0]]
output_shape = []
if envs.VLLM_W8A8_BACKEND == 3:
output_shape = [*input.shape[:-1], weight.shape[-1]]
else:
output_shape = [*input.shape[:-1], weight.shape[0]]
return torch.empty(output_shape, dtype=input.dtype, device=input.device)
......@@ -581,6 +591,26 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
)
return None
def hipblaslt_w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
assert A.shape[1] == B.shape[0]
m, k = A.shape
_, n = B.shape
enum_block_size = BlockSize.block_128x128
if block_size[0] == 64:
enum_block_size = BlockSize.block_64x64
elif block_size[0] == 128:
enum_block_size = BlockSize.block_128x128
else:
print(f"[WARN] Unsupported block_size: {block_size}. Falling back to BlockSize.block_128x128")
_, d = quant_ops.hipblaslt_w8a8_blockwise_gemm(A, B, As, Bs, m, n, k, 'NN', output_dtype, enum_block_size, None)
return d
def w8a8_block_fp8_matmul(
A: torch.Tensor,
......@@ -898,7 +928,11 @@ def process_fp8_weight_block_strategy(
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale)
weight = _maybe_pad_fp8_weight(weight)
if envs.VLLM_W8A8_BACKEND == 3:
weight = weight.T.contiguous()
weight_scale = weight_scale.T.contiguous()
else:
weight = _maybe_pad_fp8_weight(weight)
return weight, weight_scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment