bert_helper.py 5.24 KB
Newer Older
1
2
import torch

3
4
5
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc

6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
_MAX_DATA_DIM = 5


def _build_key_size_numel_dictionaries(keys, data):
    """Build the size on rank 0 and broadcast."""
    max_dim = _MAX_DATA_DIM
    sizes = [0 for _ in range(max_dim) for _ in keys]

    # Pack the sizes on rank zero.
    if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0:
        offset = 0
        for key in keys:
            assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
            size = data[key].size()
            for i, s in enumerate(size):
                sizes[i + offset] = s
            offset += max_dim

    # Move to GPU and broadcast.
    sizes_cuda = torch.cuda.LongTensor(sizes)
26
27
    torch.distributed.broadcast(sizes_cuda,
                                gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
                                group=gpc.get_group(ParallelMode.TENSOR))

    # Move back to cpu and unpack.
    sizes_cpu = sizes_cuda.cpu()
    key_size = {}
    key_numel = {}
    total_numel = 0
    offset = 0
    for key in keys:
        i = 0
        size = []
        numel = 1
        while sizes_cpu[offset + i] > 0:
            this_size = sizes_cpu[offset + i]
            size.append(this_size)
            numel *= this_size
            i += 1
        key_size[key] = size
        key_numel[key] = numel
        total_numel += numel
        offset += max_dim

    return key_size, key_numel, total_numel


def broadcast_data(keys, data, datatype):
    """Broadcast data from rank zero of each model parallel group to the
    members of the same model parallel group.

    Arguments:
        keys: list of keys in the data dictionary to be broadcasted
        data: data dictionary of string keys and cpu tensor values.
        datatype: torch data type of all tensors in data associated
                  with keys.
    """
    # Build (key, size) and (key, number of elements) dictionaries along
    # with the total number of elements on all ranks.
65
    key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data)
66
67
68
69
70

    # Pack on rank zero.
    if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0:
        # Check that all keys have the same data type.
        # Flatten the data associated with the keys
71
        flatten_data = torch.cat([data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
72
    else:
73
        flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype)
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

    # Broadcast
    torch.distributed.broadcast(flatten_data,
                                gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
                                group=gpc.get_group(ParallelMode.TENSOR))

    # Unpack
    output = {}
    offset = 0
    for key in keys:
        size = key_size[key]
        numel = key_numel[key]
        output[key] = flatten_data.narrow(0, offset, numel).view(size)
        offset += numel

    return output


def get_batch(data_iterator):
    """Build the batch."""

    # Items and their type.
    keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    data_b = broadcast_data(keys, data, datatype)

    # Unpack.
    tokens = data_b['text'].long()
    types = data_b['types'].long()
    sentence_order = data_b['is_random'].long()
    loss_mask = data_b['loss_mask'].float()
    lm_labels = data_b['labels'].long()
    padding_mask = data_b['padding_mask'].long()

    return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask


def get_batch_for_sequence_parallel(data_iterator):
    """Build the batch."""

    # Items and their type.
    keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None

    # unpack
    data_b = broadcast_data(keys, data, datatype)

    # # get tensor parallel local rank
    global_rank = torch.distributed.get_rank()
    local_world_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(ParallelMode.TENSOR)
    local_rank = global_rank % local_world_size
    seq_length = data_b['text'].size(1)
    sub_seq_length = seq_length // local_world_size
    sub_seq_start = local_rank * sub_seq_length
140
    sub_seq_end = (local_rank + 1) * sub_seq_length
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    #
    # # Unpack.
    tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long()
    types = data_b['types'][:, sub_seq_start:sub_seq_end].long()
    sentence_order = data_b['is_random'].long()
    loss_mask = data_b['loss_mask'][:, sub_seq_start:sub_seq_end].float()
    lm_labels = data_b['labels'][:, sub_seq_start:sub_seq_end].long()
    padding_mask = data_b['padding_mask'].long()

    return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask


class SequenceParallelDataIterator:

    def __init__(self, data_iter):
        self.data_iter = data_iter

    def __iter__(self):
        return self.data_iter

    def __next__(self):
162
        return get_batch_for_sequence_parallel(self.data_iter)