"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "9432ed8c7e158def084e0770fd02838292bf57e4"
Commit 96d4d18e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-ep' into 'v0.9.2-dev'

support w4a8 ep

See merge request OpenDAS/vllm!2
parents c11b09df 48a9e546
...@@ -34,7 +34,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -34,7 +34,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled() self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
def apply( def apply_ep(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -254,7 +254,7 @@ class EPMoE(FusedMoE): ...@@ -254,7 +254,7 @@ class EPMoE(FusedMoE):
) )
# Matrix multiply. # Matrix multiply.
expert_output = self.quant_method.apply( expert_output = self.quant_method.apply_ep(
layer=self, layer=self,
hidden_states=dispatched_input, hidden_states=dispatched_input,
tokens_per_expert=tokens_per_expert tokens_per_expert=tokens_per_expert
......
...@@ -21,7 +21,10 @@ from vllm.utils import W8a8GetCacheJSON ...@@ -21,7 +21,10 @@ from vllm.utils import W8a8GetCacheJSON
import os import os
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
try:
from lmslim.layers.fused_moe.fuse_moe_w4a8 import fused_experts_impl_w4a8_ep
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
W8A8_TRITONJSON=W8a8GetCacheJSON() W8A8_TRITONJSON=W8a8GetCacheJSON()
def baseline_scaled_mm(a: torch.Tensor, def baseline_scaled_mm(a: torch.Tensor,
...@@ -328,7 +331,21 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -328,7 +331,21 @@ class SlimQuantW4A8Int8MoEMethod:
layer.w2_weight_scale.data, requires_grad=False layer.w2_weight_scale.data, requires_grad=False
) )
def apply( def apply_ep( #dp+ep
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
return fused_experts_impl_w4a8_ep(hidden_states,
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
tokens_per_expert)
def apply(# tp
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment