utils.py 2.42 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Utilities for benchmark tests."""

import os
7
8
9
10
import socket
from contextlib import closing
import multiprocessing as multiprocessing
from multiprocessing import Process
11

12
from superbench.benchmarks import BenchmarkRegistry
13
14
15
16
17
18
19
20
21


def clean_simulated_ddp_distributed_env():
    """Function to clean up the simulated DDP distributed envionment variables."""
    os.environ.pop('WORLD_SIZE')
    os.environ.pop('RANK')
    os.environ.pop('LOCAL_RANK')
    os.environ.pop('MASTER_ADDR')
    os.environ.pop('MASTER_PORT')
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75


def get_free_port():
    """Get a free port in current system.

    Return:
        port (int): a free port in current system.
    """
    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
        s.bind(('', 0))
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        return s.getsockname()[1]


def setup_simulated_ddp_distributed_env(world_size, local_rank, port):
    """Function to setup the simulated DDP distributed envionment variables."""
    os.environ['WORLD_SIZE'] = str(world_size)
    os.environ['RANK'] = str(local_rank)
    os.environ['LOCAL_RANK'] = str(local_rank)
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = str(port)


def benchmark_in_one_process(context, world_size, local_rank, port, queue):
    """Function to setup env for DDP initialization and run the benchmark in each single process."""
    setup_simulated_ddp_distributed_env(world_size, local_rank, port)
    benchmark = BenchmarkRegistry.launch_benchmark(context)
    # parser object must be removed becaues it can not be serialized.
    benchmark._parser = None
    queue.put(benchmark)
    clean_simulated_ddp_distributed_env()


def simulated_ddp_distributed_benchmark(context, world_size):
    """Function to run the benchmark on #world_size number of processes.

    Return:
        results (list): list of benchmark results from #world_size number of processes.
    """
    port = get_free_port()
    process_list = []
    multiprocessing.set_start_method('spawn')

    queue = multiprocessing.Queue()

    for rank in range(world_size):
        process = Process(target=benchmark_in_one_process, args=(context, world_size, rank, port, queue))
        process.start()
        process_list.append(process)

    for process in process_list:
        process.join()
    results = [queue.get(1) for p in process_list]
    return results