"examples/vscode:/vscode.git/clone" did not exist on "4664019ac3a7e3a82bb369b2ea6c1cac3c57b261"
Commit faaa9ec3 authored by pengcheng888's avatar pengcheng888
Browse files

issue/80 模型文件夹名字改为为model_path,增加moore平台的参数

parent 7d222d83
...@@ -31,6 +31,16 @@ def get_args(): ...@@ -31,6 +31,16 @@ def get_args():
action="store_true", action="store_true",
help="Run metax test", help="Run metax test",
) )
parser.add_argument(
"--moore",
action="store_true",
help="Run moore test",
)
parser.add_argument(
"--iluvatar",
action="store_true",
help="Run iluvatar test",
)
parser.add_argument( parser.add_argument(
"--model_path", "--model_path",
type=str, type=str,
...@@ -141,9 +151,13 @@ if __name__ == "__main__": ...@@ -141,9 +151,13 @@ if __name__ == "__main__":
device_type = "cuda" device_type = "cuda"
elif args.metax: elif args.metax:
device_type = "cuda" device_type = "cuda"
elif args.moore:
device_type = "musa"
elif args.iluvatar:
device_type = "cuda"
else: else:
print( print(
"Usage: python examples/llama.py [--cpu | --nvidia] --model_path=<path/to/model_dir>" "Usage: python examples/llama.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_path=<path/to/model_dir>"
) )
sys.exit(1) sys.exit(1)
......
...@@ -25,7 +25,7 @@ def get_args(): ...@@ -25,7 +25,7 @@ def get_args():
parser = argparse.ArgumentParser(description="Test Operator") parser = argparse.ArgumentParser(description="Test Operator")
parser.add_argument( parser.add_argument(
"--model_dir", "--model_path",
action="store", action="store",
help="The directory of the model to be tested", help="The directory of the model to be tested",
) )
...@@ -55,7 +55,7 @@ def get_args(): ...@@ -55,7 +55,7 @@ def get_args():
parser.add_argument( parser.add_argument(
"--iluvatar", "--iluvatar",
action="store_true", action="store_true",
help="Run moore test", help="Run iluvatar test",
) )
return parser.parse_args() return parser.parse_args()
...@@ -400,7 +400,7 @@ if __name__ == "__main__": ...@@ -400,7 +400,7 @@ if __name__ == "__main__":
args = get_args() args = get_args()
print(args) print(args)
model_dir = args.model_dir model_path = args.model_path
dtype = torch.bfloat16 dtype = torch.bfloat16
# Parse command line arguments # Parse command line arguments
...@@ -417,7 +417,7 @@ if __name__ == "__main__": ...@@ -417,7 +417,7 @@ if __name__ == "__main__":
device = "cuda" device = "cuda"
else: else:
print( print(
"Usage: python test/qwen3_atteniton_test.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_dir=<path/to/model_dir>" "Usage: python test/qwen3_atteniton_test.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_path=<path/to/model_path>"
) )
sys.exit(1) sys.exit(1)
...@@ -425,7 +425,7 @@ if __name__ == "__main__": ...@@ -425,7 +425,7 @@ if __name__ == "__main__":
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
model, rotary_emb = create_Qwen3attention_torch( model, rotary_emb = create_Qwen3attention_torch(
model_dir, device=device, dtype=dtype model_path, device=device, dtype=dtype
) )
print("\n") print("\n")
print("*" * 130) print("*" * 130)
......
...@@ -25,7 +25,7 @@ def get_args(): ...@@ -25,7 +25,7 @@ def get_args():
parser = argparse.ArgumentParser(description="Test Operator") parser = argparse.ArgumentParser(description="Test Operator")
parser.add_argument( parser.add_argument(
"--model_dir", "--model_path",
action="store", action="store",
help="The directory of the model to be tested", help="The directory of the model to be tested",
) )
...@@ -55,7 +55,7 @@ def get_args(): ...@@ -55,7 +55,7 @@ def get_args():
parser.add_argument( parser.add_argument(
"--iluvatar", "--iluvatar",
action="store_true", action="store_true",
help="Run moore test", help="Run iluvatar test",
) )
return parser.parse_args() return parser.parse_args()
...@@ -111,7 +111,7 @@ if __name__ == "__main__": ...@@ -111,7 +111,7 @@ if __name__ == "__main__":
args = get_args() args = get_args()
print(args) print(args)
model_dir = args.model_dir model_path = args.model_path
dtype = torch.bfloat16 dtype = torch.bfloat16
# Parse command line arguments # Parse command line arguments
device = "cpu" device = "cpu"
...@@ -127,7 +127,7 @@ if __name__ == "__main__": ...@@ -127,7 +127,7 @@ if __name__ == "__main__":
device = "cuda" device = "cuda"
else: else:
print( print(
"Usage: python test/qwen3_atteniton_test.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_dir=<path/to/model_dir>" "Usage: python test/qwen3_moe_test.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_path=<path/to/model_path>"
) )
sys.exit(1) sys.exit(1)
...@@ -135,7 +135,7 @@ if __name__ == "__main__": ...@@ -135,7 +135,7 @@ if __name__ == "__main__":
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
moe = create_moe_torch(model_dir, device=device, dtype=dtype) moe = create_moe_torch(model_path, device=device, dtype=dtype)
print("*" * 130) print("*" * 130)
print("Test Qwen3 MoE") print("Test Qwen3 MoE")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment