Commit bc9aee38 authored by zhuwenwen's avatar zhuwenwen
Browse files

update op.moe_fused_gate

parent a54ab95d
...@@ -230,31 +230,31 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, ...@@ -230,31 +230,31 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
# @pytest.mark.parametrize("n_heads", [4, 8, 13]) # @pytest.mark.parametrize("n_heads", [4, 8, 13])
# @pytest.mark.parametrize("d_head", [5, 16, 21, 32]) # @pytest.mark.parametrize("d_head", [5, 16, 21, 32])
# @pytest.mark.parametrize( # @pytest.mark.parametrize(
"seq_len_chunk_size_cases", # "seq_len_chunk_size_cases",
[ # [
# small-ish chunk_size (8) # # small-ish chunk_size (8)
(64, 8, 2, [(64, 32), (64, 32)]), # (64, 8, 2, [(64, 32), (64, 32)]),
(64, 8, 2, [(32, 32), (32, 32), (32, 32)]), # (64, 8, 2, [(32, 32), (32, 32), (32, 32)]),
(64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary # (64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary
(64, 8, 2, [(4, 4), (4, 4), (4, 4), # (64, 8, 2, [(4, 4), (4, 4), (4, 4),
(4, 4)]), # chunk_size larger than cont batches # (4, 4)]), # chunk_size larger than cont batches
(64, 8, 5, [ # (64, 8, 5, [
(64, 32, 16, 8, 8), # (64, 32, 16, 8, 8),
(8, 16, 32, 16, 8), # (8, 16, 32, 16, 8),
(8, 8, 16, 32, 16), # (8, 8, 16, 32, 16),
]), # mode examples with varied lengths # ]), # mode examples with varied lengths
# odd chunk_size # # odd chunk_size
(64, 29, 2, [(11, 4), (13, 23), (19, 22), # (64, 29, 2, [(11, 4), (13, 23), (19, 22),
(21, 15)]), # irregular sizes # (21, 15)]), # irregular sizes
# large-ish chunk_size (256) # # large-ish chunk_size (256)
(64, 256, 1, [(5, ), (1, ), (1, ), # (64, 256, 1, [(5, ), (1, ), (1, ),
(1, )]), # irregular sizes with small sequences # (1, )]), # irregular sizes with small sequences
(64, 256, 2, [(5, 30), (1, 2), (1, 2), # (64, 256, 2, [(5, 30), (1, 2), (1, 2),
(1, 2)]), # irregular sizes with small sequences # (1, 2)]), # irregular sizes with small sequences
]) # ]
# def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, # def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
# itype): # itype):
......
...@@ -1285,19 +1285,19 @@ class FusedMoE(torch.nn.Module): ...@@ -1285,19 +1285,19 @@ class FusedMoE(torch.nn.Module):
num_expert_group, num_expert_group,
topk_group, topk_group,
top_k, top_k,
routed_scaling_factor=routed_scaling_factor, 0,
n_share_experts_fusion=0, routed_scaling_factor,
) )
else: else:
topk_weights, topk_ids = ops.moe_fused_gate( topk_weights, topk_ids = ops.moe_fused_gate(
router_logits, router_logits,
e_score_correction_bias, e_score_correction_bias,
num_expert_group, num_expert_group,
topk_group, topk_group,
top_k, top_k,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
n_share_experts_fusion=0, n_share_experts_fusion=0,
) )
else: else:
topk_weights, topk_ids = grouped_topk( topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states, hidden_states=hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment