"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "def95df2220d34b6cc9d80f671dfa11fc4b9447f"
Unverified Commit 179d6aab authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Distributed] Allow user to pass-in extra env parameters when launching a...


[Distributed] Allow user to pass-in extra env parameters when launching a distributed training task. (#3375)

* Allow user to pass-in extra env parameters when launching a distributed training task.

* Update

* upd
Co-authored-by: default avatarxiangsx <xiangsx@ip-10-3-59-214.eu-west-1.compute.internal>
parent 367a3a34
...@@ -405,6 +405,24 @@ def wrap_cmd_with_local_envvars(cmd: str, env_vars: str) -> str: ...@@ -405,6 +405,24 @@ def wrap_cmd_with_local_envvars(cmd: str, env_vars: str) -> str:
# https://stackoverflow.com/a/45993803 # https://stackoverflow.com/a/45993803
return f"(export {env_vars}; {cmd})" return f"(export {env_vars}; {cmd})"
def wrap_cmd_with_extra_envvars(cmd: str, env_vars: list) -> str:
"""Wraps a CLI command with extra env vars
Example:
>>> cmd = "ls && pwd"
>>> env_vars = ["VAR1=value1", "VAR2=value2"]
>>> wrap_cmd_with_extra_envvars(cmd, env_vars)
"(export VAR1=value1 VAR2=value2; ls && pwd)"
Args:
cmd:
env_vars: A list of strings containing env vars, e.g., ["VAR1=value1", "VAR2=value2"]
Returns:
cmd_with_env_vars:
"""
env_vars = " ".join(env_vars)
return wrap_cmd_with_local_envvars(cmd, env_vars)
def submit_jobs(args, udf_command): def submit_jobs(args, udf_command):
"""Submit distributed jobs (server and client processes) via ssh""" """Submit distributed jobs (server and client processes) via ssh"""
...@@ -453,6 +471,7 @@ def submit_jobs(args, udf_command): ...@@ -453,6 +471,7 @@ def submit_jobs(args, udf_command):
ip, _ = hosts[int(i / server_count_per_machine)] ip, _ = hosts[int(i / server_count_per_machine)]
server_env_vars_cur = f"{server_env_vars} DGL_SERVER_ID={i}" server_env_vars_cur = f"{server_env_vars} DGL_SERVER_ID={i}"
cmd = wrap_cmd_with_local_envvars(udf_command, server_env_vars_cur) cmd = wrap_cmd_with_local_envvars(udf_command, server_env_vars_cur)
cmd = wrap_cmd_with_extra_envvars(cmd, args.extra_envs) if len(args.extra_envs) > 0 else cmd
cmd = 'cd ' + str(args.workspace) + '; ' + cmd cmd = 'cd ' + str(args.workspace) + '; ' + cmd
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username)) thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username))
...@@ -480,6 +499,7 @@ def submit_jobs(args, udf_command): ...@@ -480,6 +499,7 @@ def submit_jobs(args, udf_command):
master_port=1234, master_port=1234,
) )
cmd = wrap_cmd_with_local_envvars(torch_dist_udf_command, client_env_vars) cmd = wrap_cmd_with_local_envvars(torch_dist_udf_command, client_env_vars)
cmd = wrap_cmd_with_extra_envvars(cmd, args.extra_envs) if len(args.extra_envs) > 0 else cmd
cmd = 'cd ' + str(args.workspace) + '; ' + cmd cmd = 'cd ' + str(args.workspace) + '; ' + cmd
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username)) thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username))
...@@ -536,6 +556,10 @@ def main(): ...@@ -536,6 +556,10 @@ def main():
help='The format of the graph structure of each partition. \ help='The format of the graph structure of each partition. \
The allowed formats are csr, csc and coo. A user can specify multiple \ The allowed formats are csr, csc and coo. A user can specify multiple \
formats, separated by ",". For example, the graph format is "csr,csc".') formats, separated by ",". For example, the graph format is "csr,csc".')
parser.add_argument('--extra_envs', nargs='+', type=str, default=[],
help='Extra environment parameters need to be set. For example, \
you can set the LD_LIBRARY_PATH and NCCL_DEBUG by adding: \
--extra_envs LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH NCCL_DEBUG=INFO ')
args, udf_command = parser.parse_known_args() args, udf_command = parser.parse_known_args()
assert len(udf_command) == 1, 'Please provide user command line.' assert len(udf_command) == 1, 'Please provide user command line.'
assert args.num_trainers is not None and args.num_trainers > 0, \ assert args.num_trainers is not None and args.num_trainers > 0, \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment