Commit 5ca1c279 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-ds' of...

Merge branch 'v0.9.2-dev-ds' of https://developer.sourcefind.cn/codes/OpenDAS/vllm into v0.9.2-dev-ds
parents 8419f911 e8cf079b
......@@ -30,11 +30,11 @@ try:
except ImportError:
is_mori_available = False
logger = init_logger(__name__)
_MORI_OP = None
@CustomOp.register("unquantized_ep_moe")
class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
"""MoE method without quantization."""
......@@ -44,20 +44,20 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self.topk_indices_dtype = None
self.moe = moe
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
def apply_ep(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=True,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=True,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor:
return self.forward(
......@@ -73,17 +73,17 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
use_nn_moe=use_nn_moe)
def forward_cuda(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=True,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=True,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor:
# process MoE
......@@ -109,48 +109,48 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
return output
def forward_cpu(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=True,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
**kwargs,
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=True,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
**kwargs,
):
raise NotImplementedError
def forward_hpu(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=True,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=True,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor:
raise NotImplementedError
def forward_tpu(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=True,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=True,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor:
raise NotImplementedError
......@@ -167,49 +167,50 @@ class EPMoE(FusedMoE):
dp+ep MoE Expert Parallel Impl
"""
def __init__(
self,
num_experts: int, # Global number of experts
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
ep_size: Optional[int] = None,
dp_size: Optional[int] = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
enable_eplb: bool = False,
num_redundant_experts: int = 0,
moe_permute_fusion: bool = False,
moe_shared_expert_overlap: bool = False
self,
num_experts: int, # Global number of experts
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
ep_size: Optional[int] = None,
dp_size: Optional[int] = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
enable_eplb: bool = False,
num_redundant_experts: int = 0,
moe_permute_fusion: bool = False,
moe_shared_expert_overlap: bool = False
):
super().__init__(num_experts, top_k, hidden_size,
intermediate_size, params_dtype,
reduce_results, renormalize,
use_grouped_topk, num_expert_group,
topk_group, quant_config, tp_size,
ep_size, dp_size, prefix,
ep_size, dp_size, prefix,
custom_routing_function, scoring_func,
e_score_correction_bias,
apply_router_weight_on_input,
activation,
activation,
routed_scaling_factor=routed_scaling_factor,
enable_eplb=enable_eplb,
num_redundant_experts=num_redundant_experts,
)
self.ep_moe_config: EpMoeConfig = EpMoeConfig.make(
moe_router_topk=self.top_k,
# TODO: support fusion permute
......@@ -222,7 +223,7 @@ class EPMoE(FusedMoE):
)
local_expert_indices_offset = (
self.ep_rank * self.local_num_experts
self.ep_rank * self.local_num_experts
)
self.local_expert_indices = [
local_expert_indices_offset + i for i in range(self.local_num_experts)
......@@ -230,10 +231,10 @@ class EPMoE(FusedMoE):
self.use_shared_expert = False
self.token_dispatcher = MoEAlltoAllTokenDispatcher(
self.local_num_experts, self.local_expert_indices,
config=self.ep_moe_config, layer_name=f"{self.layer_name}.token_dispatcher",
)
self.local_num_experts, self.local_expert_indices,
config=self.ep_moe_config, layer_name=f"{self.layer_name}.token_dispatcher",
)
self.shared_expert_overlap = moe_shared_expert_overlap
self.shared_experts = None
......@@ -246,8 +247,7 @@ class EPMoE(FusedMoE):
self.max_num_inp_token_per_rank = vllm_config.scheduler_config.max_num_seqs
self.mori_op = self.get_mori_op()
self.first = True
def get_mori_op(self):
global _MORI_OP
if _MORI_OP is None:
......@@ -259,14 +259,14 @@ class EPMoE(FusedMoE):
# assert world_group is not None
# torch._C._distributed_c10d._register_process_group("default", world_group)
# mori.shmem.shmem_torch_process_group_init("default")
vllm_config = get_current_vllm_config()
multi_node = self.ep_size / 8 > 1
mori_data_type=vllm_config.model_config.dtype
mori_data_type = vllm_config.model_config.dtype
mori_scale_type_size = vllm_config.model_config.dtype.itemsize
if self.use_int8_dispatch:
mori_scale_type_size = 4
mori_scale_type_size = 4
config = mori.ops.EpDispatchCombineConfig(
data_type=mori_data_type,
......@@ -281,12 +281,12 @@ class EPMoE(FusedMoE):
max_token_type_size=2,
block_num=80,
warp_num_per_block=16,
#kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode
# kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode if multi_node else \
mori.ops.EpDispatchCombineKernelType.IntraNode
)
_MORI_OP = mori.ops.EpDispatchCombineOp(config)
return _MORI_OP
def set_shared_experts(self, shared_experts: torch.nn.Module):
......@@ -306,15 +306,15 @@ class EPMoE(FusedMoE):
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
return quant_method
def sync(self):
#torch.cuda.synchronize()
# torch.cuda.synchronize()
dist.barrier()
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
return torch.ops.vllm.ep_moe_forward(hidden_states, router_logits,
self.layer_name)
self.layer_name)
def get_expert_weights(self) -> Iterable[torch.Tensor]:
weights = list(self.named_parameters())
......@@ -333,30 +333,29 @@ class EPMoE(FusedMoE):
return [
weight.view(self.local_num_experts, -1) for name, weight in weights
if name not in NON_EXPERT_WEIGHTS
]
]
def forward_impl(self, hidden_states: torch.Tensor,
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
topk_weights, topk_ids = self.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
use_grouped_topk=self.use_grouped_topk,
top_k=self.top_k,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
indices_type=torch.int32,
routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate)
hidden_states=hidden_states,
router_logits=router_logits,
use_grouped_topk=self.use_grouped_topk,
top_k=self.top_k,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
indices_type=torch.int32,
routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate)
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
if self.use_int8_dispatch:
hidden_states, scales = per_token_quant_int8(hidden_states)
else:
......@@ -369,23 +368,22 @@ class EPMoE(FusedMoE):
)
scales = self.scales
#self.sync()
# self.sync()
(
dispatch_output,
dispatch_weights,
dispatch_scales,
dispatch_indices,
dispatch_recv_num_token,
dispatch_output,
dispatch_weights,
dispatch_scales,
dispatch_indices,
dispatch_recv_num_token,
) = self.mori_op.dispatch(
hidden_states,
topk_weights,
scales,
topk_ids,
)
#self.sync()
# self.sync()
# expect_m = topk_ids.shape[0] * self.ep_size
# dispatch_output_clip = dispatch_output[:expect_m]
# dispatch_weights_clip = dispatch_weights[:expect_m]
......@@ -421,14 +419,14 @@ class EPMoE(FusedMoE):
num_local_tokens=dispatch_recv_num_token,
config_select_bs=hidden_states.shape[0],
scales=dispatch_scales if self.use_int8_dispatch else None
#routed_scaling_factor=self.routed_scaling_factor,
# routed_scaling_factor=self.routed_scaling_factor,
)
#self.sync()
# self.sync()
combine_output, _ = self.mori_op.combine(expert_output, dispatch_weights, topk_ids)
final_hidden_states = combine_output[:hidden_states.shape[0], :]
#self.sync()
# self.sync()
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
# if shared_expert_overlap is True, the expert calculation happens in
......@@ -443,12 +441,13 @@ class EPMoE(FusedMoE):
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
* (1. / self.routed_scaling_factor)
return final_hidden_states
def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None
......@@ -457,7 +456,7 @@ def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
def ep_moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states)
......@@ -467,5 +466,5 @@ direct_register_custom_op(
mutates_args=["hidden_states", "router_logits"],
fake_impl=ep_moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
tags=(torch.Tag.needs_fixed_stride_order,),
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment