Commit a10c3e15 authored by one's avatar one
Browse files

Refactor environment variable handling in runner.py

parent 325db60e
...@@ -193,16 +193,67 @@ def __get_mode_command(self, benchmark_name, mode, timeout=None): ...@@ -193,16 +193,67 @@ def __get_mode_command(self, benchmark_name, mode, timeout=None):
','.join(f'{host}:{mode.proc_num}' for host in mode.host_list), ','.join(f'{host}:{mode.proc_num}' for host in mode.host_list),
bind_to=mode.bind_to, bind_to=mode.bind_to,
mca_list=' '.join(f'-mca {k} {v}' for k, v in mode.mca.items()), mca_list=' '.join(f'-mca {k} {v}' for k, v in mode.mca.items()),
env_list=' '.join( env_list=' '.join(self.__format_mpi_env_args(mode.env, mode.proc_rank, mode.proc_num)),
f'-x {k}={str(v).format(proc_rank=mode.proc_rank, proc_num=mode.proc_num)}'
if isinstance(v, str) else f'-x {k}' for k, v in mode.env.items()
),
command=exec_command, command=exec_command,
) )
else: else:
logger.warning('Unknown mode %s.', mode.name) logger.warning('Unknown mode %s.', mode.name)
return mode_command.strip() return mode_command.strip()
def __format_mode_env_value(self, value, proc_rank, proc_num):
"""Format mode env value.
Args:
value: Env value from config.
proc_rank (int): Process rank.
proc_num (int): Process count.
Returns:
str | None: Formatted value, or None to export existing env only.
"""
if value is None:
return None
if isinstance(value, str):
return value.format(proc_rank=proc_rank, proc_num=proc_num)
if isinstance(value, (int, float, bool)):
return str(value)
raise ValueError(f'Unsupported env value type: {type(value)}')
def __format_mpi_env_args(self, env_dict, proc_rank, proc_num):
"""Format env args for mpirun.
Args:
env_dict (DictConfig): Mode env config.
proc_rank (int): Process rank.
proc_num (int): Process count.
Returns:
list[str]: Formatted mpirun env args.
"""
env_args = []
for key, value in env_dict.items():
formatted_value = self.__format_mode_env_value(value, proc_rank, proc_num)
env_args.append(
f'-x {key}' if formatted_value is None else f'-x {self.__quote_env_assignment(key, formatted_value)}'
)
return env_args
def __quote_env_assignment(self, key, value):
"""Quote env assignment for shell command composition.
Use double quotes so the result can be embedded inside the existing
bash -lc '...' wrapper without breaking the outer single quotes.
Args:
key (str): Env key.
value (str): Env value.
Returns:
str: Double-quoted KEY=value assignment.
"""
escaped = str(value).replace('\\', '\\\\').replace('"', '\\"').replace('$', '\\$').replace('`', '\\`')
return f'"{key}={escaped}"'
def get_failure_count(self): def get_failure_count(self):
"""Get failure count during Ansible run. """Get failure count during Ansible run.
...@@ -489,9 +540,13 @@ def _run_proc(self, benchmark_name, mode, vars): ...@@ -489,9 +540,13 @@ def _run_proc(self, benchmark_name, mode, vars):
if self._docker_config.skip: if self._docker_config.skip:
env_list = 'set -o allexport && source /tmp/sb.env && set +o allexport' env_list = 'set -o allexport && source /tmp/sb.env && set +o allexport'
for k, v in mode.env.items(): for k, v in mode.env.items():
if isinstance(v, str): formatted_value = self.__format_mode_env_value(v, mode.proc_rank, mode.proc_num)
envvar = f'{k}={str(v).format(proc_rank=mode.proc_rank, proc_num=mode.proc_num)}' if formatted_value is not None:
env_list += f' -e {envvar}' if not self._docker_config.skip else f' && export {envvar}' envvar = self.__quote_env_assignment(k, formatted_value)
env_list += (
f' -e {envvar}'
if not self._docker_config.skip else f' && export {envvar}'
)
fcmd = "docker exec {env_list} sb-workspace bash -lc '{command}'" fcmd = "docker exec {env_list} sb-workspace bash -lc '{command}'"
if self._docker_config.skip: if self._docker_config.skip:
......
...@@ -212,6 +212,27 @@ def test_get_mode_command(self): ...@@ -212,6 +212,27 @@ def test_get_mode_command(self):
f'sb exec --output-dir {self.sb_output_dir} -c sb.config.yaml -C superbench.enable=foo' f'sb exec --output-dir {self.sb_output_dir} -c sb.config.yaml -C superbench.enable=foo'
), ),
}, },
{
'benchmark_name':
'foo',
'mode': {
'name': 'mpi',
'node_num': 1,
'proc_num': 4,
'proc_rank': 1,
'mca': {},
'env': {
'NCCL_BUFFSIZE': 4194304,
'NCCL_RINGS': '0 1 2 3|0 3 2 1',
'PATH': None,
},
},
'expected_command': (
"mpirun -tag-output -allow-run-as-root -host localhost:4 -bind-to numa "
'-x "NCCL_BUFFSIZE=4194304" -x "NCCL_RINGS=0 1 2 3|0 3 2 1" -x PATH '
f'sb exec --output-dir {self.sb_output_dir} -c sb.config.yaml -C superbench.enable=foo'
),
},
{ {
'benchmark_name': 'benchmark_name':
'foo', 'foo',
...@@ -453,3 +474,34 @@ def test_run_proc_timeout(self): ...@@ -453,3 +474,34 @@ def test_run_proc_timeout(self):
if isinstance(timeout, int): if isinstance(timeout, int):
timeout = max(timeout, 60) timeout = max(timeout, 60)
self.assertEqual(timeout, expected_timeout) self.assertEqual(timeout, expected_timeout)
@mock.patch('superbench.runner.ansible.AnsibleClient.run')
def test_run_proc_quotes_env_values(self, mock_ansible_client_run):
"""Test _run_proc quotes env values for docker exec and mpirun."""
mock_ansible_client_run.return_value = 0
self.runner._sb_benchmarks = {'foo': {}}
captured = {}
def fake_get_shell_config(cmd):
captured['cmd'] = cmd
return {'module_args': cmd, 'cmdline': '', 'host_pattern': 'localhost', 'module': 'shell'}
self.runner._ansible_client.get_shell_config = fake_get_shell_config
mode = OmegaConf.create({
'name': 'mpi',
'proc_num': 4,
'node_num': 1,
'mca': {},
'env': {
'NCCL_BUFFSIZE': 4194304,
'NCCL_RINGS': '0 1 2 3|0 3 2 1',
'PATH': None,
},
})
self.runner._run_proc('foo', mode, {'proc_rank': 0})
self.assertIn('-e "NCCL_BUFFSIZE=4194304"', captured['cmd'])
self.assertIn('-e "NCCL_RINGS=0 1 2 3|0 3 2 1"', captured['cmd'])
self.assertIn('-x "NCCL_BUFFSIZE=4194304"', captured['cmd'])
self.assertIn('-x "NCCL_RINGS=0 1 2 3|0 3 2 1"', captured['cmd'])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment