Commit 9b42963d authored by wanglong3's avatar wanglong3 Committed by zhuwenwen
Browse files

V0.11.0 dev lxh channelwise

parent c5980399
...@@ -171,6 +171,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -171,6 +171,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None,
**_, **_,
) -> torch.Tensor: ) -> torch.Tensor:
if enable_eplb: if enable_eplb:
......
...@@ -917,7 +917,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -917,7 +917,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
topk_weights: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None,
topk_ids: Optional[torch.Tensor] = None,**_, topk_ids: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb: if enable_eplb:
......
...@@ -18,6 +18,7 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer ...@@ -18,6 +18,7 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
try: try:
from lmslim.layers.gemm.int8_utils import per_token_quant_int8 from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from lmslim.quantize.quant_ops import hipblaslt_w8a8_channelwise_gemm
from lmslim.quantize import quant_ops from lmslim.quantize import quant_ops
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n") print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
...@@ -347,25 +348,14 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, ...@@ -347,25 +348,14 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
# GEMM # GEMM
# This computes C = (X * W). # This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place # Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(qinput, qinput = qinput.view(-1,qinput.shape[-1])
output = triton_scaled_mm_fp8(qinput,
weight, weight,
scale_a=TORCH_DEVICE_IDENTITY, scale_a=scale_a,
scale_b=TORCH_DEVICE_IDENTITY, scale_b=scale_b,
out_dtype=torch.float32) out_dtype=out_dtype,
# A fix for discrepancy in scaled_mm which returns tuple bias=bias)
# for torch < 2.5 and a single value in torch >= 2.5 return output.view(*output_shape)
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, qinput.shape[0])
x_scale = torch.narrow(scale_a, 0, 0, qinput.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * scale_b.t()
if bias is not None:
output = output + bias
return output.to(out_dtype).view(*output_shape)
def dispatch_w8a8_scaled_mm( def dispatch_w8a8_scaled_mm(
preferred_backend: str, per_tensor_weights: bool, preferred_backend: str, per_tensor_weights: bool,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment