Commit 02676be8 authored by PanZezhong's avatar PanZezhong
Browse files

issue/80 增加对musa同步支持

parent d7d0889d
...@@ -60,6 +60,20 @@ def get_args(): ...@@ -60,6 +60,20 @@ def get_args():
return parser.parse_args() return parser.parse_args()
def torch_synchronize(_device):
if _device == "cuda":
torch.cuda.synchronize()
elif _device == "musa":
torch.musa.synchronize()
def torch_empty_cache(_device):
if _device == "cuda":
torch.cuda.empty_cache()
elif _device == "musa":
torch.musa.empty_cache()
def create_Qwen3attention_torch(dir_path, *, device, dtype=torch.bfloat16): def create_Qwen3attention_torch(dir_path, *, device, dtype=torch.bfloat16):
config = AutoConfig.from_pretrained(dir_path) config = AutoConfig.from_pretrained(dir_path)
config.num_hidden_layers = 1 config.num_hidden_layers = 1
...@@ -128,12 +142,16 @@ def generate_attention_input_torch( ...@@ -128,12 +142,16 @@ def generate_attention_input_torch(
return req_list return req_list
def benchmark_Qwen3attention_prefill_torch(model, rotary_emb, req_list, test_cases): def benchmark_Qwen3attention_prefill_torch(
model, rotary_emb, test_cases, device, dtype=torch.bfloat16
):
""" """
Test Qwen3attention. Test Qwen3attention.
""" """
req_list = generate_attention_input_torch(
model, rotary_emb, test_cases, device, dtype=dtype
)
req_out_list = [] req_out_list = []
for req in req_list: for req in req_list:
# ----------------------------------------- # # ----------------------------------------- #
...@@ -172,7 +190,7 @@ def benchmark_Qwen3attention_prefill_torch(model, rotary_emb, req_list, test_cas ...@@ -172,7 +190,7 @@ def benchmark_Qwen3attention_prefill_torch(model, rotary_emb, req_list, test_cas
output_host = output_device.to("cpu") output_host = output_device.to("cpu")
req_out_list.append(output_host) req_out_list.append(output_host)
torch.cuda.synchronize() torch_synchronize(device)
for _ in range(WARMUPS): for _ in range(WARMUPS):
for i, req in enumerate(req_list): for i, req in enumerate(req_list):
...@@ -223,7 +241,7 @@ def benchmark_Qwen3attention_prefill_torch(model, rotary_emb, req_list, test_cas ...@@ -223,7 +241,7 @@ def benchmark_Qwen3attention_prefill_torch(model, rotary_emb, req_list, test_cas
origin_len = test_cases["pastlens"][i] origin_len = test_cases["pastlens"][i]
req["past_key_values"].crop(origin_len) req["past_key_values"].crop(origin_len)
torch.cuda.synchronize() torch_synchronize(device)
# ----------------------------------------- # # ----------------------------------------- #
# 重要:每个req都按整个batch的起始时间计算 # 重要:每个req都按整个batch的起始时间计算
# ----------------------------------------- # # ----------------------------------------- #
...@@ -260,7 +278,7 @@ def benchmark_Qwen3attention_prefill_torch(model, rotary_emb, req_list, test_cas ...@@ -260,7 +278,7 @@ def benchmark_Qwen3attention_prefill_torch(model, rotary_emb, req_list, test_cas
past_key_values=past_key_values, past_key_values=past_key_values,
) )
torch.cuda.synchronize() torch_synchronize(device)
end_time = time.time() end_time = time.time()
# 记录每个req从进入所有req进入推理到自己结束的时间 # 记录每个req从进入所有req进入推理到自己结束的时间
...@@ -277,10 +295,15 @@ def benchmark_Qwen3attention_prefill_torch(model, rotary_emb, req_list, test_cas ...@@ -277,10 +295,15 @@ def benchmark_Qwen3attention_prefill_torch(model, rotary_emb, req_list, test_cas
return req_out_list return req_out_list
def benchmark_Qwen3attention_decode_torch(model, rotary_emb, req_list, test_cases): def benchmark_Qwen3attention_decode_torch(
model, rotary_emb, test_cases, device, dtype=torch.bfloat16
):
""" """
Test Qwen3attention_decode. Test Qwen3attention_decode.
""" """
req_list = generate_attention_input_torch(
model, rotary_emb, test_cases, device, dtype=dtype
)
req_out_list = [] req_out_list = []
for req in req_list: for req in req_list:
# ----------------------------------------- # # ----------------------------------------- #
...@@ -314,7 +337,7 @@ def benchmark_Qwen3attention_decode_torch(model, rotary_emb, req_list, test_case ...@@ -314,7 +337,7 @@ def benchmark_Qwen3attention_decode_torch(model, rotary_emb, req_list, test_case
req_out_list.append(output_host) req_out_list.append(output_host)
torch.cuda.synchronize() torch_synchronize(device)
for req in req_list: for req in req_list:
for _ in range(WARMUPS): for _ in range(WARMUPS):
...@@ -353,7 +376,7 @@ def benchmark_Qwen3attention_decode_torch(model, rotary_emb, req_list, test_case ...@@ -353,7 +376,7 @@ def benchmark_Qwen3attention_decode_torch(model, rotary_emb, req_list, test_case
origin_len = test_cases["pastlens"][i] origin_len = test_cases["pastlens"][i]
req["past_key_values"].crop(origin_len) req["past_key_values"].crop(origin_len)
torch.cuda.synchronize() torch_synchronize(device)
start_time = time.time() start_time = time.time()
for i, req in enumerate(req_list): for i, req in enumerate(req_list):
...@@ -393,7 +416,7 @@ def benchmark_Qwen3attention_decode_torch(model, rotary_emb, req_list, test_case ...@@ -393,7 +416,7 @@ def benchmark_Qwen3attention_decode_torch(model, rotary_emb, req_list, test_case
# -------------------------------------------------------------- # # -------------------------------------------------------------- #
req["hidden_states"] = output_device req["hidden_states"] = output_device
torch.cuda.synchronize() torch_synchronize(device)
end_time = time.time() end_time = time.time()
time_consuming = end_time - start_time time_consuming = end_time - start_time
...@@ -425,11 +448,12 @@ if __name__ == "__main__": ...@@ -425,11 +448,12 @@ if __name__ == "__main__":
device = "cuda" device = "cuda"
elif args.moore: elif args.moore:
device = "musa" device = "musa"
import torch_musa
elif args.iluvatar: elif args.iluvatar:
device = "cuda" device = "cuda"
else: else:
print( print(
"Usage: python test/qwen3_atteniton_test.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_path=<path/to/model_path>" "Usage: python test/models/qwen3_moe/attention_test.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_path=<path/to/model_path>"
) )
sys.exit(1) sys.exit(1)
...@@ -444,26 +468,17 @@ if __name__ == "__main__": ...@@ -444,26 +468,17 @@ if __name__ == "__main__":
print("Test Qwen3attention ") print("Test Qwen3attention ")
print("*" * 130) print("*" * 130)
print(f"Test Case PREFILL_TESTCASES : {PREFILL_TESTCASES}") print(f"Test Case PREFILL_TESTCASES : {PREFILL_TESTCASES}")
req_list = generate_attention_input_torch(
model, rotary_emb, PREFILL_TESTCASES, device, dtype=dtype
)
output_prefill = benchmark_Qwen3attention_prefill_torch( output_prefill = benchmark_Qwen3attention_prefill_torch(
model, rotary_emb, req_list, PREFILL_TESTCASES model, rotary_emb, PREFILL_TESTCASES, device, dtype=dtype
) )
print("\n") print("\n")
print("-" * 130) print("-" * 130)
print(f"\nTest DECODE_TESTCASES: {DECODE_TESTCASES}") print(f"\nTest DECODE_TESTCASES: {DECODE_TESTCASES}")
#
req_list = generate_attention_input_torch(
model, rotary_emb, DECODE_TESTCASES, device, dtype=dtype
)
output_decode = benchmark_Qwen3attention_decode_torch( output_decode = benchmark_Qwen3attention_decode_torch(
model, rotary_emb, req_list, DECODE_TESTCASES model, rotary_emb, DECODE_TESTCASES, device, dtype=dtype
) )
# clean up device memory # clean up device memory
del model del model
torch.cuda.empty_cache() torch_empty_cache(device)
...@@ -60,6 +60,20 @@ def get_args(): ...@@ -60,6 +60,20 @@ def get_args():
return parser.parse_args() return parser.parse_args()
def torch_synchronize(_device):
if _device == "cuda":
torch.cuda.synchronize()
elif _device == "musa":
torch.musa.synchronize()
def torch_empty_cache(_device):
if _device == "cuda":
torch.cuda.empty_cache()
elif _device == "musa":
torch.musa.empty_cache()
def create_moe_torch(dir_path, device, dtype=torch.bfloat16): def create_moe_torch(dir_path, device, dtype=torch.bfloat16):
config = AutoConfig.from_pretrained(dir_path) config = AutoConfig.from_pretrained(dir_path)
moe = qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock(config).to( moe = qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock(config).to(
...@@ -95,12 +109,12 @@ def benchmark_moe_torch(moe, testcase, device, dtype): ...@@ -95,12 +109,12 @@ def benchmark_moe_torch(moe, testcase, device, dtype):
for _ in range(WARMUPS): for _ in range(WARMUPS):
moe(input_device) moe(input_device)
torch.cuda.synchronize() torch_synchronize(device)
start_time = time.time() start_time = time.time()
for _ in range(RUNS): for _ in range(RUNS):
moe(input_device) moe(input_device)
torch.cuda.synchronize() torch_synchronize(device)
end_time = time.time() end_time = time.time()
total_time = end_time - start_time total_time = end_time - start_time
...@@ -127,11 +141,12 @@ if __name__ == "__main__": ...@@ -127,11 +141,12 @@ if __name__ == "__main__":
device = "cuda" device = "cuda"
elif args.moore: elif args.moore:
device = "musa" device = "musa"
import torch_musa
elif args.iluvatar: elif args.iluvatar:
device = "cuda" device = "cuda"
else: else:
print( print(
"Usage: python test/qwen3_moe_test.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_path=<path/to/model_path>" "Usage: python test/models/qwen3_moe/moe_test.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_path=<path/to/model_path>"
) )
sys.exit(1) sys.exit(1)
...@@ -158,4 +173,4 @@ if __name__ == "__main__": ...@@ -158,4 +173,4 @@ if __name__ == "__main__":
# clean up device memory # clean up device memory
del moe del moe
torch.cuda.empty_cache() torch_empty_cache(device)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment