Commit faaa9ec3 authored by pengcheng888's avatar pengcheng888
Browse files

issue/80 模型文件夹名字改为为model_path,增加moore平台的参数

parent 7d222d83
......@@ -31,6 +31,16 @@ def get_args():
action="store_true",
help="Run metax test",
)
parser.add_argument(
"--moore",
action="store_true",
help="Run moore test",
)
parser.add_argument(
"--iluvatar",
action="store_true",
help="Run iluvatar test",
)
parser.add_argument(
"--model_path",
type=str,
......@@ -141,9 +151,13 @@ if __name__ == "__main__":
device_type = "cuda"
elif args.metax:
device_type = "cuda"
elif args.moore:
device_type = "musa"
elif args.iluvatar:
device_type = "cuda"
else:
print(
"Usage: python examples/llama.py [--cpu | --nvidia] --model_path=<path/to/model_dir>"
"Usage: python examples/llama.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_path=<path/to/model_dir>"
)
sys.exit(1)
......
......@@ -25,7 +25,7 @@ def get_args():
parser = argparse.ArgumentParser(description="Test Operator")
parser.add_argument(
"--model_dir",
"--model_path",
action="store",
help="The directory of the model to be tested",
)
......@@ -55,7 +55,7 @@ def get_args():
parser.add_argument(
"--iluvatar",
action="store_true",
help="Run moore test",
help="Run iluvatar test",
)
return parser.parse_args()
......@@ -400,7 +400,7 @@ if __name__ == "__main__":
args = get_args()
print(args)
model_dir = args.model_dir
model_path = args.model_path
dtype = torch.bfloat16
# Parse command line arguments
......@@ -417,7 +417,7 @@ if __name__ == "__main__":
device = "cuda"
else:
print(
"Usage: python test/qwen3_atteniton_test.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_dir=<path/to/model_dir>"
"Usage: python test/qwen3_atteniton_test.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_path=<path/to/model_path>"
)
sys.exit(1)
......@@ -425,7 +425,7 @@ if __name__ == "__main__":
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
model, rotary_emb = create_Qwen3attention_torch(
model_dir, device=device, dtype=dtype
model_path, device=device, dtype=dtype
)
print("\n")
print("*" * 130)
......
......@@ -25,7 +25,7 @@ def get_args():
parser = argparse.ArgumentParser(description="Test Operator")
parser.add_argument(
"--model_dir",
"--model_path",
action="store",
help="The directory of the model to be tested",
)
......@@ -55,7 +55,7 @@ def get_args():
parser.add_argument(
"--iluvatar",
action="store_true",
help="Run moore test",
help="Run iluvatar test",
)
return parser.parse_args()
......@@ -111,7 +111,7 @@ if __name__ == "__main__":
args = get_args()
print(args)
model_dir = args.model_dir
model_path = args.model_path
dtype = torch.bfloat16
# Parse command line arguments
device = "cpu"
......@@ -127,7 +127,7 @@ if __name__ == "__main__":
device = "cuda"
else:
print(
"Usage: python test/qwen3_atteniton_test.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_dir=<path/to/model_dir>"
"Usage: python test/qwen3_moe_test.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_path=<path/to/model_path>"
)
sys.exit(1)
......@@ -135,7 +135,7 @@ if __name__ == "__main__":
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
moe = create_moe_torch(model_dir, device=device, dtype=dtype)
moe = create_moe_torch(model_path, device=device, dtype=dtype)
print("*" * 130)
print("Test Qwen3 MoE")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment