qwen3_moe_test.py 4.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import time
import torch
import transformers
import safetensors
import os
from transformers import AutoConfig
from transformers.models import qwen3_moe
import sys

WARMUPS = 10
RUNS = 100
PREFILL_TESTCASES = {"seqlens": [64, 128, 256, 256], "pastlens": [512, 0, 0, 256]}

DECODE_TESTCASES = {
    "seqlens": [1 for _ in range(16)],
    "pastlens": [50 for _ in range(4)]
    + [100 for _ in range(4)]
    + [200 for _ in range(4)]
    + [400 for _ in range(4)],
}


def get_args():
    import argparse

    parser = argparse.ArgumentParser(description="Test Operator")
    parser.add_argument(
28
        "--model_path",
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        action="store",
        help="The directory of the model to be tested",
    )

    parser.add_argument(
        "--cpu",
        action="store_true",
        help="Run cpu test",
    )

    parser.add_argument(
        "--nvidia",
        action="store_true",
        help="Run nvidia test",
    )

    parser.add_argument(
        "--metax",
        action="store_true",
        help="Run metax test",
    )
    parser.add_argument(
        "--moore",
        action="store_true",
        help="Run moore test",
    )
    parser.add_argument(
        "--iluvatar",
        action="store_true",
58
        help="Run iluvatar test",
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    )
    return parser.parse_args()


def create_moe_torch(dir_path, device, dtype=torch.bfloat16):
    config = AutoConfig.from_pretrained(dir_path)
    moe = qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock(config).to(
        device=device, dtype=dtype
    )
    tensors = {}
    for fname in sorted(os.listdir(dir_path)):
        if not fname.endswith(".safetensors"):
            continue
        fpath = os.path.join(dir_path, fname)
        with safetensors.safe_open(fpath, framework="pt") as f:
            for key in f.keys():
                if "model.layers.0.mlp." in key:
                    tensors[key[len("model.layers.0.mlp.") :]] = f.get_tensor(key)
        break
    moe.load_state_dict(tensors)
    return moe


def generate_moe_input_torch(testcase, dtype=torch.bfloat16):
    total_seqlen = sum(testcase["seqlens"])
    input_tensor = torch.rand((1, total_seqlen, 2048), device="cpu", dtype=dtype)
    return input_tensor


def benchmark_moe_torch(moe, input_host, device, dtype):
    """"""

    input_device = input_host.to(device=device)

    output_device, _ = moe(input_device)
    output_host = output_device.to("cpu")

    for _ in range(WARMUPS):
        moe(input_device)
    torch.cuda.synchronize()

    start_time = time.time()
    for _ in range(RUNS):
        moe(input_device)
    torch.cuda.synchronize()
    end_time = time.time()

    print(f"    MoE Torch average latency: {(end_time - start_time) * 1000 / RUNS} ms")
    return output_host


if __name__ == "__main__":
    args = get_args()
    print(args)

114
    model_path = args.model_path
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    dtype = torch.bfloat16
    # Parse command line arguments
    device = "cpu"
    if args.cpu:
        device = "cpu"
    elif args.nvidia:
        device = "cuda"
    elif args.metax:
        device = "cuda"
    elif args.moore:
        device = "musa"
    elif args.iluvatar:
        device = "cuda"
    else:
        print(
130
            "Usage:  python test/qwen3_moe_test.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_path=<path/to/model_path>"
131
132
133
134
135
136
137
        )
        sys.exit(1)

    # -----------------------------------------------------------------------------
    # -----------------------------------------------------------------------------
    # -----------------------------------------------------------------------------

138
    moe = create_moe_torch(model_path, device=device, dtype=dtype)
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

    print("*" * 130)
    print("Test Qwen3 MoE")
    print("*" * 130)
    print(f"Test Case PREFILL_TESTCASES : {PREFILL_TESTCASES}")

    input_prefill = generate_moe_input_torch(PREFILL_TESTCASES)
    output_prefill = benchmark_moe_torch(moe, input_prefill, device=device, dtype=dtype)

    print("\n")
    print("-" * 130)
    print(f"\nTest DECODE_TESTCASES: {DECODE_TESTCASES}")
    input_decode = generate_moe_input_torch(DECODE_TESTCASES)
    output_decode = benchmark_moe_torch(moe, input_decode, device=device, dtype=dtype)

    # clean up device memory
    del moe
    torch.cuda.empty_cache()