# expert_mask is of size (self.num_experts_per_partition + 1),
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
# for instance, if we have 4 experts on this rank, we would have a expert_mask like:
# self.expert_mask = [1, 1, 1, 1, 0]
# idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
self.expert_mask=torch.zeros(
(self.num_experts_per_partition+1),
device=torch.cuda.current_device(),
dtype=torch.int,
)
# the last one is invalid rank_id
self.expert_mask[:-1]=1
else:
self.w13_weight_fp8=(
self.w13_weight,
(
self.w13_weight_scale_inv
ifself.use_block_quant
elseself.w13_weight_scale
),
)
self.w2_weight_fp8=(
self.w2_weight,
(
self.w2_weight_scale_inv
ifself.use_block_quant
elseself.w2_weight_scale
),
)
defforward(
self,
...
...
@@ -1142,6 +1175,9 @@ class DeepEPMoE(EPMoE):
num_recv_tokens_per_expert:List[int],
forward_mode:ForwardMode,
):
if_use_aiter:
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel