test_moe_align.py 1.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from sgl_kernel import moe_align_block_size


def test_moe_align_block_size():
    num_experts = 256
    block_size = 128
    topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda")

    max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
    sorted_ids = torch.empty(
        (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
    )
    sorted_ids.fill_(topk_ids.numel())
    max_num_m_blocks = max_num_tokens_padded // block_size
    expert_ids = torch.empty(
        (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
    )
    num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)

HandH1998's avatar
HandH1998 committed
21
22
23
24
25
26
27
    token_cnts_buffer = torch.empty(
        (num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
    )
    cumsum_buffer = torch.empty(
        num_experts + 1, dtype=torch.int32, device=topk_ids.device
    )

28
    moe_align_block_size(
HandH1998's avatar
HandH1998 committed
29
30
31
32
33
34
35
36
        topk_ids,
        num_experts,
        block_size,
        sorted_ids,
        expert_ids,
        num_tokens_post_pad,
        token_cnts_buffer,
        cumsum_buffer,
37
38
39
40
    )


test_moe_align_block_size()