Commit 900f4720 authored by wanglong3's avatar wanglong3
Browse files

feat: Support w8a8-fp8 GEMM backend.

parent 9d16d5aa
...@@ -331,6 +331,10 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -331,6 +331,10 @@ class Fp8LinearMethod(LinearMethodBase):
weight = layer.weight.data weight = layer.weight.data
weight_scale_inv = layer.weight_scale_inv.data weight_scale_inv = layer.weight_scale_inv.data
if envs.VLLM_W8A8_BACKEND == 3:
weight = weight.T.contiguous()
weight_scale_inv = weight_scale_inv.T.contiguous()
else:
weight = self._maybe_pad_weight(weight) weight = self._maybe_pad_weight(weight)
# Torch.compile cannot use Parameter subclasses. # Torch.compile cannot use Parameter subclasses.
......
...@@ -6,6 +6,8 @@ import functools ...@@ -6,6 +6,8 @@ import functools
import json import json
import os import os
from typing import Any, Callable, Optional, Union, List from typing import Any, Callable, Optional, Union, List
from lmslim import quant_ops
from lmslim.quantize.quant_ops import BlockSize
import torch import torch
...@@ -83,7 +85,7 @@ if current_platform.is_rocm(): ...@@ -83,7 +85,7 @@ if current_platform.is_rocm():
def dispatch_w8a8_blockscale_func( def dispatch_w8a8_blockscale_func(
use_cutlass: bool, use_aiter_and_is_supported: bool use_cutlass: bool, use_aiter_and_is_supported: bool, use_blaslt: bool
) -> Callable[[ ) -> Callable[[
torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
...@@ -96,6 +98,9 @@ def dispatch_w8a8_blockscale_func( ...@@ -96,6 +98,9 @@ def dispatch_w8a8_blockscale_func(
return cutlass_scaled_mm return cutlass_scaled_mm
if (use_aiter_and_is_supported): if (use_aiter_and_is_supported):
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
if use_blaslt:
return hipblaslt_w8a8_block_fp8_matmul
return w8a8_block_fp8_matmul return w8a8_block_fp8_matmul
...@@ -127,6 +132,10 @@ def apply_w8a8_block_fp8_linear( ...@@ -127,6 +132,10 @@ def apply_w8a8_block_fp8_linear(
assert input_scale is None assert input_scale is None
# View input as 2D matrix for fp8 methods # View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1]) input_2d = input.view(-1, input.shape[-1])
output_shape = []
if envs.VLLM_W8A8_BACKEND == 3:
output_shape = [*input.shape[:-1], weight.shape[-1]]
else:
output_shape = [*input.shape[:-1], weight.shape[0]] output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype output_dtype = input.dtype
...@@ -166,9 +175,12 @@ def apply_w8a8_block_fp8_linear( ...@@ -166,9 +175,12 @@ def apply_w8a8_block_fp8_linear(
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
else: else:
use_cutlass = False use_cutlass = False
use_blaslt = False
if envs.VLLM_W8A8_BACKEND == 3:
use_blaslt = True
w8a8_blockscale_func = dispatch_w8a8_blockscale_func( w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
use_cutlass, use_aiter_and_is_supported) use_cutlass, use_aiter_and_is_supported, use_blaslt)
if use_cutlass: if use_cutlass:
q_input, x_scale = per_token_group_quant_fp8( q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=use_cutlass) input_2d, block_size[1], column_major_scales=use_cutlass)
...@@ -197,6 +209,10 @@ def apply_w8a8_block_fp8_linear_fake( ...@@ -197,6 +209,10 @@ def apply_w8a8_block_fp8_linear_fake(
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False, use_aiter_and_is_supported: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
output_shape = []
if envs.VLLM_W8A8_BACKEND == 3:
output_shape = [*input.shape[:-1], weight.shape[-1]]
else:
output_shape = [*input.shape[:-1], weight.shape[0]] output_shape = [*input.shape[:-1], weight.shape[0]]
return torch.empty(output_shape, dtype=input.dtype, device=input.device) return torch.empty(output_shape, dtype=input.dtype, device=input.device)
...@@ -566,6 +582,29 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, ...@@ -566,6 +582,29 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
return None return None
def hipblaslt_w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m, k = A.shape
_, n = B.shape
enum_block_size = BlockSize.block_128x128
if block_size[0] == 64:
enum_block_size = BlockSize.block_64x64
elif block_size[0] == 128:
enum_block_size = BlockSize.block_128x128
else:
print(f"[WARN] Unsupported block_size: {block_size}. Falling back to BlockSize.block_128x128")
_, d = quant_ops.hipblaslt_w8a8_blockwise_gemm(A, B, As, Bs,
m, n, k, 'NN', output_dtype,
enum_block_size, None)
return d
def w8a8_block_fp8_matmul( def w8a8_block_fp8_matmul(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment