embedding.py 7.61 KB
Newer Older
1
2
3
4
# Copyright (c) 2022, Tri Dao.

import torch
import torch.nn as nn
5
from torch import Tensor
6

7
8
9
10
from einops import rearrange

from flash_attn.utils.distributed import reduce_scatter

11
12
13

class GPT2Embeddings(nn.Module):

14
15
    def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None,
                 device=None, dtype=None):
16
17
18
        """
            If max_position_embeddings <= 0, there's no position embeddings
        """
19
        factory_kwargs = {'device': device, 'dtype': dtype}
20
        super().__init__()
21
22
        self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
                                            **factory_kwargs)
23
24
        self.max_position_embeddings = max_position_embeddings
        if self.max_position_embeddings > 0:
25
26
            self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
                                                    **factory_kwargs)
27
28
29
30

    def forward(self, input_ids, position_ids=None):
        """
            input_ids: (batch, seqlen)
Tri Dao's avatar
Tri Dao committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
            position_ids: (batch, seqlen)
        """
        batch_size, seqlen = input_ids.shape
        embeddings = self.word_embeddings(input_ids)
        if self.max_position_embeddings > 0:
            if position_ids is None:
                position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
            position_embeddings = self.position_embeddings(position_ids)
            embeddings = embeddings + position_embeddings
        return embeddings


class BertEmbeddings(nn.Module):

    def __init__(self, embed_dim, vocab_size, max_position_embeddings, type_vocab_size,
46
                 padding_idx=None, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
47
48
49
50
        """
            If max_position_embeddings <= 0, there's no position embeddings
            If type_vocab_size <= 0, there's no token type embeddings
        """
51
        factory_kwargs = {'device': device, 'dtype': dtype}
Tri Dao's avatar
Tri Dao committed
52
        super().__init__()
53
54
        self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
                                            **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
55
56
57
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        if self.max_position_embeddings > 0:
58
59
            self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
                                                    **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
60
        if self.type_vocab_size > 0:
61
62
            self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim,
                                                      **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
63
64
65
66
67
68

    def forward(self, input_ids, position_ids=None, token_type_ids=None):
        """
            input_ids: (batch, seqlen)
            position_ids: (batch, seqlen)
            token_type_ids: (batch, seqlen)
69
70
        """
        batch_size, seqlen = input_ids.shape
Tri Dao's avatar
Tri Dao committed
71
        embeddings = self.word_embeddings(input_ids)
72
73
        if self.max_position_embeddings > 0:
            if position_ids is None:
Tri Dao's avatar
Tri Dao committed
74
                position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
75
            position_embeddings = self.position_embeddings(position_ids)
Tri Dao's avatar
Tri Dao committed
76
77
78
79
80
81
82
            embeddings = embeddings + position_embeddings
        if self.type_vocab_size > 0:
            if token_type_ids is None:
                token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
            token_type_embeddings = self.token_type_embeddings(token_type_ids)
            embeddings = embeddings + token_type_embeddings
        return embeddings
83
84


85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
class VocabParallelEmbedding(nn.Embedding):

    def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
        self.process_group = process_group
        if process_group is not None:
            world_size = torch.distributed.get_world_size(process_group)
            if num_embeddings % world_size != 0:
                raise ValueError(f'num_embeddings ({num_embeddings}) must be divisible by '
                                 f'world_size ({world_size})')
            if world_size > 1 and padding_idx is not None:
                raise RuntimeError('ParallelEmbedding does not support padding_idx')
        else:
            world_size = 1
        super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)

    def forward(self, input: Tensor) -> Tensor:
        if self.process_group is None:
            return super().forward(input)
        else:
            rank = torch.distributed.get_rank(self.process_group)
            vocab_size = self.num_embeddings
            vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
            # Create a mask of valid vocab ids (1 means it needs to be masked).
            input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
            input = input - vocab_start_index
            input[input_ids_mask] = 0
            embeddings = super().forward(input)
            embeddings[input_ids_mask] = 0.0
            return embeddings


class ColumnParallelEmbedding(nn.Embedding):

    def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
        self.process_group = process_group
        if process_group is not None:
            world_size = torch.distributed.get_world_size(process_group)
            if embedding_dim % world_size != 0:
                raise ValueError(f'embedding_dim ({embedding_dim}) must be divisible by '
                                 f'world_size ({world_size})')
        else:
            world_size = 1
        super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)


130
131
132
133
134
135
136
137
138
139
class ParallelGPT2Embeddings(nn.Module):

    def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group,
                 padding_idx=None, device=None, dtype=None):
        """
            If max_position_embeddings <= 0, there's no position embeddings
        """
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.process_group = process_group
140
141
142
143
        self.word_embeddings = VocabParallelEmbedding(
            vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group,
            **factory_kwargs
        )
144
145
        self.max_position_embeddings = max_position_embeddings
        if self.max_position_embeddings > 0:
146
147
            self.position_embeddings = ColumnParallelEmbedding(
                max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs
148
149
150
151
152
153
154
155
156
            )

    def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
        """
            input_ids: (batch, seqlen)
            position_ids: (batch, seqlen)
        """
        batch_size, seqlen = input_ids.shape
        world_size = torch.distributed.get_world_size(self.process_group)
157
158
159
160
161
162
        embeddings = self.word_embeddings(input_ids)
        if self.max_position_embeddings > 0:
            if position_ids is None:
                position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
            position_embeddings = self.position_embeddings(position_ids)
            if world_size <= 1:
163
                embeddings = embeddings + position_embeddings
164
            else:
165
                partition_dim = self.position_embeddings.embedding_dim
166
                rank = torch.distributed.get_rank(self.process_group)
167
                embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings
168
169
170
        if combine_batch_seqlen_dim:
            embeddings = rearrange(embeddings, 'b s d -> (b s) d')
        return embeddings if world_size <= 1 else reduce_scatter(embeddings, self.process_group)