Unverified Commit 5fbfa8d9 authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Quantization] fix marlin w8a8 check (#30961)


Signed-off-by: default avatarJinzhen Lin <jinzhen.ljz@antgroup.com>
parent 23a1946e
...@@ -11,7 +11,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -11,7 +11,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_permute_bias, marlin_permute_bias,
marlin_permute_scales, marlin_permute_scales,
marlin_quant_input,
should_use_atomic_add_reduce, should_use_atomic_add_reduce,
) )
from vllm.model_executor.utils import replace_parameter from vllm.model_executor.utils import replace_parameter
...@@ -63,13 +62,11 @@ def apply_fp8_marlin_linear( ...@@ -63,13 +62,11 @@ def apply_fp8_marlin_linear(
inputs = reshaped_x inputs = reshaped_x
a_scales = None a_scales = None
if input_dtype is not None and input_dtype.itemsize == 1: if input_dtype is not None and input_dtype.itemsize == 1:
if input_dtype != torch.float8_e4m3fn: # inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
raise RuntimeError("FP8 weight + INT8 activation is not supported.") raise RuntimeError("Marlin W8A8 is not supported.")
inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
output = ops.gptq_marlin_gemm( output = ops.gptq_marlin_gemm(
a=reshaped_x, a=inputs,
c=None, c=None,
b_q_weight=weight, b_q_weight=weight,
b_bias=bias, b_bias=bias,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment