embedding.py 7.77 KB
Newer Older
1
2
3
4
# Copyright (c) 2022, Tri Dao.

import torch
import torch.nn as nn
5
from torch import Tensor
6

7
8
from einops import rearrange

9
from flash_attn.utils.distributed import reduce_scatter, all_reduce
10

11
12
13

class GPT2Embeddings(nn.Module):

14
15
    def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None,
                 device=None, dtype=None):
16
17
18
        """
            If max_position_embeddings <= 0, there's no position embeddings
        """
19
        factory_kwargs = {'device': device, 'dtype': dtype}
20
        super().__init__()
21
22
        self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
                                            **factory_kwargs)
23
24
        self.max_position_embeddings = max_position_embeddings
        if self.max_position_embeddings > 0:
25
26
            self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
                                                    **factory_kwargs)
27
28
29
30

    def forward(self, input_ids, position_ids=None):
        """
            input_ids: (batch, seqlen)
Tri Dao's avatar
Tri Dao committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
            position_ids: (batch, seqlen)
        """
        batch_size, seqlen = input_ids.shape
        embeddings = self.word_embeddings(input_ids)
        if self.max_position_embeddings > 0:
            if position_ids is None:
                position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
            position_embeddings = self.position_embeddings(position_ids)
            embeddings = embeddings + position_embeddings
        return embeddings


class BertEmbeddings(nn.Module):

    def __init__(self, embed_dim, vocab_size, max_position_embeddings, type_vocab_size,
46
                 padding_idx=None, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
47
48
49
50
        """
            If max_position_embeddings <= 0, there's no position embeddings
            If type_vocab_size <= 0, there's no token type embeddings
        """
51
        factory_kwargs = {'device': device, 'dtype': dtype}
Tri Dao's avatar
Tri Dao committed
52
        super().__init__()
53
54
        self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
                                            **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
55
56
57
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        if self.max_position_embeddings > 0:
58
59
            self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
                                                    **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
60
        if self.type_vocab_size > 0:
61
62
            self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim,
                                                      **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
63
64
65
66
67
68

    def forward(self, input_ids, position_ids=None, token_type_ids=None):
        """
            input_ids: (batch, seqlen)
            position_ids: (batch, seqlen)
            token_type_ids: (batch, seqlen)
69
70
        """
        batch_size, seqlen = input_ids.shape
Tri Dao's avatar
Tri Dao committed
71
        embeddings = self.word_embeddings(input_ids)
72
73
        if self.max_position_embeddings > 0:
            if position_ids is None:
Tri Dao's avatar
Tri Dao committed
74
                position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
75
            position_embeddings = self.position_embeddings(position_ids)
Tri Dao's avatar
Tri Dao committed
76
77
78
79
80
81
82
            embeddings = embeddings + position_embeddings
        if self.type_vocab_size > 0:
            if token_type_ids is None:
                token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
            token_type_embeddings = self.token_type_embeddings(token_type_ids)
            embeddings = embeddings + token_type_embeddings
        return embeddings
83
84


85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
class VocabParallelEmbedding(nn.Embedding):

    def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
        self.process_group = process_group
        if process_group is not None:
            world_size = torch.distributed.get_world_size(process_group)
            if num_embeddings % world_size != 0:
                raise ValueError(f'num_embeddings ({num_embeddings}) must be divisible by '
                                 f'world_size ({world_size})')
            if world_size > 1 and padding_idx is not None:
                raise RuntimeError('ParallelEmbedding does not support padding_idx')
        else:
            world_size = 1
        super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)

    def forward(self, input: Tensor) -> Tensor:
        if self.process_group is None:
            return super().forward(input)
        else:
            rank = torch.distributed.get_rank(self.process_group)
            vocab_size = self.num_embeddings
            vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
            # Create a mask of valid vocab ids (1 means it needs to be masked).
            input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
            input = input - vocab_start_index
            input[input_ids_mask] = 0
            embeddings = super().forward(input)
            embeddings[input_ids_mask] = 0.0
            return embeddings


class ColumnParallelEmbedding(nn.Embedding):

    def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
        self.process_group = process_group
        if process_group is not None:
            world_size = torch.distributed.get_world_size(process_group)
            if embedding_dim % world_size != 0:
                raise ValueError(f'embedding_dim ({embedding_dim}) must be divisible by '
                                 f'world_size ({world_size})')
        else:
            world_size = 1
        super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)


130
131
132
class ParallelGPT2Embeddings(nn.Module):

    def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group,
133
                 padding_idx=None, sequence_parallel=True, device=None, dtype=None):
134
135
136
137
138
139
        """
            If max_position_embeddings <= 0, there's no position embeddings
        """
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.process_group = process_group
140
        self.sequence_parallel = sequence_parallel
141
142
143
144
        self.word_embeddings = VocabParallelEmbedding(
            vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group,
            **factory_kwargs
        )
145
146
        self.max_position_embeddings = max_position_embeddings
        if self.max_position_embeddings > 0:
147
148
            self.position_embeddings = ColumnParallelEmbedding(
                max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs
149
150
151
152
153
154
155
156
157
            )

    def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
        """
            input_ids: (batch, seqlen)
            position_ids: (batch, seqlen)
        """
        batch_size, seqlen = input_ids.shape
        world_size = torch.distributed.get_world_size(self.process_group)
158
159
160
161
162
163
        embeddings = self.word_embeddings(input_ids)
        if self.max_position_embeddings > 0:
            if position_ids is None:
                position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
            position_embeddings = self.position_embeddings(position_ids)
            if world_size <= 1:
164
                embeddings = embeddings + position_embeddings
165
            else:
166
                partition_dim = self.position_embeddings.embedding_dim
167
                rank = torch.distributed.get_rank(self.process_group)
168
                embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings
169
170
        if combine_batch_seqlen_dim:
            embeddings = rearrange(embeddings, 'b s d -> (b s) d')
171
172
        reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
        return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)