gloo_wrapper.py 7.91 KB
Newer Older
1
2
3
4
import numpy as np
import torch
import torch.distributed as dist

5

6
def allgather_sizes(send_data, world_size, num_parts, return_sizes=False):
7
    """
8
9
10
11
12
13
    Perform all gather on list lengths, used to compute prefix sums
    to determine the offsets on each ranks. This is used to allocate
    global ids for edges/nodes on each ranks.

    Parameters
    ----------
14
    send_data : numpy array
15
16
17
        Data on which allgather is performed.
    world_size : integer
        No. of processes configured for execution
18
19
    num_parts : integer
        No. of output graph partitions
20
21
22
    return_sizes : bool
        Boolean flag to indicate whether to return raw sizes from each process
        or perform prefix sum on the raw sizes.
23

24
    Returns :
25
26
27
28
29
    ---------
        numpy array
            array with the prefix sum
    """

30
31
32
    # Assert on the world_size, num_parts
    assert (num_parts % world_size) == 0

33
    # compute the length of the local data
34
35
    send_length = len(send_data)
    out_tensor = torch.as_tensor(send_data, dtype=torch.int64)
36
37
38
    in_tensor = [
        torch.zeros(send_length, dtype=torch.int64) for _ in range(world_size)
    ]
39

40
    # all_gather message
41
42
    dist.all_gather(in_tensor, out_tensor)

43
44
45
46
    # Return on the raw sizes from each process
    if return_sizes:
        return torch.cat(in_tensor).numpy()

47
    # gather sizes in on array to return to the invoking function
48
49
50
    rank_sizes = np.zeros(num_parts + 1, dtype=np.int64)
    part_counts = torch.cat(in_tensor).numpy()

51
    count = rank_sizes[0]
52
    idx = 1
