"platforms/cuda/vscode:/vscode.git/clone" did not exist on "31d0993fb6ae6874152d67209b7983e232bf19dc"
Commit 9b42963d authored by wanglong3's avatar wanglong3 Committed by zhuwenwen
Browse files

V0.11.0 dev lxh channelwise

parent c5980399
......@@ -171,6 +171,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None,
**_,
) -> torch.Tensor:
if enable_eplb:
......
......@@ -917,7 +917,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
topk_weights: Optional[torch.Tensor] = None,
topk_ids: Optional[torch.Tensor] = None,**_,
topk_ids: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb:
......
......@@ -18,6 +18,7 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
try:
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from lmslim.quantize.quant_ops import hipblaslt_w8a8_channelwise_gemm
from lmslim.quantize import quant_ops
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
......@@ -347,25 +348,14 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(qinput,
qinput = qinput.view(-1,qinput.shape[-1])
output = triton_scaled_mm_fp8(qinput,
weight,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, qinput.shape[0])
x_scale = torch.narrow(scale_a, 0, 0, qinput.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * scale_b.t()
if bias is not None:
output = output + bias
return output.to(out_dtype).view(*output_shape)
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
bias=bias)
return output.view(*output_shape)
def dispatch_w8a8_scaled_mm(
preferred_backend: str, per_tensor_weights: bool,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment