test_fused_experts_relu2.py 2.91 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import argparse

import torch
import torch.nn.functional as F

from aiter.test_common import checkAllclose
from aiter.ops.triton.fused_moe import fused_experts_impl


def torch_fused_experts_relu2(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
) -> torch.Tensor:
    num_tokens, hidden_size = hidden_states.shape
    topk = topk_ids.shape[1]
    out = torch.zeros(
        (num_tokens, hidden_size),
        dtype=hidden_states.dtype,
        device=hidden_states.device,
    )

    for token_idx in range(num_tokens):
        token_out = torch.zeros(
            (hidden_size,),
            dtype=hidden_states.dtype,
            device=hidden_states.device,
        )
        x = hidden_states[token_idx]
        for k in range(topk):
            expert_id = int(topk_ids[token_idx, k].item())
            weight = topk_weights[token_idx, k]
            up = x @ w1[expert_id].transpose(0, 1)
            act = torch.square(F.relu(up))
            down = act @ w2[expert_id].transpose(0, 1)
            token_out = token_out + weight * down
        out[token_idx] = token_out
    return out


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--m", type=int, default=32)
    parser.add_argument("--hidden_size", type=int, default=128)
    parser.add_argument("--intermediate_size", type=int, default=256)
    parser.add_argument("--num_experts", type=int, default=4)
    parser.add_argument("--topk", type=int, default=1)
    parser.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16"])
    args = parser.parse_args()

    torch.manual_seed(0)
    dtype = torch.float16 if args.dtype == "fp16" else torch.bfloat16
    device = "cuda"

    hidden_states = torch.randn(args.m, args.hidden_size, dtype=dtype, device=device) / 10
    w1 = torch.randn(
        args.num_experts,
        args.intermediate_size,
        args.hidden_size,
        dtype=dtype,
        device=device,
    ) / 10
    w2 = torch.randn(
        args.num_experts,
        args.hidden_size,
        args.intermediate_size,
        dtype=dtype,
        device=device,
    ) / 10

    scores = torch.randn(args.m, args.num_experts, dtype=torch.float32, device=device)
    topk_weights, topk_ids = torch.topk(scores, k=args.topk, dim=-1)
    topk_weights = torch.softmax(topk_weights, dim=-1).to(torch.float32)
    topk_ids = topk_ids.to(torch.int32)

    ref = torch_fused_experts_relu2(hidden_states, w1, w2, topk_weights, topk_ids)
    out = fused_experts_impl(
        hidden_states,
        w1,
        w2,
        topk_weights,
        topk_ids,
        dtype,
        False,
        activation="relu2",
    )

    err = checkAllclose(ref, out, rtol=2e-2, atol=2e-2, printLog=True)
    if err:
        raise AssertionError(f"relu2 fused_experts mismatch, err_ratio={err}")
    print("relu2 fused_experts test passed")


if __name__ == "__main__":
    main()