"vllm/multimodal/processing.py" did not exist on "bb7991aa291054a30f408e626273caa6769a07eb"
Commit 41f98782 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-channel-lxh' into 'v0.9.2-dev'

V0.9.2 dev channel lxh

See merge request dcutoolkit/deeplearing/vllm!390
parents 9f68733a 747cd248
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
from vllm import envs
import torch
from compressed_tensors.quantization import QuantizationStrategy
from torch.nn import Parameter
......@@ -61,6 +61,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
# If channelwise, scales are already lined up, so just transpose.
elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight
if envs.VLLM_W8A8_BACKEND == 3:
weight = weight.t()
if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None)
......
......@@ -11,8 +11,10 @@ from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.platforms import current_platform
from vllm.utils import W8a8GetCacheJSON
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from lmslim.quantize import quant_ops
try:
from lmslim.layers.gemm.fp8_utils import triton_scaled_mm_fp8
from lmslim.quantize.quant_ops import hipblaslt_w8a8_channelwise_gemm
except Exception:
print("INFO: Please updata lmslim if you want to use fp8_utils.\n")
# Input scaling factors are no longer optional in _scaled_mm starting
......@@ -255,6 +257,39 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
output = output.view(*output_shape)
return output
def hipblaslt_w8a8_channelwise_scaled_mm(
qinput: torch.Tensor,
input_2d: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
**kwargs
) -> torch.Tensor:
assert qinput.is_contiguous() and weight.is_contiguous()
assert qinput.shape[-1] == weight.shape[-1]
assert qinput.dtype == weight.dtype
m = qinput.shape[0]
k = qinput.shape[1]
n = weight.shape[0]
success, output = quant_ops.hipblaslt_w8a8_channelwise_gemm(
a = qinput,
b = weight,
scale_a = scale_a,
scale_b = scale_b,
m = m,
n = n,
k = k,
transpose_flag = "NT",
out_dtype = out_dtype,
bias = bias,
)
return output.view(m, n)
def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
......@@ -315,6 +350,8 @@ def dispatch_w8a8_scaled_mm(
if current_platform.is_rocm():
return rocm_per_tensor_w8a8_scaled_mm
return torch_per_tensor_w8a8_scaled_mm
if envs.VLLM_W8A8_BACKEND == 3:
return hipblaslt_w8a8_channelwise_scaled_mm
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
if (use_per_token_if_dynamic and not per_tensor_weights
......
......@@ -1300,6 +1300,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_decode:
assert attn_metadata.decode is not None
kv_cache_dtype_str = None
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype_str=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA:
decode_q = q_quant[:num_decode_tokens]
decode_q_nope, decode_q_pe = decode_q.split(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment