utils.py 5.34 KB
Newer Older
1
import argparse
2
3
import os
import warnings
4
from typing import Any, Callable, Dict, List, Tuple, Type, Union
5
6

import torch
7
import torch.distributed.rpc as rpc
8
9
10
import torch.multiprocessing as mp
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from torch.futures import Future
11

12
from colossalai.initialize import launch
13
from colossalai.legacy.pipeline.pipeline_process_group import ppg
14
15
16
17
18
19
20
21
22
23
24
25
26
27


def pyobj_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = ()) -> Any:
    if isinstance(obj, process_types):
        return fn(obj)
    elif type(obj) is dict:
        return {k: pyobj_map(obj[k], fn, process_types) for k in obj}
    elif type(obj) is tuple:
        return tuple(pyobj_map(o, fn, process_types) for o in obj)
    elif type(obj) is list:
        return list(pyobj_map(o, fn, process_types) for o in obj)
    else:
        return obj

28
29
30
31
32
33
34

def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
    """process object recursively, like pytree

    Args:
        obj (:class:`Any`): object to process
        fn (:class:`Callable`): a function to process subobject in obj
35
        process_types (:class: `type | tuple[type]`): types to determine the type to process
36
        map_all (:class: `bool`): if map_all is True, then any type of element will use fn
37
38

    Returns:
39
        :class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    """
    if isinstance(obj, dict):
        return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj}
    elif isinstance(obj, tuple):
        return tuple(pytree_map(o, fn, process_types, map_all) for o in obj)
    elif isinstance(obj, list):
        return list(pytree_map(o, fn, process_types, map_all) for o in obj)
    elif isinstance(obj, process_types):
        return fn(obj)
    else:
        return fn(obj) if map_all else obj


def tensor_shape_list(obj):
    return pytree_map(obj, fn=lambda x: x.shape, process_types=torch.Tensor)


def get_batch_lengths(batch):
    lengths = []
    pytree_map(batch, fn=lambda x: lengths.append(len(x)), process_types=torch.Tensor)
    return lengths


def split_batch(batch: Any, start, stop, device: str):
    if device == 'cuda':
        fn = lambda x: x[start:stop].cuda()
    else:
        fn = lambda x: x[start:stop]
    return pytree_map(batch, fn=fn, process_types=torch.Tensor)


def type_detail(obj):
    return pytree_map(obj, lambda x: type(x), map_all=True)

74

75
76
77
78
79
80
81
82
83
84
85
86
87
88
def pytree_filter(fn, obj, process_types):
    if obj is None:
        return None

    filters = []

    def condition_append(obj):
        if fn(obj):
            filters.append(obj)

    pytree_map(obj, fn=condition_append, process_types=process_types)
    return filters


89
90
91
92
93
94
95
96
97
98
99
100
101
def get_real_args_kwargs(args_or_kwargs):
    args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
    # TODO : combine producer and consumer
    # by default, merge all args in the output args or kwargs
    if args_or_kwargs is not None:
        if isinstance(args_or_kwargs, dict):
            pass
        else:
            flatten_args = []
            pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
            args_or_kwargs = flatten_args

    return args_or_kwargs
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153


def run_worker(rank, args, master_func):
    os.environ['MASTER_ADDR'] = args.master_addr
    os.environ['MASTER_PORT'] = args.master_port

    device = args.device
    world_size = args.world_size
    dp_degree = args.dp_degree
    tp_degree = args.tp_degree
    num_worker_threads = args.num_worker_threads
    host = args.master_addr
    port = args.master_port
    backend = 'nccl' if device == 'cuda' else 'gloo'

    launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
    ppg.set_global_info(rank=rank,
                        world_size=world_size,
                        dp_degree=dp_degree,
                        tp_degree=tp_degree,
                        num_worker_threads=num_worker_threads,
                        device=device)
    ppg.args = args
    # in rpc mode, only rank 0 is needed to be coded
    if rank == 0:
        master_func(args)
    # barrier here
    if _is_current_rpc_agent_set():
        rpc.shutdown()
    else:
        warnings.warn("RPC has not been initialized")


def rpc_run(args, master_func):
    world_size = args.world_size
    mp.spawn(run_worker, args=(args, master_func), nprocs=world_size)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch', type=int, default=1)
    parser.add_argument('--world_size', type=int, default=2)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--dp_degree', type=int, default=1)
    parser.add_argument('--tp_degree', type=int, default=1)
    parser.add_argument('--num_microbatches', type=int, default=2)
    parser.add_argument('--chunk', type=int, default=1)
    parser.add_argument('--use_checkpoint', action='store_true')
    parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD')
    parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
    parser.add_argument('--master_addr', type=str, default='localhost')
    parser.add_argument('--master_port', type=str, default='29020')
154
    parser.add_argument('--num_worker_threads', type=int, default=128)
155
    return parser.parse_args()