gen_traffic_pattern_config.py 1.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Utilities for traffic pattern config."""
from superbench.common.utils import logger


def gen_all_nodes_config(n):
    """Generate all nodes config.

    Args:
        n (int): the number of participants.

    Returns:
        config (list): the generated config list, each item in the list is a str like "0,1,2,3".
    """
    config = []
    if n <= 0:
        logger.warning('n is not positive')
        return config
    config = [','.join(map(str, range(n)))]
    return config


def __convert_config_to_host_group(config, host_list):
    """Convert config format to host node.

    Args:
        host_list (list): the list of hostnames read from hostfile.
        config (list): the traffic pattern config.

    Returns:
        host_groups (list): the host groups converted from traffic pattern config.
    """
    host_groups = []
    for item in config:
        groups = item.strip().strip(';').split(';')
        host_group = []
        for group in groups:
            hosts = []
            for index in group.split(','):
                hosts.append(host_list[int(index)])
            host_group.append(hosts)
        host_groups.append(host_group)
    return host_groups


def gen_tarffic_pattern_host_group(host_list, pattern):
    """Generate host group from specified traffic pattern.

    Args:
        host_list (list): the list of hostnames read from hostfile.
        pattern (DictConfig): the mpi pattern dict.

    Returns:
        host_group (list): the host group generated from traffic pattern.
    """
    config = []
    n = len(host_list)
    if pattern.name == 'all-nodes':
        config = gen_all_nodes_config(n)
    else:
        logger.error('Unsupported traffic pattern: {}'.format(pattern.name))
    host_group = __convert_config_to_host_group(config, host_list)
    return host_group