Commit 96d4d18e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-ep' into 'v0.9.2-dev'

support w4a8 ep

See merge request OpenDAS/vllm!2
parents c11b09df 48a9e546
......@@ -34,7 +34,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
def apply(
def apply_ep(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
......@@ -254,7 +254,7 @@ class EPMoE(FusedMoE):
)
# Matrix multiply.
expert_output = self.quant_method.apply(
expert_output = self.quant_method.apply_ep(
layer=self,
hidden_states=dispatched_input,
tokens_per_expert=tokens_per_expert
......
......@@ -21,7 +21,10 @@ from vllm.utils import W8a8GetCacheJSON
import os
from vllm import _custom_ops as ops
try:
from lmslim.layers.fused_moe.fuse_moe_w4a8 import fused_experts_impl_w4a8_ep
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
W8A8_TRITONJSON=W8a8GetCacheJSON()
def baseline_scaled_mm(a: torch.Tensor,
......@@ -328,7 +331,21 @@ class SlimQuantW4A8Int8MoEMethod:
layer.w2_weight_scale.data, requires_grad=False
)
def apply(
def apply_ep( #dp+ep
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
return fused_experts_impl_w4a8_ep(hidden_states,
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
tokens_per_expert)
def apply(# tp
self,
layer: torch.nn.Module,
x: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment