Commit 900f4720 authored by wanglong3's avatar wanglong3
Browse files

feat: Support w8a8-fp8 GEMM backend.

parent 9d16d5aa
......@@ -331,8 +331,12 @@ class Fp8LinearMethod(LinearMethodBase):
weight = layer.weight.data
weight_scale_inv = layer.weight_scale_inv.data
weight = self._maybe_pad_weight(weight)
if envs.VLLM_W8A8_BACKEND == 3:
weight = weight.T.contiguous()
weight_scale_inv = weight_scale_inv.T.contiguous()
else:
weight = self._maybe_pad_weight(weight)
# Torch.compile cannot use Parameter subclasses.
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale_inv,
......
......@@ -6,6 +6,8 @@ import functools
import json
import os
from typing import Any, Callable, Optional, Union, List
from lmslim import quant_ops
from lmslim.quantize.quant_ops import BlockSize
import torch
......@@ -83,7 +85,7 @@ if current_platform.is_rocm():
def dispatch_w8a8_blockscale_func(
use_cutlass: bool, use_aiter_and_is_supported: bool
use_cutlass: bool, use_aiter_and_is_supported: bool, use_blaslt: bool
) -> Callable[[
torch.Tensor,
torch.Tensor,
......@@ -96,6 +98,9 @@ def dispatch_w8a8_blockscale_func(
return cutlass_scaled_mm
if (use_aiter_and_is_supported):
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
if use_blaslt:
return hipblaslt_w8a8_block_fp8_matmul
return w8a8_block_fp8_matmul
......@@ -127,7 +132,11 @@ def apply_w8a8_block_fp8_linear(
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
output_shape = []
if envs.VLLM_W8A8_BACKEND == 3:
output_shape = [*input.shape[:-1], weight.shape[-1]]
else:
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype
if should_use_deepgemm(output_dtype, weight):
......@@ -166,9 +175,12 @@ def apply_w8a8_block_fp8_linear(
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
else:
use_cutlass = False
use_blaslt = False
if envs.VLLM_W8A8_BACKEND == 3:
use_blaslt = True
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
use_cutlass, use_aiter_and_is_supported)
use_cutlass, use_aiter_and_is_supported, use_blaslt)
if use_cutlass:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=use_cutlass)
......@@ -197,7 +209,11 @@ def apply_w8a8_block_fp8_linear_fake(
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
) -> torch.Tensor:
output_shape = [*input.shape[:-1], weight.shape[0]]
output_shape = []
if envs.VLLM_W8A8_BACKEND == 3:
output_shape = [*input.shape[:-1], weight.shape[-1]]
else:
output_shape = [*input.shape[:-1], weight.shape[0]]
return torch.empty(output_shape, dtype=input.dtype, device=input.device)
......@@ -566,6 +582,29 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
return None
def hipblaslt_w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m, k = A.shape
_, n = B.shape
enum_block_size = BlockSize.block_128x128
if block_size[0] == 64:
enum_block_size = BlockSize.block_64x64
elif block_size[0] == 128:
enum_block_size = BlockSize.block_128x128
else:
print(f"[WARN] Unsupported block_size: {block_size}. Falling back to BlockSize.block_128x128")
_, d = quant_ops.hipblaslt_w8a8_blockwise_gemm(A, B, As, Bs,
m, n, k, 'NN', output_dtype,
enum_block_size, None)
return d
def w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment