runner.py 27.6 KB
Newer Older
1
2
3
4
5
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""SuperBench Runner."""

6
import os
7
import sys
8
import json
9
import random
10
import signal
11
from pathlib import Path
12
from pprint import pformat
13
from collections import defaultdict
14

15
import jsonlines
16
from natsort import natsorted
17
from joblib import Parallel, delayed
18
19
from omegaconf import ListConfig, OmegaConf

20
from superbench.common.utils import SuperBenchLogger, logger, gen_ibstat, gen_traffic_pattern_host_groups
21
from superbench.common.utils.lazy_import import LazyImport
22
from superbench.benchmarks import ReduceType, Reducer
23
from superbench.monitor import MonitorRecord
24

25
26
AnsibleClient = LazyImport('superbench.runner.ansible', 'AnsibleClient')

27
28
29

class SuperBenchRunner():
    """SuperBench runner class."""
30
    def __init__(self, sb_config, docker_config, ansible_config, sb_output_dir):
31
32
33
34
35
36
        """Initilize.

        Args:
            sb_config (DictConfig): SuperBench config object.
            docker_config (DictConfig): Docker config object.
            ansible_config (DictConfig): Ansible config object.
37
            sb_output_dir (str): SuperBench output directory.
38
39
40
41
        """
        self._sb_config = sb_config
        self._docker_config = docker_config
        self._ansible_config = ansible_config
42
43
        self._sb_output_dir = sb_output_dir
        self._output_path = Path(sb_output_dir).expanduser().resolve()
44
        self._ansible_client = AnsibleClient(ansible_config)
45
46

        self.__set_logger('sb-run.log')
47
        logger.info('Runner uses config: %s.', pformat(OmegaConf.to_container(self._sb_config, resolve=True)))
48
        logger.info('Runner writes to: %s.', str(self._output_path))
49

50
        self._sb_benchmarks = self._sb_config.superbench.benchmarks
51
        self.__validate_sb_config()
52
53
54
        self._sb_enabled_benchmarks = self.__get_enabled_benchmarks()
        logger.info('Runner will run: %s', self._sb_enabled_benchmarks)

55
56
57
58
59
    @property
    def _container_name(self):
        """Get docker container name with backward-compatible default."""
        return getattr(self._docker_config, 'container_name', 'sb-workspace')

60
61
62
63
64
65
    def __set_logger(self, filename):
        """Set logger and add file handler.

        Args:
            filename (str): Log file name.
        """
66
        SuperBenchLogger.add_handler(logger.logger, filename=str(self._output_path / filename))
67

Yifan Xiong's avatar
Yifan Xiong committed
68
    def __validate_sb_config(self):    # noqa: C901
69
70
71
72
73
74
        """Validate SuperBench config object.

        Raise:
            InvalidConfigError: If input config is invalid.
        """
        # TODO: add validation and defaulting
75
        if 'env' not in self._sb_config.superbench:
76
            self._sb_config.superbench.env = {}
77
        for name in self._sb_benchmarks:
78
            if 'modes' not in self._sb_benchmarks[name]:
79
80
                self._sb_benchmarks[name].modes = []
            for idx, mode in enumerate(self._sb_benchmarks[name].modes):
81
                if 'env' not in mode:
82
                    self._sb_benchmarks[name].modes[idx].env = {}
83
                if mode.name == 'local':
84
                    if 'proc_num' not in mode:
85
                        self._sb_benchmarks[name].modes[idx].proc_num = 1
86
                    if 'prefix' not in mode:
87
88
                        self._sb_benchmarks[name].modes[idx].prefix = ''
                elif mode.name == 'torch.distributed':
89
                    if 'proc_num' not in mode:
90
                        self._sb_benchmarks[name].modes[idx].proc_num = 8
Yifan Xiong's avatar
Yifan Xiong committed
91
                elif mode.name == 'mpi':
92
                    if 'mca' not in mode:
Yifan Xiong's avatar
Yifan Xiong committed
93
94
95
96
97
98
                        self._sb_benchmarks[name].modes[idx].mca = {
                            'pml': 'ob1',
                            'btl': '^openib',
                            'btl_tcp_if_exclude': 'lo,docker0',
                            'coll_hcoll_enable': 0,
                        }
one's avatar
one committed
99
100
                    if 'bind_to' not in mode:
                        self._sb_benchmarks[name].modes[idx].bind_to = 'numa'
101
                    for key in ['PATH', 'LD_LIBRARY_PATH', 'SB_MICRO_PATH', 'SB_WORKSPACE']:
Yifan Xiong's avatar
Yifan Xiong committed
102
                        self._sb_benchmarks[name].modes[idx].env.setdefault(key, None)
103
104
                    if 'pattern' in mode:
                        if mode.pattern.type == 'topo-aware' and 'ibstat' not in mode.pattern:
105
106
107
                            self._sb_benchmarks[name].modes[idx].pattern.ibstat = gen_ibstat(
                                self._ansible_config, str(self._output_path / 'ibstate_file.txt')
                            )
108

109
110
111
112
113
114
    def __get_enabled_benchmarks(self):
        """Get enabled benchmarks list.

        Return:
            list: List of benchmarks which will be executed.
        """
115
        if 'enable' in self._sb_config.superbench and self._sb_config.superbench.enable:
116
117
118
119
            if isinstance(self._sb_config.superbench.enable, str):
                return [self._sb_config.superbench.enable]
            elif isinstance(self._sb_config.superbench.enable, (list, ListConfig)):
                return list(self._sb_config.superbench.enable)
120
        return [k for k, v in self._sb_benchmarks.items() if 'enable' in v and v.enable]
121

122
    def __get_mode_command(self, benchmark_name, mode, timeout=None):
123
124
125
        """Get runner command for given mode.

        Args:
