utils.py 2.14 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Utilities for benchmark tests."""

import os
7
8
import multiprocessing as multiprocessing
from multiprocessing import Process
9

10
from superbench.benchmarks import BenchmarkRegistry
11
from superbench.common.utils import network
12
13
14
15
16
17
18
19
20


def clean_simulated_ddp_distributed_env():
    """Function to clean up the simulated DDP distributed envionment variables."""
    os.environ.pop('WORLD_SIZE')
    os.environ.pop('RANK')
    os.environ.pop('LOCAL_RANK')
    os.environ.pop('MASTER_ADDR')
    os.environ.pop('MASTER_PORT')
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47


def setup_simulated_ddp_distributed_env(world_size, local_rank, port):
    """Function to setup the simulated DDP distributed envionment variables."""
    os.environ['WORLD_SIZE'] = str(world_size)
    os.environ['RANK'] = str(local_rank)
    os.environ['LOCAL_RANK'] = str(local_rank)
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = str(port)


def benchmark_in_one_process(context, world_size, local_rank, port, queue):
    """Function to setup env for DDP initialization and run the benchmark in each single process."""
    setup_simulated_ddp_distributed_env(world_size, local_rank, port)
    benchmark = BenchmarkRegistry.launch_benchmark(context)
    # parser object must be removed becaues it can not be serialized.
    benchmark._parser = None
    queue.put(benchmark)
    clean_simulated_ddp_distributed_env()


def simulated_ddp_distributed_benchmark(context, world_size):
    """Function to run the benchmark on #world_size number of processes.

    Return:
        results (list): list of benchmark results from #world_size number of processes.
    """
48
49
50
    port = network.get_free_port()
    if not port:
        return None
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    process_list = []
    multiprocessing.set_start_method('spawn')

    queue = multiprocessing.Queue()

    for rank in range(world_size):
        process = Process(target=benchmark_in_one_process, args=(context, world_size, rank, port, queue))
        process.start()
        process_list.append(process)

    for process in process_list:
        process.join()
    results = [queue.get(1) for p in process_list]
    return results