"vscode:/vscode.git/clone" did not exist on "2874bac618052a079efd837fc82cf3f3519079c7"
Commit 5ca1c279 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-ds' of...

Merge branch 'v0.9.2-dev-ds' of https://developer.sourcefind.cn/codes/OpenDAS/vllm into v0.9.2-dev-ds
parents 8419f911 e8cf079b
...@@ -30,11 +30,11 @@ try: ...@@ -30,11 +30,11 @@ try:
except ImportError: except ImportError:
is_mori_available = False is_mori_available = False
logger = init_logger(__name__) logger = init_logger(__name__)
_MORI_OP = None _MORI_OP = None
@CustomOp.register("unquantized_ep_moe") @CustomOp.register("unquantized_ep_moe")
class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
"""MoE method without quantization.""" """MoE method without quantization."""
...@@ -167,6 +167,7 @@ class EPMoE(FusedMoE): ...@@ -167,6 +167,7 @@ class EPMoE(FusedMoE):
dp+ep MoE Expert Parallel Impl dp+ep MoE Expert Parallel Impl
""" """
def __init__( def __init__(
self, self,
num_experts: int, # Global number of experts num_experts: int, # Global number of experts
...@@ -247,7 +248,6 @@ class EPMoE(FusedMoE): ...@@ -247,7 +248,6 @@ class EPMoE(FusedMoE):
self.mori_op = self.get_mori_op() self.mori_op = self.get_mori_op()
self.first = True self.first = True
def get_mori_op(self): def get_mori_op(self):
global _MORI_OP global _MORI_OP
if _MORI_OP is None: if _MORI_OP is None:
...@@ -263,7 +263,7 @@ class EPMoE(FusedMoE): ...@@ -263,7 +263,7 @@ class EPMoE(FusedMoE):
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
multi_node = self.ep_size / 8 > 1 multi_node = self.ep_size / 8 > 1
mori_data_type=vllm_config.model_config.dtype mori_data_type = vllm_config.model_config.dtype
mori_scale_type_size = vllm_config.model_config.dtype.itemsize mori_scale_type_size = vllm_config.model_config.dtype.itemsize
if self.use_int8_dispatch: if self.use_int8_dispatch:
mori_scale_type_size = 4 mori_scale_type_size = 4
...@@ -281,7 +281,7 @@ class EPMoE(FusedMoE): ...@@ -281,7 +281,7 @@ class EPMoE(FusedMoE):
max_token_type_size=2, max_token_type_size=2,
block_num=80, block_num=80,
warp_num_per_block=16, warp_num_per_block=16,
#kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode # kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode if multi_node else \ kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode if multi_node else \
mori.ops.EpDispatchCombineKernelType.IntraNode mori.ops.EpDispatchCombineKernelType.IntraNode
) )
...@@ -308,7 +308,7 @@ class EPMoE(FusedMoE): ...@@ -308,7 +308,7 @@ class EPMoE(FusedMoE):
return quant_method return quant_method
def sync(self): def sync(self):
#torch.cuda.synchronize() # torch.cuda.synchronize()
dist.barrier() dist.barrier()
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
...@@ -335,7 +335,6 @@ class EPMoE(FusedMoE): ...@@ -335,7 +335,6 @@ class EPMoE(FusedMoE):
if name not in NON_EXPERT_WEIGHTS if name not in NON_EXPERT_WEIGHTS
] ]
def forward_impl(self, hidden_states: torch.Tensor, def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor):
...@@ -369,8 +368,7 @@ class EPMoE(FusedMoE): ...@@ -369,8 +368,7 @@ class EPMoE(FusedMoE):
) )
scales = self.scales scales = self.scales
# self.sync()
#self.sync()
( (
dispatch_output, dispatch_output,
...@@ -384,7 +382,7 @@ class EPMoE(FusedMoE): ...@@ -384,7 +382,7 @@ class EPMoE(FusedMoE):
scales, scales,
topk_ids, topk_ids,
) )
#self.sync() # self.sync()
# expect_m = topk_ids.shape[0] * self.ep_size # expect_m = topk_ids.shape[0] * self.ep_size
# dispatch_output_clip = dispatch_output[:expect_m] # dispatch_output_clip = dispatch_output[:expect_m]
...@@ -421,14 +419,14 @@ class EPMoE(FusedMoE): ...@@ -421,14 +419,14 @@ class EPMoE(FusedMoE):
num_local_tokens=dispatch_recv_num_token, num_local_tokens=dispatch_recv_num_token,
config_select_bs=hidden_states.shape[0], config_select_bs=hidden_states.shape[0],
scales=dispatch_scales if self.use_int8_dispatch else None scales=dispatch_scales if self.use_int8_dispatch else None
#routed_scaling_factor=self.routed_scaling_factor, # routed_scaling_factor=self.routed_scaling_factor,
) )
#self.sync() # self.sync()
combine_output, _ = self.mori_op.combine(expert_output, dispatch_weights, topk_ids) combine_output, _ = self.mori_op.combine(expert_output, dispatch_weights, topk_ids)
final_hidden_states = combine_output[:hidden_states.shape[0], :] final_hidden_states = combine_output[:hidden_states.shape[0], :]
#self.sync() # self.sync()
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None: if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
# if shared_expert_overlap is True, the expert calculation happens in # if shared_expert_overlap is True, the expert calculation happens in
...@@ -447,6 +445,7 @@ class EPMoE(FusedMoE): ...@@ -447,6 +445,7 @@ class EPMoE(FusedMoE):
return final_hidden_states return final_hidden_states
def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor: layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
...@@ -467,5 +466,5 @@ direct_register_custom_op( ...@@ -467,5 +466,5 @@ direct_register_custom_op(
mutates_args=["hidden_states", "router_logits"], mutates_args=["hidden_states", "router_logits"],
fake_impl=ep_moe_forward_fake, fake_impl=ep_moe_forward_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ), tags=(torch.Tag.needs_fixed_stride_order,),
) )
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment