Commit 8510c10c authored by lixh's avatar lixh
Browse files

feat: implement FP8 blockwise GEMM with hipblaslt

parent 45a060d6
......@@ -1837,7 +1837,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_W8A8_BACKEND": lambda: int(os.getenv("VLLM_W8A8_BACKEND", "3")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -1804,6 +1804,7 @@ def fused_experts(
expert_map: torch.Tensor | None = None,
quant_config: FusedMoEQuantConfig | None = None,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor:
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
......
......@@ -1001,6 +1001,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None
assert not self.is_monolithic
......
......@@ -47,6 +47,8 @@ from vllm.utils.flashinfer import (
should_use_flashinfer_for_blockscale_fp8_gemm,
)
from vllm.utils.torch_utils import direct_register_custom_op
from lmslim import quant_ops
from lmslim.quantize.quant_ops import BlockSize
logger = init_logger(__name__)
......@@ -357,6 +359,7 @@ class W8A8BlockFp8LinearOp:
act_quant_group_shape: GroupShape,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
use_blaslt: bool = False,
):
self.weight_group_shape = weight_group_shape
self.act_quant_group_shape = act_quant_group_shape
......@@ -364,14 +367,13 @@ class W8A8BlockFp8LinearOp:
self.is_hopper = current_platform.is_device_capability(90)
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
self.is_flashinfer_supported = is_flashinfer_fp8_blockscale_gemm_supported()
# Get the correct blockscale mul and input quant operations.
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
# to use deepgemm because we don't know the shape of weights (and
# whether deepgemm supports it) at the init time.
self.w8a8_blockscale_op, self.input_quant_op = (
self._dispatch_w8a8_blockscale_op(
cutlass_block_fp8_supported, use_aiter_and_is_supported
cutlass_block_fp8_supported, use_aiter_and_is_supported, use_blaslt
)
)
self.deepgemm_input_quant_op = (
......@@ -397,8 +399,14 @@ class W8A8BlockFp8LinearOp:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
output_shape = []
output_dtype = input.dtype
if envs.VLLM_W8A8_BACKEND == 3:
output_shape = [*input.shape[:-1], weight.shape[-1]]
out_features = int(weight.shape[-1])
else:
output_shape = [*input.shape[:-1], weight.shape[0]]
out_features = int(weight.shape[0])
if should_use_flashinfer_for_blockscale_fp8_gemm(
self.is_flashinfer_supported, output_dtype, input_2d, weight
......@@ -413,7 +421,7 @@ class W8A8BlockFp8LinearOp:
output = self._run_deepgemm(input_2d, weight, weight_scale)
else:
output = self.w8a8_blockscale_op(
input_2d, weight, weight_scale, input_scale
out_features, input_2d, weight, weight_scale, input_scale
)
if bias is not None:
......@@ -535,6 +543,37 @@ class W8A8BlockFp8LinearOp:
input_2d.dtype,
)
def _run_hipblaslt_blockwise(
self,
out_features: int,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor | None = None,
) -> torch.Tensor:
m, k = input_2d.shape
n = out_features
if input_scale is None:
q_input, input_scale = self.input_quant_op(input_2d)
else:
q_input = input_2d
enum_block_size = BlockSize.block_128x128
if hasattr(self, "block_size") and self.block_size[0] == 64:
enum_block_size = BlockSize.block_64x64
output = hipblaslt_w8a8_block_fp8_matmul(
A=q_input,
B=weight,
As=input_scale,
Bs=weight_scale,
block_size=enum_block_size,
output_dtype=torch.bfloat16,
)
return output
def _run_flashinfer(
self,
input_2d: torch.Tensor,
......@@ -562,6 +601,7 @@ class W8A8BlockFp8LinearOp:
self,
use_cutlass: bool,
use_aiter_and_is_supported: bool,
use_blaslt: bool,
) -> tuple[
Callable[
[
......@@ -585,6 +625,16 @@ class W8A8BlockFp8LinearOp:
)
if use_aiter_and_is_supported:
return self._run_aiter, None
if envs.VLLM_W8A8_BACKEND == 3 or use_blaslt:
return (
self._run_hipblaslt_blockwise,
QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=False,
use_ue8m0=False,
),
)
return self._run_triton, (
QuantFP8(
False,
......@@ -1179,6 +1229,19 @@ def get_w8a8_block_fp8_configs(
)
return None
def hipblaslt_w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: BlockSize,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
assert A.shape[1] == B.shape[0]
m, k = A.shape
_, n = B.shape
_, d = quant_ops.hipblaslt_w8a8_blockwise_gemm(A, B, As, Bs, m, n, k, 'NN', output_dtype, block_size, None)
return d
def w8a8_triton_block_scaled_mm(
A: torch.Tensor,
......@@ -1597,6 +1660,10 @@ def process_fp8_weight_block_strategy(
weight=weight, weight_scale=weight_scale
)
if envs.VLLM_W8A8_BACKEND == 3:
weight = weight.T.contiguous()
weight_scale = weight_scale.T.contiguous()
else:
weight = _maybe_pad_fp8_weight(weight)
return weight, weight_scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment