tokenizer.py 8.15 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# Copyright 2025 StepFun Inc. All Rights Reserved.
# 
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
import torch.nn as nn
import torch
from typing import List


class LLaMaEmbedding(nn.Module):
    """Language model embeddings.

    Arguments:
        hidden_size: hidden size
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        init_method: weight initialization method
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """

    def __init__(self,
                 cfg,
                 ):
        super().__init__()
        self.hidden_size = cfg.hidden_size
        self.params_dtype = cfg.params_dtype
        self.fp32_residual_connection = cfg.fp32_residual_connection 
        self.embedding_weights_in_fp32 = cfg.embedding_weights_in_fp32
        self.word_embeddings = torch.nn.Embedding(
            cfg.padded_vocab_size, self.hidden_size,
        )
        self.embedding_dropout = torch.nn.Dropout(cfg.hidden_dropout)

    def forward(self, input_ids):
        # Embeddings.
        if self.embedding_weights_in_fp32:
            self.word_embeddings = self.word_embeddings.to(torch.float32)
        embeddings = self.word_embeddings(input_ids)
        if self.embedding_weights_in_fp32:
            embeddings = embeddings.to(self.params_dtype)
            self.word_embeddings = self.word_embeddings.to(self.params_dtype)

        # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
        embeddings = embeddings.transpose(0, 1).contiguous()

        # If the input flag for fp32 residual connection is set, convert for float.
        if self.fp32_residual_connection:
            embeddings = embeddings.float()

        # Dropout.
        embeddings = self.embedding_dropout(embeddings)

        return embeddings



class StepChatTokenizer:
    """Step Chat Tokenizer"""

    def __init__(
        self, model_file, name="StepChatTokenizer",
        bot_token="<|BOT|>",  # Begin of Turn
        eot_token="<|EOT|>",  # End of Turn
        call_start_token="<|CALL_START|>",      # Call Start
        call_end_token="<|CALL_END|>",          # Call End
        think_start_token="<|THINK_START|>",    # Think Start
        think_end_token="<|THINK_END|>",        # Think End
        mask_start_token="<|MASK_1e69f|>",      # Mask start
        mask_end_token="<|UNMASK_1e69f|>",      # Mask end
    ):
        import sentencepiece

        self._tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)

        self._vocab = {}
        self._inv_vocab = {}

        self._special_tokens = {}
        self._inv_special_tokens = {}

        self._t5_tokens = []

        for idx in range(self._tokenizer.get_piece_size()):
            text = self._tokenizer.id_to_piece(idx)
            self._inv_vocab[idx] = text
            self._vocab[text] = idx

            if self._tokenizer.is_control(idx) or self._tokenizer.is_unknown(idx):
                self._special_tokens[text] = idx
                self._inv_special_tokens[idx] = text

        self._unk_id = self._tokenizer.unk_id()
        self._bos_id = self._tokenizer.bos_id()
        self._eos_id = self._tokenizer.eos_id()

        for token in [
            bot_token, eot_token, call_start_token, call_end_token,
            think_start_token, think_end_token
        ]:
            assert token in self._vocab, f"Token '{token}' not found in tokenizer"
            assert token in self._special_tokens, f"Token '{token}' is not a special token"

        for token in [mask_start_token, mask_end_token]:
            assert token in self._vocab, f"Token '{token}' not found in tokenizer"

        self._bot_id = self._tokenizer.piece_to_id(bot_token)
        self._eot_id = self._tokenizer.piece_to_id(eot_token)
        self._call_start_id = self._tokenizer.piece_to_id(call_start_token)
        self._call_end_id = self._tokenizer.piece_to_id(call_end_token)
        self._think_start_id = self._tokenizer.piece_to_id(think_start_token)
        self._think_end_id = self._tokenizer.piece_to_id(think_end_token)
        self._mask_start_id = self._tokenizer.piece_to_id(mask_start_token)
        self._mask_end_id = self._tokenizer.piece_to_id(mask_end_token)

        self._underline_id = self._tokenizer.piece_to_id("\u2581")
        
    @property
    def vocab(self):
        return self._vocab

    @property
    def inv_vocab(self):
        return self._inv_vocab

    @property
    def vocab_size(self):
        return self._tokenizer.vocab_size()

    def tokenize(self, text: str) -> List[int]:
        return self._tokenizer.encode_as_ids(text)

    def detokenize(self, token_ids: List[int]) -> str:
        return self._tokenizer.decode_ids(token_ids)

    
class Tokens:
    def __init__(self, input_ids, cu_input_ids, attention_mask, cu_seqlens, max_seq_len) -> None:
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.cu_input_ids = cu_input_ids
        self.cu_seqlens = cu_seqlens
        self.max_seq_len = max_seq_len
    def to(self, device):
        self.input_ids = self.input_ids.to(device)
        self.attention_mask = self.attention_mask.to(device)
        self.cu_input_ids = self.cu_input_ids.to(device)
        self.cu_seqlens = self.cu_seqlens.to(device)
        return self
    
class Wrapped_StepChatTokenizer(StepChatTokenizer):
    def __call__(self, text, max_length=320, padding="max_length", truncation=True, return_tensors="pt"):
        # [bos, ..., eos, pad, pad, ..., pad]
        self.BOS = 1
        self.EOS = 2
        self.PAD = 2
        out_tokens = []
        attn_mask = []
        if len(text) == 0:
            part_tokens = [self.BOS] + [self.EOS]
            valid_size = len(part_tokens)
            if len(part_tokens) < max_length:
                part_tokens += [self.PAD] * (max_length - valid_size)
            out_tokens.append(part_tokens)
            attn_mask.append([1]*valid_size+[0]*(max_length-valid_size))
        else:
            for part in text:
                part_tokens = self.tokenize(part)
                part_tokens = part_tokens[:(max_length - 2)] # leave 2 space for bos and eos
                part_tokens = [self.BOS] + part_tokens + [self.EOS]
                valid_size = len(part_tokens)
                if len(part_tokens) < max_length:
                    part_tokens += [self.PAD] * (max_length - valid_size)
                out_tokens.append(part_tokens)
                attn_mask.append([1]*valid_size+[0]*(max_length-valid_size))

        out_tokens = torch.tensor(out_tokens, dtype=torch.long)
        attn_mask = torch.tensor(attn_mask, dtype=torch.long)

        # padding y based on tp size
        padded_len = 0
        padded_flag = True if padded_len > 0 else False
        if padded_flag:
            pad_tokens = torch.tensor([[self.PAD] * max_length], device=out_tokens.device)
            pad_attn_mask = torch.tensor([[1]*padded_len+[0]*(max_length-padded_len)], device=attn_mask.device)
            out_tokens = torch.cat([out_tokens, pad_tokens], dim=0)
            attn_mask = torch.cat([attn_mask, pad_attn_mask], dim=0)
        
        # cu_seqlens
        cu_out_tokens = out_tokens.masked_select(attn_mask != 0).unsqueeze(0)
        seqlen = attn_mask.sum(dim=1).tolist()
        cu_seqlens = torch.cumsum(torch.tensor([0]+seqlen), 0).to(device=out_tokens.device,dtype=torch.int32)
        max_seq_len = max(seqlen)
        return Tokens(out_tokens, cu_out_tokens, attn_mask, cu_seqlens, max_seq_len)