126
            benchmark_name (str): Benchmark name.
127
            mode (DictConfig): Runner mode.
128
            timeout (int): The timeout value in seconds.
129
            host_list (list): The specified Host node list.
130
131
132
133

        Return:
            str: Runner command.
        """
134
135
136
137
        exec_command = ('sb exec --output-dir {output_dir} -c sb.config.yaml -C superbench.enable={name}').format(
            name=benchmark_name,
            output_dir=self._sb_output_dir,
        )
138
139
140
        if timeout is not None:
            exec_command = 'timeout {timeout} {command}'.format(timeout=timeout, command=exec_command)

141
142
143
144
        # Enable nsys profiling based on environment variable
        enable_nsys = os.environ.get('SB_ENABLE_NSYS', '') == '1'
        trace_dir = os.environ.get('SB_NSYS_TRACE_DIR', self._sb_output_dir)

145
146
        mode_command = exec_command
        if mode.name == 'local':
147
148
149
150
151
152
153
154
155
156
157
158
159
160
            trace_command = (
                f'nsys profile --output {trace_dir}/{benchmark_name}_{mode.proc_rank}_traces '
                f'--backtrace none --sample none --force-overwrite true --cpuctxsw none --trace cuda,nvtx '
            ) if enable_nsys and mode.proc_rank == 0 else ''
            # Build the command parts, only including trace if it's not empty
            command_parts = []
            prefix = mode.prefix.format(proc_rank=mode.proc_rank, proc_num=mode.proc_num)
            if prefix:
                command_parts.append(prefix)
            if trace_command:
                command_parts.append(trace_command)
            command_parts.append(exec_command)
            mode_command = ' '.join(command_parts)
            mode_command = f'PROC_RANK={mode.proc_rank} {mode_command}'
161
        elif mode.name == 'torch.distributed':
162
163
            # TODO: replace with torch.distributed.run in v1.9
            # TODO: only supports node_num=1 and node_num=all currently
164
165
            torch_dist_params = (
                '' if 'node_num' in mode and mode.node_num == 1 else
166
                '--nnodes=$NNODES --node_rank=$NODE_RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT '
167
168
169
170
171
172
173
            )

            nsys_prefix = (
                f'nsys profile --output {trace_dir}/{benchmark_name}_traces '
                f'--backtrace none --sample none --force-overwrite true --cpuctxsw none --trace cuda,nvtx '
            ) if enable_nsys else ''

174
            mode_command = (
175
                f'{nsys_prefix}'
176
177
                f'torchrun'
                f' --no_python --nproc_per_node={mode.proc_num} {torch_dist_params}{exec_command}'
178
179
                f' superbench.benchmarks.{benchmark_name}.parameters.distributed_impl=ddp'
                f' superbench.benchmarks.{benchmark_name}.parameters.distributed_backend=nccl'
180
            )
Yifan Xiong's avatar
Yifan Xiong committed
181
        elif mode.name == 'mpi':
182
183
184
185
            trace_command = (
                f'nsys profile --output {trace_dir}/{benchmark_name}_{mode.proc_rank}_traces '
                f'--backtrace none --sample none --force-overwrite true --cpuctxsw none --trace cuda,nvtx '
            ) if enable_nsys else ''
Yifan Xiong's avatar
Yifan Xiong committed
186
            mode_command = (
187
                '{trace} '
Yifan Xiong's avatar
Yifan Xiong committed
188
189
190
                'mpirun '    # use default OpenMPI in image
                '-tag-output '    # tag mpi output with [jobid,rank]<stdout/stderr> prefix
                '-allow-run-as-root '    # allow mpirun to run when executed by root user
191
                '{host_list} '    # use prepared hostfile or specify nodes and launch {proc_num} processes on each node
one's avatar
one committed
192
                '-bind-to {bind_to} '    # bind processes according to mode config
Yifan Xiong's avatar
Yifan Xiong committed
193
194
                '{mca_list} {env_list} {command}'
            ).format(
195
                trace=trace_command,
196
197
                host_list=f'-host localhost:{mode.proc_num}' if 'node_num' in mode and mode.node_num == 1 else
                f'-hostfile hostfile -map-by ppr:{mode.proc_num}:node' if 'host_list' not in mode else '-host ' +
198
                ','.join(f'{host}:{mode.proc_num}' for host in mode.host_list),
one's avatar
one committed
199
                bind_to=mode.bind_to,
Yifan Xiong's avatar
Yifan Xiong committed
200
                mca_list=' '.join(f'-mca {k} {v}' for k, v in mode.mca.items()),
201
                env_list=' '.join(self.__format_mpi_env_args(mode.env, mode.proc_rank, mode.proc_num)),
Yifan Xiong's avatar
Yifan Xiong committed
202
203
204
205
                command=exec_command,
            )
        else:
            logger.warning('Unknown mode %s.', mode.name)
206
        return mode_command.strip()
207

208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    def __format_mode_env_value(self, value, proc_rank, proc_num):
        """Format mode env value.

        Args:
            value: Env value from config.
            proc_rank (int): Process rank.
            proc_num (int): Process count.

        Returns:
            str | None: Formatted value, or None to export existing env only.
        """
        if value is None:
            return None
        if isinstance(value, str):
            return value.format(proc_rank=proc_rank, proc_num=proc_num)
        if isinstance(value, (int, float, bool)):
            return str(value)
        raise ValueError(f'Unsupported env value type: {type(value)}')

    def __format_mpi_env_args(self, env_dict, proc_rank, proc_num):
        """Format env args for mpirun.

        Args:
            env_dict (DictConfig): Mode env config.
            proc_rank (int): Process rank.
            proc_num (int): Process count.

        Returns:
            list[str]: Formatted mpirun env args.
        """
        env_args = []
        for key, value in env_dict.items():
            formatted_value = self.__format_mode_env_value(value, proc_rank, proc_num)
            env_args.append(
                f'-x {key}' if formatted_value is None else f'-x {self.__quote_env_assignment(key, formatted_value)}'
            )
        return env_args

    def __quote_env_assignment(self, key, value):
        """Quote env assignment for shell command composition.

        Use double quotes so the result can be embedded inside the existing
        bash -lc '...' wrapper without breaking the outer single quotes.

        Args:
            key (str): Env key.
            value (str): Env value.

        Returns:
            str: Double-quoted KEY=value assignment.
        """
        escaped = str(value).replace('\\', '\\\\').replace('"', '\\"').replace('$', '\\$').replace('`', '\\`')
        return f'"{key}={escaped}"'

262
263
264
265
266
267
268
269
    def get_failure_count(self):
        """Get failure count during Ansible run.

        Return:
            int: Failure count.
        """
        return self._ansible_client.failure_count

270
271
272
273
274
    def deploy(self):    # pragma: no cover
        """Deploy SuperBench environment."""
        logger.info('Preparing SuperBench environment.')
        extravars = {
            'ssh_port': random.randint(1 << 14, (1 << 15) - 1),
275
            'output_dir': str(self._output_path),
276
            'container': self._container_name,
277
            'docker_image': self._docker_config.image,
278
            'docker_pull': bool(self._docker_config.pull),
279
280
281
282
283
284
285
286
287
288
289
        }
        if bool(self._docker_config.username) and bool(self._docker_config.password):
            extravars.update(
                {
                    'docker_registry': self._docker_config.registry,
                    'docker_username': self._docker_config.username,
                    'docker_password': self._docker_config.password,
                }
            )
        self._ansible_client.run(self._ansible_client.get_playbook_config('deploy.yaml', extravars=extravars))

290
291
292
293
294
295
    def run_sys_info(self):
        """Run the system info on all nodes."""
        self.check_env()

        logger.info('Runner is going to get node system info.')

296
        fcmd = "docker exec {container} bash -lc '{{command}}'".format(container=self._container_name)
297
298
299

        if 'skip' not in self._docker_config:
            self._docker_config.skip = False
300
301
302
303
304
305
306
307
308
309
310
        if self._docker_config.skip:
            fcmd = "bash -c 'cd $SB_WORKSPACE && {command}'"
        ansible_runner_config = self._ansible_client.get_shell_config(
            fcmd.format(command='sb node info --output-dir {output_dir}'.format(output_dir=self._sb_output_dir))
        )
        ansible_rc = self._ansible_client.run(ansible_runner_config, sudo=(not self._docker_config.skip))

        if ansible_rc != 0:
            self.cleanup()
        self.fetch_results()

311
312
313
    def check_env(self):    # pragma: no cover
        """Check SuperBench environment."""
        logger.info('Checking SuperBench environment.')
314
        OmegaConf.save(config=self._sb_config, f=str(self._output_path / 'sb.config.yaml'))
315
        self._ansible_client.run(
316
317
318
            self._ansible_client.get_playbook_config(
                'check_env.yaml',
                extravars={
319
                    'container': self._container_name,
320
                    'no_docker': False if 'skip' not in self._docker_config else self._docker_config.skip,
321
                    'output_dir': str(self._output_path),
322
323
324
                    'env': '\n'.join(f'{k}={v}' for k, v in self._sb_config.superbench.env.items()),
                }
            )
325
326
        )

327
328
329
330
    def cleanup(self):    # pragma: no cover
        """Cleanup remaining processes on all nodes."""
        self._ansible_client.run(self._ansible_client.get_playbook_config('cleanup.yaml'))

331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
    def fetch_results(self):    # pragma: no cover
        """Fetch benchmark results on all nodes."""
        try:
            (self._output_path / 'nodes').mkdir(mode=0o755, parents=True, exist_ok=True)
        except Exception:
            logger.exception('Failed to create directory %s.', str(self._output_path / 'nodes'))
            raise
        self._ansible_client.run(
            self._ansible_client.get_playbook_config(
                'fetch_results.yaml',
                extravars={
                    'sb_output_dir': self._sb_output_dir,
                    'absolute_output_dir': str(self._output_path),
                }
            )
        )

348
349
350
351
352
353
354
355
356
357
358
359
    def __signal_handler(self, signum, frame):
        """Signal handler for runner.

        Args:
            signum (int): Signal number.
            frame (FrameType): Timeout frame.
        """
        if signum == signal.SIGINT or signum == signal.SIGTERM:
            logger.info('Killed by %s, exiting ...', signal.Signals(signum).name)
            self.cleanup()
            sys.exit(128 + signum)

360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    def __create_results_summary(self):    # pragma: no cover
        """Create the result summary file of all nodes."""
        all_results = list()
        for node_path in (self._output_path / 'nodes').glob('*'):
            if not node_path.is_dir():
                continue
            results_summary = self.__create_single_node_summary(node_path)
            results_summary['node'] = node_path.name
            all_results.append(results_summary)

        with (self._output_path / 'results-summary.jsonl').open(mode='w') as f:
            for result in all_results:
                json.dump(result, f)
                f.write('\n')

375
    def __create_single_node_summary(self, node_path):    # pragma: no cover # noqa: C901
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
        """Create the result summary file of single node.

        Args:
            node_path (Path): The Path instance of node directory.

        Returns:
            dict: Result summary of single node.
        """
        results_summary = dict()
        reduce_ops = dict()
        file_list = [Path(f) for f in natsorted([str(f) for f in node_path.glob('**/results.json')])]
        for results_file in file_list:
            with results_file.open() as f:
                try:
                    results = json.load(f)
                except ValueError:
                    logger.error('Invalid JSON file: {}'.format(results_file))
                    continue

                for result in results:
396
397
398
399
400
                    try:
                        benchmark_name = result['name']
                    except Exception:
                        logger.error('Invalid content in JSON file: {}'.format(results_file))
                        continue
401
402
403
404
405
406
407
408
409
410
411
412
                    if benchmark_name not in results_summary:
                        results_summary[benchmark_name] = defaultdict(list)
                    for metric in result['result']:
                        metric_name = '{}/{}'.format(benchmark_name, metric)
                        if metric_name not in reduce_ops:
                            reduce_ops[metric_name] = result['reduce_op'][metric]
                        elif reduce_ops[metric_name] != result['reduce_op'][metric]:
                            logger.error('Inconsistent reduce type for metric: {}'.format(metric_name))
                            continue

                        results_summary[benchmark_name][metric].append(result['result'][metric])

413
414
415
        results_summary = self.__merge_benchmark_metrics(results_summary, reduce_ops)
        monitor_summary = self.__merge_monitor_metrics(node_path)
        results_summary = {**results_summary, **monitor_summary}
416
417
418
419
420
        with (node_path / 'results-summary.json').open(mode='w') as f:
            json.dump(results_summary, f, indent=2)

        return results_summary

421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    def __generate_metric_name(self, benchmark_name, metric, rank_count, run_count, curr_rank, curr_run):
        """Generate the summarized metrics name.

        The format of metric name is:
               {benchmark_name}/[{run_count}/]{metric_name}[:rank]
        [run_count] and [rank] parts are optional.

        Args:
            benchmark_name (str): The benchmark name.
            metric (str): The metric name.
            rank_count (int): The total count of rank.
            run_count (int): The total count of benchmarking.
            curr_rank (int): The current rank index.
            curr_run (int): The current run index.

        Returns:
            dict: Flattened result with metric as key.
        """
        metric_name = benchmark_name
        if run_count > 1:
            metric_name = '{}/{}'.format(metric_name, curr_run)
        metric_name = '{}/{}'.format(metric_name, metric)
        if rank_count > 1:
            metric_name = '{}:{}'.format(metric_name, curr_rank)

        return metric_name

448
    def __merge_benchmark_metrics(self, results_summary, reduce_ops):
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
        """Merge metrics of all benchmarks in one node.

        Args:
            results_summary (dict): Summarized result of one node.
            reduce_ops (dict): The reduce type of each metric.

        Returns:
            dict: Flattened result with metric as key.
        """
        metrics_summary = dict()
        for benchmark_name in results_summary:
            for metric in results_summary[benchmark_name]:
                metric_name = '{}/{}'.format(benchmark_name, metric)
                if metric_name not in reduce_ops or (
                    reduce_ops[metric_name] is not None and reduce_ops[metric_name] not in ReduceType.get_values()
                ):
                    logger.error('Unknown reduce type for metric: {}'.format(metric_name))
                    continue

                if reduce_ops[metric_name] is not None:
                    reduce_func = Reducer.get_reduce_func(ReduceType(reduce_ops[metric_name]))
                    values = [reduce_func(list(result)) for result in zip(*results_summary[benchmark_name][metric])]
471
472
473
                    for run in range(len(values)):
                        metric_name = self.__generate_metric_name(benchmark_name, metric, 1, len(values), 0, run)
                        metrics_summary[metric_name] = values[run]
474
                else:
475
476
477
478
479
480
481
482
                    rank_count = len(results_summary[benchmark_name][metric])
                    for rank, rank_value in enumerate(results_summary[benchmark_name][metric]):
                        run_count = len(rank_value)
                        for run, run_value in enumerate(rank_value):
                            metric_name = self.__generate_metric_name(
                                benchmark_name, metric, rank_count, run_count, rank, run
                            )
                            metrics_summary[metric_name] = run_value
483
484
485

        return metrics_summary

486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
    def __merge_monitor_metrics(self, node_path):
        """Merge and summarize monitor metrics of one node.

        Args:
            node_path (Path): The Path instance of node directory.

        Returns:
            dict: Flattened result with metric as key.
        """
        metrics_summary = dict()
        all_samples = list()
        file_list = list(node_path.glob('**/monitor.jsonl'))
        for results_file in file_list:
            try:
                with jsonlines.open(results_file) as reader:
                    all_samples = list(reader)
            except BaseException as e:
                logger.error('Invalid Jsonline file: {}, error message: {}'.format(results_file, str(e)))
                continue
        all_samples = sorted(all_samples, key=lambda k: k.get('time', '0'))
        metrics_dict = dict()
        for sample in all_samples:
            for metric, value in sample.items():
                if metric not in metrics_dict:
                    metrics_dict[metric] = list()
                metrics_dict[metric].append(value)

        for metric, values in metrics_dict.items():
514
            prefix = metric.split(':')[0]
515
            for pattern, reduce_type in MonitorRecord.reduce_ops.items():
516
                if pattern == prefix:
517
                    reduce_func = Reducer.get_reduce_func(reduce_type)
518
519
                    metric_name = 'monitor/{}'.format(metric)
                    metrics_summary[metric_name] = reduce_func(values)
520
521
522
523
                    continue

        return metrics_summary

524
525
526
527
528
529
530
531
532
533
534
535
    def _run_proc(self, benchmark_name, mode, vars):
        """Run the process.

        Args:
            benchmark_name (str): Benchmark name.
            mode (DictConfig): Runner mode.
            vars (dict): Process variables.

        Returns:
            int: Process return code.
        """
        mode.update(vars)
536
        if mode.name == 'mpi' and 'pattern' in mode:
537
            mode.env.update({'SB_MODE_SERIAL_INDEX': mode.serial_index, 'SB_MODE_PARALLEL_INDEX': mode.parallel_index})
538
        logger.info('Runner is going to run %s in %s mode, proc rank %d.', benchmark_name, mode.name, mode.proc_rank)
539

540
        timeout = self._sb_benchmarks[benchmark_name].get('timeout', None)
541
        if isinstance(timeout, int):
542
            timeout = max(timeout, 60)
543

544
545
        if 'skip' not in self._docker_config:
            self._docker_config.skip = False
one's avatar
one committed
546
547
        base_env_cmd = 'set -o allexport && source /root/sb.env && set +o allexport'
        mode_env_cmds = []
548
        if self._docker_config.skip:
one's avatar
one committed
549
            base_env_cmd = 'set -o allexport && source /tmp/sb.env && set +o allexport'
550
        for k, v in mode.env.items():
551
552
553
            formatted_value = self.__format_mode_env_value(v, mode.proc_rank, mode.proc_num)
            if formatted_value is not None:
                envvar = self.__quote_env_assignment(k, formatted_value)
one's avatar
one committed
554
                mode_env_cmds.append(f'export {envvar}')
555

one's avatar
one committed
556
557
558
559
        env_list = base_env_cmd
        if mode_env_cmds:
            env_list = f"{env_list} && {' && '.join(mode_env_cmds)}"

560
        fcmd = "docker exec {container} bash -lc '{{env_list}} && {{command}}'".format(container=self._container_name)
561
        if self._docker_config.skip:
562
            fcmd = "bash -c '{env_list} && cd $SB_WORKSPACE && {command}'"
Yifan Xiong's avatar
Yifan Xiong committed
563
        ansible_runner_config = self._ansible_client.get_shell_config(
564
            fcmd.format(env_list=env_list, command=self.__get_mode_command(benchmark_name, mode, timeout))
565
        )
566
        if mode.name == 'mpi' and 'node_num' in mode and mode.node_num != 1:
Yifan Xiong's avatar
Yifan Xiong committed
567
            ansible_runner_config = self._ansible_client.update_mpi_config(ansible_runner_config)
568

569
570
        if isinstance(timeout, int):
            # we do not expect timeout in ansible unless subprocess hangs
571
            ansible_runner_config['timeout'] = timeout + 60
572

573
574
575
576
        # overwrite ansible runner's default signal handler with main process's
        rc = self._ansible_client.run(
            ansible_runner_config, cancel_callback=lambda: None, sudo=(not self._docker_config.skip)
        )
577
578
        return rc

579
    def run(self):
580
581
        """Run the SuperBench benchmarks distributedly."""
        self.check_env()
582
583
        signal.signal(signal.SIGINT, self.__signal_handler)
        signal.signal(signal.SIGTERM, self.__signal_handler)
584
585
586
587
        for benchmark_name in self._sb_benchmarks:
            if benchmark_name not in self._sb_enabled_benchmarks:
                continue
            benchmark_config = self._sb_benchmarks[benchmark_name]
588
            for mode in benchmark_config.modes:
589
                ansible_rc = 0
590
                if mode.name == 'local':
591
                    rc_list = Parallel(n_jobs=mode.proc_num if mode.parallel else 1)(
592
593
594
                        delayed(self._run_proc)(benchmark_name, mode, {
                            'proc_rank': proc_rank
                        }) for proc_rank in range(mode.proc_num)
595
                    )
596
                    ansible_rc = sum(rc_list)
Yifan Xiong's avatar
Yifan Xiong committed
597
                elif mode.name == 'torch.distributed' or mode.name == 'mpi':
598
                    if 'pattern' not in mode:
599
600
                        ansible_rc = self._run_proc(benchmark_name, mode, {'proc_rank': 0})
                    else:
601
602
603
                        if not os.path.exists(self._output_path / 'hostfile'):
                            logger.warning('No hostfile under %s.', self._output_path)
                            continue
604
605
                        with open(self._output_path / 'hostfile', 'r') as f:
                            host_list = f.read().splitlines()
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
                        host_groups = gen_traffic_pattern_host_groups(
                            host_list, mode.pattern, self._output_path / 'mpi_pattern.txt', benchmark_name
                        )
                        for serial_index, host_group in enumerate(host_groups):
                            para_rc_list = Parallel(n_jobs=len(host_group))(
                                delayed(self._run_proc)(
                                    benchmark_name,
                                    mode,
                                    vars={
                                        'proc_rank': 0,
                                        'host_list': host_list,
                                        'serial_index': str(serial_index),
                                        'parallel_index': str(parallel_index),
                                    }
                                ) for parallel_index, host_list in enumerate(host_group)
621
622
                            )
                            ansible_rc = ansible_rc + sum(para_rc_list)
Yifan Xiong's avatar
Yifan Xiong committed
623
624
                else:
                    logger.warning('Unknown mode %s.', mode.name)
625
626
                if ansible_rc != 0:
                    self.cleanup()
627
            self.fetch_results()
628
629

        self.__create_results_summary()