embedding.py 6.79 KB
Newer Older
1
2
3
4
5
# Copyright (c) 2022, Tri Dao.

import torch
import torch.nn as nn

6
7
8
9
from einops import rearrange

from flash_attn.utils.distributed import reduce_scatter

10
11
12

class GPT2Embeddings(nn.Module):

13
14
    def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None,
                 device=None, dtype=None):
15
16
17
        """
            If max_position_embeddings <= 0, there's no position embeddings
        """
18
        factory_kwargs = {'device': device, 'dtype': dtype}
19
        super().__init__()
20
21
        self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
                                            **factory_kwargs)
22
23
        self.max_position_embeddings = max_position_embeddings
        if self.max_position_embeddings > 0:
24
25
            self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
                                                    **factory_kwargs)
26
27
28
29

    def forward(self, input_ids, position_ids=None):
        """
            input_ids: (batch, seqlen)
Tri Dao's avatar
Tri Dao committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
            position_ids: (batch, seqlen)
        """
        batch_size, seqlen = input_ids.shape
        embeddings = self.word_embeddings(input_ids)
        if self.max_position_embeddings > 0:
            if position_ids is None:
                position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
            position_embeddings = self.position_embeddings(position_ids)
            embeddings = embeddings + position_embeddings
        return embeddings


class BertEmbeddings(nn.Module):

    def __init__(self, embed_dim, vocab_size, max_position_embeddings, type_vocab_size,
45
                 padding_idx=None, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
46
47
48
49
        """
            If max_position_embeddings <= 0, there's no position embeddings
            If type_vocab_size <= 0, there's no token type embeddings
        """
50
        factory_kwargs = {'device': device, 'dtype': dtype}
Tri Dao's avatar
Tri Dao committed
51
        super().__init__()
52
53
        self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
                                            **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
54
55
56
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        if self.max_position_embeddings > 0:
57
58
            self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
                                                    **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
59
        if self.type_vocab_size > 0:
60
61
            self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim,
                                                      **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
62
63
64
65
66
67

    def forward(self, input_ids, position_ids=None, token_type_ids=None):
        """
            input_ids: (batch, seqlen)
            position_ids: (batch, seqlen)
            token_type_ids: (batch, seqlen)
68
69
        """
        batch_size, seqlen = input_ids.shape
Tri Dao's avatar
Tri Dao committed
70
        embeddings = self.word_embeddings(input_ids)
71
72
        if self.max_position_embeddings > 0:
            if position_ids is None:
Tri Dao's avatar
Tri Dao committed
73
                position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
74
            position_embeddings = self.position_embeddings(position_ids)
Tri Dao's avatar
Tri Dao committed
75
76
77
78
79
80
81
            embeddings = embeddings + position_embeddings
        if self.type_vocab_size > 0:
            if token_type_ids is None:
                token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
            token_type_embeddings = self.token_type_embeddings(token_type_ids)
            embeddings = embeddings + token_type_embeddings
        return embeddings
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144


class ParallelGPT2Embeddings(nn.Module):

    def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group,
                 padding_idx=None, device=None, dtype=None):
        """
            If max_position_embeddings <= 0, there's no position embeddings
        """
        world_size = torch.distributed.get_world_size(process_group)
        if vocab_size % world_size != 0:
            raise ValueError(f'vocab_size ({vocab_size}) must be divisible by '
                             f'world_size ({world_size})')
        if embed_dim % world_size != 0:
            raise ValueError(f'embed_dim ({embed_dim}) must be divisible by '
                             f'world_size ({world_size})')
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.process_group = process_group
        self.word_embeddings = nn.Embedding(vocab_size // world_size, embed_dim,
                                            padding_idx=padding_idx, **factory_kwargs)
        self.max_position_embeddings = max_position_embeddings
        if self.max_position_embeddings > 0:
            self.position_embeddings = nn.Embedding(
                max_position_embeddings, embed_dim // world_size, **factory_kwargs
            )

    def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
        """
            input_ids: (batch, seqlen)
            position_ids: (batch, seqlen)
        """
        batch_size, seqlen = input_ids.shape
        world_size = torch.distributed.get_world_size(self.process_group)
        if world_size <= 1:
            embeddings = self.word_embeddings(input_ids)
            if self.max_position_embeddings > 0:
                if position_ids is None:
                    position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
                position_embeddings = self.position_embeddings(position_ids)
                embeddings = embeddings + position_embeddings
            if combine_batch_seqlen_dim:
                embeddings = rearrange(embeddings, 'b s d -> (b s) d')
            return embeddings
        else:
            rank = torch.distributed.get_rank(self.process_group)
            vocab_size = self.word_embeddings.num_embeddings
            vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
            # Create a mask of valid vocab ids (1 means it needs to be masked).
            input_ids_mask = (input_ids < vocab_start_index) | (input_ids >= vocab_end_index)
            input_ids = input_ids - vocab_start_index
            input_ids[input_ids_mask] = 0
            embeddings = self.word_embeddings(input_ids)
            embeddings[input_ids_mask] = 0.0
            if self.max_position_embeddings > 0:
                if position_ids is None:
                    position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
                position_embeddings = self.position_embeddings(position_ids)
                partition_dim = self.position_embeddings.embedding_dim
                embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings
            if combine_batch_seqlen_dim:
                embeddings = rearrange(embeddings, 'b s d -> (b s) d')
            return reduce_scatter(embeddings, self.process_group)