embedding.py 8.42 KB
Newer Older
1
2
3
4
# Copyright (c) 2022, Tri Dao.

import torch
import torch.nn as nn
5
from torch import Tensor
6

7
8
from einops import rearrange

9
from flash_attn.utils.distributed import reduce_scatter, all_reduce
10

11
12
13

class GPT2Embeddings(nn.Module):

14
    def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None,
Tri Dao's avatar
Tri Dao committed
15
                 word_embed_proj_dim=None, device=None, dtype=None):
16
17
        """
            If max_position_embeddings <= 0, there's no position embeddings
Tri Dao's avatar
Tri Dao committed
18
19
            If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
                the project up to embed_dim
20
        """
21
        factory_kwargs = {'device': device, 'dtype': dtype}
22
        super().__init__()
Tri Dao's avatar
Tri Dao committed
23
24
25
26
27
28
29
30
31
        if word_embed_proj_dim is None:
            self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
                                                **factory_kwargs)
            self.project_in = None
        else:
            self.word_embeddings = nn.Embedding(vocab_size, word_embed_proj_dim,
                                                padding_idx=padding_idx, **factory_kwargs)
            self.project_in = nn.Linear(word_embed_proj_dim, embed_dim, bias=False,
                                        **factory_kwargs)
32
33
        self.max_position_embeddings = max_position_embeddings
        if self.max_position_embeddings > 0:
34
35
            self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
                                                    **factory_kwargs)
36
37
38
39

    def forward(self, input_ids, position_ids=None):
        """
            input_ids: (batch, seqlen)
Tri Dao's avatar
Tri Dao committed
40
41
42
43
            position_ids: (batch, seqlen)
        """
        batch_size, seqlen = input_ids.shape
        embeddings = self.word_embeddings(input_ids)
Tri Dao's avatar
Tri Dao committed
44
45
        if self.project_in is not None:
            embeddings = self.project_in(embeddings)
Tri Dao's avatar
Tri Dao committed
46
47
48
49
50
51
52
53
54
55
56
        if self.max_position_embeddings > 0:
            if position_ids is None:
                position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
            position_embeddings = self.position_embeddings(position_ids)
            embeddings = embeddings + position_embeddings
        return embeddings


class BertEmbeddings(nn.Module):

    def __init__(self, embed_dim, vocab_size, max_position_embeddings, type_vocab_size,
57
                 padding_idx=None, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
58
59
60
61
        """
            If max_position_embeddings <= 0, there's no position embeddings
            If type_vocab_size <= 0, there's no token type embeddings
        """
62
        factory_kwargs = {'device': device, 'dtype': dtype}
Tri Dao's avatar
Tri Dao committed
63
        super().__init__()
64
65
        self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
                                            **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
66
67
68
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        if self.max_position_embeddings > 0:
69
70
            self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
                                                    **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
71
        if self.type_vocab_size > 0:
72
73
            self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim,
                                                      **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
74
75
76
77
78
79

    def forward(self, input_ids, position_ids=None, token_type_ids=None):
        """
            input_ids: (batch, seqlen)
            position_ids: (batch, seqlen)
            token_type_ids: (batch, seqlen)
80
81
        """
        batch_size, seqlen = input_ids.shape
Tri Dao's avatar
Tri Dao committed
82
        embeddings = self.word_embeddings(input_ids)
83
84
        if self.max_position_embeddings > 0:
            if position_ids is None:
Tri Dao's avatar
Tri Dao committed
85
                position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
86
            position_embeddings = self.position_embeddings(position_ids)
Tri Dao's avatar
Tri Dao committed
87
88
89
90
91
92
93
            embeddings = embeddings + position_embeddings
        if self.type_vocab_size > 0:
            if token_type_ids is None:
                token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
            token_type_embeddings = self.token_type_embeddings(token_type_ids)
            embeddings = embeddings + token_type_embeddings
        return embeddings
94
95


96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
class VocabParallelEmbedding(nn.Embedding):

    def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
        self.process_group = process_group
        if process_group is not None:
            world_size = torch.distributed.get_world_size(process_group)
            if num_embeddings % world_size != 0:
                raise ValueError(f'num_embeddings ({num_embeddings}) must be divisible by '
                                 f'world_size ({world_size})')
            if world_size > 1 and padding_idx is not None:
                raise RuntimeError('ParallelEmbedding does not support padding_idx')
        else:
            world_size = 1
        super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)

    def forward(self, input: Tensor) -> Tensor:
        if self.process_group is None:
            return super().forward(input)
        else:
            rank = torch.distributed.get_rank(self.process_group)
            vocab_size = self.num_embeddings
            vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
            # Create a mask of valid vocab ids (1 means it needs to be masked).
            input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
            input = input - vocab_start_index
            input[input_ids_mask] = 0
            embeddings = super().forward(input)
            embeddings[input_ids_mask] = 0.0
            return embeddings


class ColumnParallelEmbedding(nn.Embedding):

    def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
        self.process_group = process_group
        if process_group is not None:
            world_size = torch.distributed.get_world_size(process_group)
            if embedding_dim % world_size != 0:
                raise ValueError(f'embedding_dim ({embedding_dim}) must be divisible by '
                                 f'world_size ({world_size})')
        else:
            world_size = 1
        super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)


141
142
143
class ParallelGPT2Embeddings(nn.Module):

    def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group,
144
                 padding_idx=None, sequence_parallel=True, device=None, dtype=None):
145
146
147
148
149
150
        """
            If max_position_embeddings <= 0, there's no position embeddings
        """
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.process_group = process_group
151
        self.sequence_parallel = sequence_parallel
152
153
154
155
        self.word_embeddings = VocabParallelEmbedding(
            vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group,
            **factory_kwargs
        )
156
157
        self.max_position_embeddings = max_position_embeddings
        if self.max_position_embeddings > 0:
158
159
            self.position_embeddings = ColumnParallelEmbedding(
                max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs
160
161
162
163
164
165
166
167
168
            )

    def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
        """
            input_ids: (batch, seqlen)
            position_ids: (batch, seqlen)
        """
        batch_size, seqlen = input_ids.shape
        world_size = torch.distributed.get_world_size(self.process_group)
169
170
171
172
173
174
        embeddings = self.word_embeddings(input_ids)
        if self.max_position_embeddings > 0:
            if position_ids is None:
                position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
            position_embeddings = self.position_embeddings(position_ids)
            if world_size <= 1:
175
                embeddings = embeddings + position_embeddings
176
            else:
177
                partition_dim = self.position_embeddings.embedding_dim
178
                rank = torch.distributed.get_rank(self.process_group)
179
                embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings
180
181
        if combine_batch_seqlen_dim:
            embeddings = rearrange(embeddings, 'b s d -> (b s) d')
182
183
        reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
        return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)