53
    for local_part_id in range(num_parts // world_size):
54
        for r in range(world_size):
55
            count += part_counts[r * (num_parts // world_size) + local_part_id]
56
57
            rank_sizes[idx] = count
            idx += 1
58
59
60

    return rank_sizes

61

62
def __alltoall_cpu(rank, world_size, output_tensor_list, input_tensor_list):
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    """
    Each process scatters list of input tensors to all processes in a cluster
    and return gathered list of tensors in output list. The tensors should have the same shape.

    Parameters
    ----------
    rank : int
        The rank of current worker
    world_size : int
        The size of the entire
    output_tensor_list : List of tensor
        The received tensors
    input_tensor_list : List of tensor
        The tensors to exchange
    """
78
79
80
    input_tensor_list = [
        tensor.to(torch.device("cpu")) for tensor in input_tensor_list
    ]
81
82
83
    # TODO(#5002): As Boolean data is not supported in
    # ``torch.distributed.scatter()``, we convert boolean into uint8 before
    # scatter and convert it back afterwards.
84
    dtypes = [t.dtype for t in input_tensor_list]
85
86
87
88
    for i, dtype in enumerate(dtypes):
        if dtype == torch.bool:
            input_tensor_list[i] = input_tensor_list[i].to(torch.int8)
            output_tensor_list[i] = output_tensor_list[i].to(torch.int8)
89
    for i in range(world_size):
90
91
92
        dist.scatter(
            output_tensor_list[i], input_tensor_list if i == rank else [], src=i
        )
93
94
95
96
97
    # Convert back to original dtype
    for i, dtype in enumerate(dtypes):
        if dtype == torch.bool:
            input_tensor_list[i] = input_tensor_list[i].to(dtype)
            output_tensor_list[i] = output_tensor_list[i].to(dtype)
98

99

100
def alltoallv_cpu(rank, world_size, input_tensor_list, retain_nones=True):
101
    """
102
    Wrapper function to providing the alltoallv functionality by using underlying alltoall
103
    messaging primitive. This function, in its current implementation, supports exchanging
104
105
106
    messages of arbitrary dimensions and is not tied to the user of this function.

    This function pads all input tensors, except one, so that all the messages are of the same
107
108
    size. Once the messages are padded, It first sends a vector whose first two elements are
    1) actual message size along first dimension, and 2) Message size along first dimension
109
110
    which is used for communication. The rest of the dimensions are assumed to be same across
    all the input tensors. After receiving the message sizes, the receiving end will create buffers
111
112
    of appropriate sizes. And then slices the received messages to remove the added padding, if any,
    and returns to the caller.
113

114
115
    Parameters:
    -----------
116
117
118
119
120
121
    rank : int
        The rank of current worker
    world_size : int
        The size of the entire
    input_tensor_list : List of tensor
        The tensors to exchange
122
123
    retain_nones : bool
        Indicates whether to retain ``None`` data in returned value.
124

125
126
    Returns:
    --------
127
    list :
128
        list of tensors received from other processes during alltoall message
129

130
    """
131
    # ensure len of input_tensor_list is same as the world_size.
132
133
134
    assert input_tensor_list != None
    assert len(input_tensor_list) == world_size

135
    # ensure that all the tensors in the input_tensor_list are of same size.
136
    sizes = [list(x.size()) for x in input_tensor_list]
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    for idx in range(1, len(sizes)):
        assert len(sizes[idx - 1]) == len(
            sizes[idx]
        )  # no. of dimensions should be same
        assert (
            input_tensor_list[idx - 1].dtype == input_tensor_list[idx].dtype
        )  # dtype should be same
        assert (
            sizes[idx - 1][1:] == sizes[idx][1:]
        )  # except first dimension remaining dimensions should all be the same

    # decide how much to pad.
    # always use the first-dimension for padding.
    ll = [x[0] for x in sizes]

    # dims of the padding needed, if any
    # these dims are used for padding purposes.
    diff_dims = [[np.amax(ll) - l[0]] + l[1:] for l in sizes]

    # pad the actual message
    input_tensor_list = [
        torch.cat((x, torch.zeros(diff_dims[idx]).type(x.dtype)))
        for idx, x in enumerate(input_tensor_list)
    ]

    # send useful message sizes to all
163
164
165
    send_counts = []
    recv_counts = []
    for idx in range(world_size):
166
167
168
169
170
171
172
173
174
175
176
        # send a vector, of atleast 3 elements, [a, b, ....] where
        # a = useful message dim, b = actual message outgoing message size along the first dimension
        # and remaining elements are the remaining dimensions of the tensor
        send_counts.append(
            torch.from_numpy(
                np.array([sizes[idx][0]] + [np.amax(ll)] + sizes[idx][1:])
            ).type(torch.int64)
        )
        recv_counts.append(
            torch.zeros((1 + len(sizes[idx])), dtype=torch.int64)
        )
177
178
    __alltoall_cpu(rank, world_size, recv_counts, send_counts)

179
    # allocate buffers for receiving message
180
    output_tensor_list = []
181
    recv_counts = [tsize.numpy() for tsize in recv_counts]
182
    for idx, tsize in enumerate(recv_counts):
183
184
185
        output_tensor_list.append(
            torch.zeros(tuple(tsize[1:])).type(input_tensor_list[idx].dtype)
        )
186

187
    # send actual message itself.
188
189
    __alltoall_cpu(rank, world_size, output_tensor_list, input_tensor_list)

190
    # extract un-padded message from the output_tensor_list and return it
191
192
193
    return_vals = []
    for s, t in zip(recv_counts, output_tensor_list):
        if s[0] == 0:
194
195
            if retain_nones:
                return_vals.append(None)
196
        else:
197
            return_vals.append(t[0 : s[0]])
198
    return return_vals
199

200
201
202

def gather_metadata_json(metadata, rank, world_size):
    """
203
204
205
206
    Gather an object (json schema on `rank`)
    Parameters:
    -----------
    metadata : json dictionary object
207
        json schema formed on each rank with graph level data.
208
209
210
211
        This will be used as input to the distributed training in the later steps.
    Returns:
    --------
    list : list of json dictionary objects
212
        The result of the gather operation, which is the list of json dicitonary
213
214
215
        objects from each rank in the world
    """

216
    # Populate input obj and output obj list on rank-0 and non-rank-0 machines
217
218
219
    input_obj = None if rank == 0 else metadata
    output_objs = [None for _ in range(world_size)] if rank == 0 else None

220
    # invoke the gloo method to perform gather on rank-0
221
222
    dist.gather_object(input_obj, output_objs, dst=0)
    return output_objs