Unverified Commit 1e0e0808 authored by littsk's avatar littsk Committed by GitHub
Browse files

[bug] Fix the version check bug in colossalai run when generating the cmd. (#4713)

* Fix the version check bug in colossalai run when generating the cmd.

* polish code
parent 3e05c07b
......@@ -156,7 +156,8 @@ def get_launch_command(
torch_version = version.parse(torch.__version__)
assert torch_version.major >= 1
if torch_version.minor < 9:
if torch_version.major == 1 and torch_version.minor < 9:
# torch distributed launch cmd with torch < 1.9
cmd = [
sys.executable,
"-m",
......@@ -177,7 +178,8 @@ def get_launch_command(
value = extra_launch_args.pop(key)
default_torchrun_rdzv_args[key] = value
if torch_version.minor < 10:
if torch_version.major == 1 and torch_version.minor == 9:
# torch distributed launch cmd with torch == 1.9
cmd = [
sys.executable,
"-m",
......@@ -187,6 +189,7 @@ def get_launch_command(
f"--node_rank={node_rank}",
]
else:
# torch distributed launch cmd with torch > 1.9
cmd = [
"torchrun",
f"--nproc_per_node={nproc_per_node}",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment