Commit 1a56d6cb authored by 王敏's avatar 王敏
Browse files

添加mori ep

parent 3cb11400
...@@ -2004,10 +2004,10 @@ class ParallelConfig: ...@@ -2004,10 +2004,10 @@ class ParallelConfig:
logger.info("Disabling V1 multiprocessing for external launcher.") logger.info("Disabling V1 multiprocessing for external launcher.")
if self.enable_eplb: if self.enable_eplb:
if not current_platform.is_cuda(): # if not current_platform.is_cuda():
raise ValueError( # raise ValueError(
"Expert parallelism load balancing is only supported on " # "Expert parallelism load balancing is only supported on "
"CUDA devices now.") # "CUDA devices now.")
if self.num_redundant_experts < 0: if self.num_redundant_experts < 0:
raise ValueError( raise ValueError(
"num_redundant_experts must be non-negative, but got " "num_redundant_experts must be non-negative, but got "
......
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
import logging import logging
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
from dataclasses import dataclass from dataclasses import dataclass
from collections.abc import Iterable
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -11,7 +12,7 @@ from vllm.platforms import current_platform ...@@ -11,7 +12,7 @@ from vllm.platforms import current_platform
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.distributed.parallel_state import get_ep_group, get_node_count from vllm.distributed.parallel_state import get_ep_group, get_node_count, is_use_cuda_graph
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
...@@ -20,7 +21,6 @@ from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, Unqua ...@@ -20,7 +21,6 @@ from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, Unqua
from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EpMoeConfig from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EpMoeConfig
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from lightop import groupgemm
import mori import mori
import torch.distributed as dist import torch.distributed as dist
...@@ -45,7 +45,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -45,7 +45,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace=True, inplace=True,
...@@ -59,7 +58,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -59,7 +58,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
return self.forward( return self.forward(
hidden_states=hidden_states, hidden_states=hidden_states,
layer=layer, layer=layer,
tokens_per_expert=tokens_per_expert,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=inplace, inplace=inplace,
...@@ -73,7 +71,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -73,7 +71,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace=True, inplace=True,
...@@ -85,62 +82,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -85,62 +82,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
) -> torch.Tensor: ) -> torch.Tensor:
# process MoE # process MoE
def custom_forward(layer, hidden_states, tokens_per_expert): def custom_forward(layer, hidden_states):
if False:
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
w1 = layer.w13_weight[i]
w2 = layer.w2_weight[i]
tokens_for_this_expert = hidden_states[start_idx:end_idx]
gateup_output = torch.matmul(tokens_for_this_expert, w1)
# Act
down_input = torch.zeros(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=hidden_states.dtype
)
torch.ops._C.silu_and_mul(down_input,
gateup_output.view(-1, w1.shape[1]))
expert_out = torch.matmul(down_input, w2)
outputs.append(expert_out)
start_idx = end_idx
if len(outputs) > 0:
expert_output = torch.cat(outputs, dim=0)
else:
assert hidden_states.numel() == 0, f"sorted_tokens: should be empty, but got {hidden_states.shape}"
expert_output = hidden_states
else:
if topk_ids is None:
if self.zero_token_count is None:
self.zero_token_count = torch.zeros(1, dtype=torch.int64, device=hidden_states.device)
total_tokens = tokens_per_expert.sum()
print("#################total_tokens:", total_tokens.tolist())
if total_tokens > self.zero_token_count:
gateup_output = groupgemm(hidden_states, layer.w13_weight, tokens_per_expert, False)
# Act
down_input = torch.zeros(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=hidden_states.dtype
)
torch.ops._C.silu_and_mul(down_input,
gateup_output.view(-1, layer.w13_weight.shape[2]))
expert_output = groupgemm(down_input, layer.w2_weight, tokens_per_expert, False)
else :
expert_output = hidden_states
else:
expert_output = self.fused_experts( expert_output = self.fused_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
w1=layer.w13_weight, w1=layer.w13_weight,
...@@ -155,10 +97,9 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -155,10 +97,9 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
use_nn_moe=use_nn_moe use_nn_moe=use_nn_moe
) )
return expert_output return expert_output
output = custom_forward(layer, hidden_states, tokens_per_expert) output = custom_forward(layer, hidden_states)
return output return output
...@@ -166,7 +107,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -166,7 +107,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace=True, inplace=True,
...@@ -183,7 +123,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -183,7 +123,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace=True, inplace=True,
...@@ -199,7 +138,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -199,7 +138,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace=True, inplace=True,
...@@ -249,7 +187,7 @@ class EPMoE(FusedMoE): ...@@ -249,7 +187,7 @@ class EPMoE(FusedMoE):
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
enable_eplb: bool = False, enable_eplb: bool = False,
num_redundant_experts: int = 0, num_redundant_experts: int = 0,
moe_permute_fusion: bool = True, moe_permute_fusion: bool = False,
moe_shared_expert_overlap: bool = False moe_shared_expert_overlap: bool = False
): ):
super().__init__(num_experts, top_k, hidden_size, super().__init__(num_experts, top_k, hidden_size,
...@@ -296,9 +234,11 @@ class EPMoE(FusedMoE): ...@@ -296,9 +234,11 @@ class EPMoE(FusedMoE):
self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1' self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
if True:
self.mori_op = self.get_mori_op() self.mori_op = self.get_mori_op()
self.zero_token_count = None
def get_mori_op(self): def get_mori_op(self):
global _MORI_OP global _MORI_OP
if _MORI_OP is None: if _MORI_OP is None:
...@@ -319,7 +259,7 @@ class EPMoE(FusedMoE): ...@@ -319,7 +259,7 @@ class EPMoE(FusedMoE):
hidden_dim=self.hidden_size, hidden_dim=self.hidden_size,
scale_dim=0, scale_dim=0,
scale_type_size=vllm_config.model_config.dtype.itemsize, scale_type_size=vllm_config.model_config.dtype.itemsize,
max_num_inp_token_per_rank=4096, max_num_inp_token_per_rank=20480,
num_experts_per_rank=self.local_num_experts, num_experts_per_rank=self.local_num_experts,
num_experts_per_token=self.top_k, num_experts_per_token=self.top_k,
max_token_type_size=2, max_token_type_size=2,
...@@ -334,6 +274,7 @@ class EPMoE(FusedMoE): ...@@ -334,6 +274,7 @@ class EPMoE(FusedMoE):
def set_shared_experts(self, shared_experts: torch.nn.Module): def set_shared_experts(self, shared_experts: torch.nn.Module):
if self.shared_experts is None: if self.shared_experts is None:
self.shared_experts = shared_experts self.shared_experts = shared_experts
if self.shared_expert_overlap: if self.shared_expert_overlap:
self.token_dispatcher.set_shared_experts(self.shared_experts) self.token_dispatcher.set_shared_experts(self.shared_experts)
...@@ -357,6 +298,26 @@ class EPMoE(FusedMoE): ...@@ -357,6 +298,26 @@ class EPMoE(FusedMoE):
return torch.ops.vllm.ep_moe_forward(hidden_states, router_logits, return torch.ops.vllm.ep_moe_forward(hidden_states, router_logits,
self.layer_name) self.layer_name)
def get_expert_weights(self) -> Iterable[torch.Tensor]:
weights = list(self.named_parameters())
# Filter out the non-expert weights.
# `e_score_correction_bias` is a bias for each logical expert,
# with shape (num_logical_experts,), not an expert weight.
NON_EXPERT_WEIGHTS = {
"e_score_correction_bias",
"shared_experts.gate_up_proj.weight",
"shared_experts.gate_up_proj.weight_scale",
"shared_experts.down_proj.weight",
"shared_experts.down_proj.weight_scale"
}
return [
weight.view(self.local_num_experts, -1) for name, weight in weights
if name not in NON_EXPERT_WEIGHTS
]
def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
topk_weights, topk_ids = self.select_experts( topk_weights, topk_ids = self.select_experts(
...@@ -376,53 +337,6 @@ class EPMoE(FusedMoE): ...@@ -376,53 +337,6 @@ class EPMoE(FusedMoE):
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None: if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
########################test#########################
# probs = None
# if self.apply_router_weight_on_input:
# probs = torch.zeros_like(router_logits, dtype=topk_weights.dtype).scatter(1, topk_ids, topk_weights)
# routing_map = torch.zeros_like(router_logits).int().scatter(1, topk_ids, 1).bool()
# (dispatch_output, tokens_per_expert) = self.token_dispatcher.token_permutation(
# hidden_states, probs, routing_map
# )
# torch.cuda.synchronize()
# print("###########################all2all dispatch_output shape:", dispatch_output.shape)
# print("###########################all2all dispatch_output:", dispatch_output[:10, :10])
# expert_output = self.quant_method.apply_ep(
# layer=self,
# hidden_states=dispatch_output,
# tokens_per_expert=tokens_per_expert,
# topk_weights=None,
# topk_ids=None,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# )
# torch.cuda.synchronize()
# print("###########################grouped gemm out:", expert_output[:10, :10])
# final_hidden_states = self.token_dispatcher.token_unpermutation(expert_output)
# final_hidden_states_all2all = final_hidden_states
# torch.cuda.synchronize()
# print("####################all2all unpermute output:", final_hidden_states[:10, :10].tolist())
########################test##########################
if False:
probs = None
if self.apply_router_weight_on_input:
probs = torch.zeros_like(router_logits, dtype=topk_weights.dtype).scatter(1, topk_ids, topk_weights)
routing_map = torch.zeros_like(router_logits).int().scatter(1, topk_ids, 1).bool()
(dispatch_output, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, probs, routing_map
)
else:
topk_ids = topk_ids.to(torch.int32) topk_ids = topk_ids.to(torch.int32)
scales = torch.rand( scales = torch.rand(
hidden_states.shape[0], hidden_states.shape[0],
...@@ -430,8 +344,7 @@ class EPMoE(FusedMoE): ...@@ -430,8 +344,7 @@ class EPMoE(FusedMoE):
dtype=torch.float32, dtype=torch.float32,
device=hidden_states.device, device=hidden_states.device,
) )
self.sync()
print("##########################topk_weights shape:{} topk_ids shape:{}".format(topk_weights.shape, topk_ids.shape))
( (
dispatch_output, dispatch_output,
dispatch_weights, dispatch_weights,
...@@ -439,39 +352,42 @@ class EPMoE(FusedMoE): ...@@ -439,39 +352,42 @@ class EPMoE(FusedMoE):
dispatch_indices, dispatch_indices,
dispatch_recv_num_token, dispatch_recv_num_token,
) = self.mori_op.dispatch( ) = self.mori_op.dispatch(
hidden_states.contiguous(), hidden_states,
topk_weights.contiguous(), topk_weights,
scales.contiguous(), scales,
topk_ids.contiguous(), topk_ids,
) )
self.sync() #self.sync()
# with torch.inference_mode():
# src_token_pos = self.mori_op.get_dispatch_src_token_pos().tolist()
# print("##################src_token_pos:", src_token_pos[:10].tolist())
tokens_per_expert = dispatch_recv_num_token
print("######################dispatched_input shape:{} dispatch_weights.shape:{} dispatch_indices shape:{}".format(dispatch_output.shape,
dispatch_weights.shape, dispatch_indices.shape))
print("####################dispatch_recv_num_token:", dispatch_recv_num_token)
#dispatch_recv_num_token = dispatch_recv_num_token[0].item()
dispatch_recv_num_token = dispatch_recv_num_token.cpu()[0] dispatch_recv_num_token = dispatch_recv_num_token.cpu()[0]
print("########################dispatch_output:", dispatch_output[:10, :10].tolist()) dispatch_output = dispatch_output[:dispatch_recv_num_token]
print("########################dispatch_indices:", dispatch_indices[:10, :].tolist()) dispatch_weights = dispatch_weights[:dispatch_recv_num_token]
print("#########################start fused_moe") dispatch_indices = dispatch_indices[:dispatch_recv_num_token]
has_greater_than_255 = torch.any(dispatch_indices > 255).item()
has_less_than_0 = torch.any(dispatch_indices < 0).item()
print("##################################has_greater_than_255:{} has_less_than_0:{}".format(has_greater_than_255, has_less_than_0)) valid_mask = ((dispatch_indices <= 255) & (dispatch_indices >= 0)).all(dim=1)
dispatch_output = dispatch_output[valid_mask]
dispatch_indices = dispatch_indices[valid_mask]
dispatch_weights = dispatch_weights[valid_mask]
dispatch_recv_num_token = dispatch_indices.shape[0]
# dispatch_recv_num_token = dispatch_recv_num_token.cpu()[0]
# has_greater_than_255 = torch.any(dispatch_indices > 255).item()
# has_less_than_0 = torch.any(dispatch_indices < 0).item()
# print("##################################has_greater_than_255:{} has_less_than_0:{}".format(has_greater_than_255, has_less_than_0))
# if has_greater_than_255 or has_less_than_0:
# print("###################dispatch_indices:", dispatch_indices.tolist())
if dispatch_recv_num_token > 0: if dispatch_recv_num_token > 0:
# Matrix multiply. # Matrix multiply.
#expert_output = self.quant_method.apply_ep( expert_output = self.quant_method.apply_ep(
expert_output = self.quant_method.apply(
layer=self, layer=self,
x=dispatch_output[:dispatch_recv_num_token].contiguous(), x=dispatch_output,
tokens_per_expert=tokens_per_expert, topk_weights=dispatch_weights,
topk_weights=dispatch_weights[:dispatch_recv_num_token].contiguous(), topk_ids=dispatch_indices,
topk_ids=dispatch_indices[:dispatch_recv_num_token].contiguous(),
global_num_experts=self.global_num_experts, global_num_experts=self.global_num_experts,
expert_map=self.expert_map, expert_map=self.expert_map,
activation=self.activation, activation=self.activation,
...@@ -479,26 +395,13 @@ class EPMoE(FusedMoE): ...@@ -479,26 +395,13 @@ class EPMoE(FusedMoE):
use_nn_moe=self.use_nn_moe, use_nn_moe=self.use_nn_moe,
) )
else: else:
expert_output = dispatch_output[:dispatch_recv_num_token] expert_output = dispatch_output#[:dispatch_recv_num_token]
self.sync() #self.sync()
print("####################fused_moe expert_output:", expert_output[:10, :10].tolist())
combine_output, _ = self.mori_op.combine(expert_output, dispatch_weights, topk_ids)
if False:
final_hidden_states = self.token_dispatcher.token_unpermutation(expert_output)
else:
combine_output, _ = self.mori_op.combine(expert_output.contiguous(), dispatch_weights.contiguous(), topk_ids.contiguous())
final_hidden_states = combine_output[:hidden_states.shape[0], :] final_hidden_states = combine_output[:hidden_states.shape[0], :]
torch.cuda.synchronize() #self.sync()
print("####################mori combine_output:", combine_output[:10, :10].tolist())
self.sync()
####################test#################
# final_hidden_states_close = torch.allclose(final_hidden_states, final_hidden_states_all2all, rtol=1e-2, atol=1e-2)
# print(f"final_hidden_states_close: {final_hidden_states_close}")
#####################test################
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None: if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
# if shared_expert_overlap is True, the expert calculation happens in # if shared_expert_overlap is True, the expert calculation happens in
......
...@@ -331,7 +331,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -331,7 +331,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.hidden_shape_before_permute = hidden_states.shape self.hidden_shape_before_permute = hidden_states.shape
if True: if False:
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute( permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
hidden_states, hidden_states,
routing_map, routing_map,
...@@ -339,15 +339,10 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -339,15 +339,10 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
fused=self.config.moe_permute_fusion fused=self.config.moe_permute_fusion
) )
else: else:
torch.cuda.synchronize()
print("########################hidden_states shape:{} \n #####################routing_map shape:{}\n".format(hidden_states.shape,
routing_map.shape))
print("########################hidden_states:{} \n #####################routing_map:{}\n".format(hidden_states[0, :10].tolist(),
routing_map[0, :10].tolist()))
cuda_permute_result = groupgemm_permute(hidden_states, routing_map) cuda_permute_result = groupgemm_permute(hidden_states, routing_map)
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping, \ permutated_local_input_tokens, self.reversed_local_input_permutation_mapping, \
expert_m, self.expert_m_count, expert_m_max = cuda_permute_result self.expert_m_count = cuda_permute_result
# Perform expert parallel AlltoAll communication # Perform expert parallel AlltoAll communication
# tokens_per_expert = self._maybe_dtoh_and_synchronize( # tokens_per_expert = self._maybe_dtoh_and_synchronize(
...@@ -427,7 +422,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -427,7 +422,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.shared_experts.post_forward_comm() self.shared_experts.post_forward_comm()
# Unpermutation 1: AlltoAll output to output # Unpermutation 1: AlltoAll output to output
if True: if False:
output = unpermute( output = unpermute(
permutated_local_input_tokens, permutated_local_input_tokens,
self.reversed_local_input_permutation_mapping, self.reversed_local_input_permutation_mapping,
......
...@@ -778,19 +778,19 @@ class FusedMoE(torch.nn.Module): ...@@ -778,19 +778,19 @@ class FusedMoE(torch.nn.Module):
assert isinstance(quant_method, FusedMoEMethodBase) assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method self.quant_method = quant_method
if self.enable_eplb: # if self.enable_eplb:
from vllm.model_executor.layers.quantization.fp8 import ( # from vllm.model_executor.layers.quantization.fp8 import (
Fp8MoEMethod) # Fp8MoEMethod)
if not isinstance(quant_method, Fp8MoEMethod): # if not isinstance(quant_method, Fp8MoEMethod):
# TODO: Add support for additional quantization methods. # # TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not # # The implementation for other quantization methods does not
# contain essential differences, but the current quant API # # contain essential differences, but the current quant API
# design causes duplicated work when extending to new # # design causes duplicated work when extending to new
# quantization methods, so I'm leaving it for now. # # quantization methods, so I'm leaving it for now.
# If you plan to add support for more quantization methods, # # If you plan to add support for more quantization methods,
# please refer to the implementation in `Fp8MoEMethod`. # # please refer to the implementation in `Fp8MoEMethod`.
raise NotImplementedError("EPLB is only supported for FP8 " # raise NotImplementedError("EPLB is only supported for FP8 "
"quantization for now.") # "quantization for now.")
if quant_config is None: if quant_config is None:
# Not considering quant for now, temporarily # Not considering quant for now, temporarily
......
...@@ -334,29 +334,59 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -334,29 +334,59 @@ class SlimQuantW4A8Int8MoEMethod:
def apply_ep( #dp+ep def apply_ep( #dp+ep
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, x: torch.Tensor,
tokens_per_expert: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
**_
) -> torch.Tensor: ) -> torch.Tensor:
return fused_experts_impl_w4a8_ep(hidden_states, from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
layer.w13_weight_scale, topk_weights=topk_weights,
layer.w2_weight_scale, topk_ids=topk_ids,
tokens_per_expert) inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe,
)
def apply(# tp def apply(# tp
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
tokens_per_expert: torch.Tensor, router_logits: torch.Tensor,
topk_weights: torch.Tensor, top_k: int,
topk_ids: torch.Tensor, renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
enable_eplb: bool = False, enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
**_ **_
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -364,6 +394,20 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -364,6 +394,20 @@ class SlimQuantW4A8Int8MoEMethod:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.") "EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.")
# Expert selection # Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate
)
return fused_experts( return fused_experts(
x, x,
......
...@@ -102,11 +102,6 @@ def with_amdsmi_context(fn): ...@@ -102,11 +102,6 @@ def with_amdsmi_context(fn):
def device_id_to_physical_device_id(device_id: int) -> int: def device_id_to_physical_device_id(device_id: int) -> int:
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
physical_device_id = device_ids[device_id]
return int(physical_device_id)
else:
return device_id return device_id
......
...@@ -441,7 +441,7 @@ class EagleProposer: ...@@ -441,7 +441,7 @@ class EagleProposer:
# [batch_size] # [batch_size]
num_accepted_tokens_tensor: torch.Tensor, num_accepted_tokens_tensor: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
cu_num_tokens = torch.arange(cu_target_query_lens.shape[0], device=cu_target_query_lens.device) cu_num_tokens = torch.arange(cu_target_query_lens.shape[0], device=cu_target_query_lens.device, dtype=torch.int32)
token_indices = num_accepted_tokens_tensor + cu_target_query_lens[:-1] token_indices = num_accepted_tokens_tensor + cu_target_query_lens[:-1]
return cu_num_tokens, token_indices return cu_num_tokens, token_indices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment