moe_test.py 4.66 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import time
import torch
import transformers
import safetensors
import os
from transformers import AutoConfig
from transformers.models import qwen3_moe
import sys

WARMUPS = 10
RUNS = 100
PREFILL_TESTCASES = {"seqlens": [64, 128, 256, 256], "pastlens": [512, 0, 0, 256]}

DECODE_TESTCASES = {
    "seqlens": [1 for _ in range(16)],
    "pastlens": [50 for _ in range(4)]
    + [100 for _ in range(4)]
    + [200 for _ in range(4)]
    + [400 for _ in range(4)],
}


def get_args():
    import argparse

    parser = argparse.ArgumentParser(description="Test Operator")
    parser.add_argument(
28
        "--model_path",
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        action="store",
        help="The directory of the model to be tested",
    )

    parser.add_argument(
        "--cpu",
        action="store_true",
        help="Run cpu test",
    )

    parser.add_argument(
        "--nvidia",
        action="store_true",
        help="Run nvidia test",
    )

    parser.add_argument(
        "--metax",
        action="store_true",
        help="Run metax test",
    )
    parser.add_argument(
        "--moore",
        action="store_true",
        help="Run moore test",
    )
    parser.add_argument(
        "--iluvatar",
        action="store_true",
58
        help="Run iluvatar test",
59
60
61
62
    )
    return parser.parse_args()


63
64
65
66
67
68
69
70
71
72
73
74
75
76
def torch_synchronize(_device):
    if _device == "cuda":
        torch.cuda.synchronize()
    elif _device == "musa":
        torch.musa.synchronize()


def torch_empty_cache(_device):
    if _device == "cuda":
        torch.cuda.empty_cache()
    elif _device == "musa":
        torch.musa.empty_cache()


77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
def create_moe_torch(dir_path, device, dtype=torch.bfloat16):
    config = AutoConfig.from_pretrained(dir_path)
    moe = qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock(config).to(
        device=device, dtype=dtype
    )
    tensors = {}
    for fname in sorted(os.listdir(dir_path)):
        if not fname.endswith(".safetensors"):
            continue
        fpath = os.path.join(dir_path, fname)
        with safetensors.safe_open(fpath, framework="pt") as f:
            for key in f.keys():
                if "model.layers.0.mlp." in key:
                    tensors[key[len("model.layers.0.mlp.") :]] = f.get_tensor(key)
        break
    moe.load_state_dict(tensors)
    return moe


def generate_moe_input_torch(testcase, dtype=torch.bfloat16):
    total_seqlen = sum(testcase["seqlens"])
    input_tensor = torch.rand((1, total_seqlen, 2048), device="cpu", dtype=dtype)
    return input_tensor


102
def benchmark_moe_torch(moe, testcase, device, dtype):
103
    """"""
104
    input_host = generate_moe_input_torch(testcase, dtype=dtype)
105
106
107
108
109
110
111
    input_device = input_host.to(device=device)

    output_device, _ = moe(input_device)
    output_host = output_device.to("cpu")

    for _ in range(WARMUPS):
        moe(input_device)
112
    torch_synchronize(device)
113
114
115
116

    start_time = time.time()
    for _ in range(RUNS):
        moe(input_device)
117
    torch_synchronize(device)
118
119
    end_time = time.time()

120
121
122
123
124
    total_time = end_time - start_time
    total_tokens = sum(testcase["seqlens"]) * RUNS
    print(
        f"\t WARMUPS={WARMUPS} RUNS={RUNS}, MoE Torch average latency: {round(total_time * 1000 / RUNS, 2)} ms   throughput: {round(total_tokens / total_time, 2)} tok/s"
    )
125
126
127
128
129
130
131
    return output_host


if __name__ == "__main__":
    args = get_args()
    print(args)

132
    model_path = args.model_path
133
134
135
136
137
138
139
140
141
142
143
    dtype = torch.bfloat16
    # Parse command line arguments
    device = "cpu"
    if args.cpu:
        device = "cpu"
    elif args.nvidia:
        device = "cuda"
    elif args.metax:
        device = "cuda"
    elif args.moore:
        device = "musa"
144
        import torch_musa
145
146
147
148
    elif args.iluvatar:
        device = "cuda"
    else:
        print(
149
            "Usage:  python test/models/qwen3_moe/moe_test.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_path=<path/to/model_path>"
150
151
152
153
154
155
156
        )
        sys.exit(1)

    # -----------------------------------------------------------------------------
    # -----------------------------------------------------------------------------
    # -----------------------------------------------------------------------------

157
    moe = create_moe_torch(model_path, device=device, dtype=dtype)
158
159
160
161
162

    print("*" * 130)
    print("Test Qwen3 MoE")
    print("*" * 130)
    print(f"Test Case PREFILL_TESTCASES : {PREFILL_TESTCASES}")
163
164
165
    output_prefill = benchmark_moe_torch(
        moe, PREFILL_TESTCASES, device=device, dtype=dtype
    )
166
167
168
169

    print("\n")
    print("-" * 130)
    print(f"\nTest DECODE_TESTCASES: {DECODE_TESTCASES}")
170
171
172
    output_decode = benchmark_moe_torch(
        moe, DECODE_TESTCASES, device=device, dtype=dtype
    )
173
174
175

    # clean up device memory
    del moe
176
    torch_empty_cache(device)