Unverified Commit ae339b1a authored by Zhewen Li's avatar Zhewen Li Committed by GitHub
Browse files

[Bugfix] Fix DeepGEMM after #29546 (#30267)


Signed-off-by: default avatarzhewenli <zhewenli@meta.com>
Signed-off-by: default avatarZhewen Li <zhewenli@meta.com>
parent 0ee6416f
......@@ -30,6 +30,7 @@ from vllm.model_executor.parameter import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
fp8_gemm_nt,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
......@@ -268,12 +269,15 @@ class W8A8BlockFp8LinearOp:
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
assert self.deepgemm_input_quant_op is not None
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
input_2d,
group_size=self.act_quant_group_shape.col,
use_ue8m0=True,
)
if DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0:
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
input_2d,
group_size=self.act_quant_group_shape.col,
use_ue8m0=True,
)
else:
assert self.deepgemm_input_quant_op is not None
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
output = torch.empty(
(q_input.shape[0], weight.shape[0]),
dtype=torch.bfloat16,
......
......@@ -399,6 +399,7 @@ def should_use_deepgemm_for_fp8_linear_for_nk(
__all__ = [
"calc_diff",
"DeepGemmQuantScaleFMT",
"fp8_gemm_nt",
"m_grouped_fp8_gemm_nt_contiguous",
"fp8_m_grouped_gemm_nt_masked",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment