bert_helper.py 5.14 KB
Newer Older
1
2
import torch

3
4
5
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc

6
7
8
9
10
11
12
13
14
15
16
17
_MAX_DATA_DIM = 5


def _build_key_size_numel_dictionaries(keys, data):
    """Build the size on rank 0 and broadcast."""
    max_dim = _MAX_DATA_DIM
    sizes = [0 for _ in range(max_dim) for _ in keys]

    # Pack the sizes on rank zero.
    if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0:
        offset = 0
        for key in keys:
18
            assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM"
19
20
21
22
23
24
25
            size = data[key].size()
            for i, s in enumerate(size):
                sizes[i + offset] = s
            offset += max_dim

    # Move to GPU and broadcast.
    sizes_cuda = torch.cuda.LongTensor(sizes)
26
27
28
    torch.distributed.broadcast(
        sizes_cuda, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR)
    )
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

    # Move back to cpu and unpack.
    sizes_cpu = sizes_cuda.cpu()
    key_size = {}
    key_numel = {}
    total_numel = 0
    offset = 0
    for key in keys:
        i = 0
        size = []
        numel = 1
        while sizes_cpu[offset + i] > 0:
            this_size = sizes_cpu[offset + i]
            size.append(this_size)
            numel *= this_size
            i += 1
        key_size[key] = size
        key_numel[key] = numel
        total_numel += numel
        offset += max_dim

    return key_size, key_numel, total_numel


def broadcast_data(keys, data, datatype):
    """Broadcast data from rank zero of each model parallel group to the
    members of the same model parallel group.

    Arguments:
        keys: list of keys in the data dictionary to be broadcasted
        data: data dictionary of string keys and cpu tensor values.
        datatype: torch data type of all tensors in data associated
                  with keys.
    """
    # Build (key, size) and (key, number of elements) dictionaries along
    # with the total number of elements on all ranks.
65
    key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data)
66
67
68
69
70

    # Pack on rank zero.
    if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0:
        # Check that all keys have the same data type.
        # Flatten the data associated with the keys
71
        flatten_data = torch.cat([data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
72
    else:
73
        flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype)
74
75

    # Broadcast
76
77
78
    torch.distributed.broadcast(
        flatten_data, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR)
    )
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

    # Unpack
    output = {}
    offset = 0
    for key in keys:
        size = key_size[key]
        numel = key_numel[key]
        output[key] = flatten_data.narrow(0, offset, numel).view(size)
        offset += numel

    return output


def get_batch(data_iterator):
    """Build the batch."""

    # Items and their type.
96
    keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"]
97
98
99
100
101
102
103
104
105
106
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    data_b = broadcast_data(keys, data, datatype)

    # Unpack.
107
108
109
110
111
112
    tokens = data_b["text"].long()
    types = data_b["types"].long()
    sentence_order = data_b["is_random"].long()
    loss_mask = data_b["loss_mask"].float()
    lm_labels = data_b["labels"].long()
    padding_mask = data_b["padding_mask"].long()
113
114
115
116
117
118
119
120

    return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask


def get_batch_for_sequence_parallel(data_iterator):
    """Build the batch."""

    # Items and their type.
121
    keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"]
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None

    # unpack
    data_b = broadcast_data(keys, data, datatype)

    # # get tensor parallel local rank
    global_rank = torch.distributed.get_rank()
    local_world_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(ParallelMode.TENSOR)
    local_rank = global_rank % local_world_size
137
    seq_length = data_b["text"].size(1)
138
139
    sub_seq_length = seq_length // local_world_size
    sub_seq_start = local_rank * sub_seq_length
140
    sub_seq_end = (local_rank + 1) * sub_seq_length
141
142
    #
    # # Unpack.
143
144
145
146
147
148
    tokens = data_b["text"][:, sub_seq_start:sub_seq_end].long()
    types = data_b["types"][:, sub_seq_start:sub_seq_end].long()
    sentence_order = data_b["is_random"].long()
    loss_mask = data_b["loss_mask"][:, sub_seq_start:sub_seq_end].float()
    lm_labels = data_b["labels"][:, sub_seq_start:sub_seq_end].long()
    padding_mask = data_b["padding_mask"].long()
149
150
151
152
153
154
155
156
157
158
159
160

    return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask


class SequenceParallelDataIterator:
    def __init__(self, data_iter):
        self.data_iter = data_iter

    def __iter__(self):
        return self.data_iter

    def __next__(self):
161
        return get_batch_for_sequence_parallel(self.data_iter)