stepllm.py 9.5 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
# Copyright 2025 StepFun Inc. All Rights Reserved.
# 
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
import os
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from stepvideo.text_encoder.flashattention import FlashSelfAttention
from stepvideo.modules.normalization import RMSNorm
from stepvideo.text_encoder.tokenizer import LLaMaEmbedding, Wrapped_StepChatTokenizer
from stepvideo.utils import with_empty_init
from safetensors.torch import load_file
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from einops import rearrange
import json


    
def safediv(n, d):
    q, r = divmod(n, d)
    assert r == 0
    return q


class MultiQueryAttention(nn.Module):
    def __init__(self, cfg, layer_id=None):
        super().__init__()

        self.head_dim = cfg.hidden_size // cfg.num_attention_heads
        self.max_seq_len = cfg.seq_length
        self.use_flash_attention = cfg.use_flash_attn
        assert self.use_flash_attention, 'FlashAttention is required!'

        self.n_groups = cfg.num_attention_groups
        self.tp_size = 1
        self.n_local_heads = cfg.num_attention_heads
        self.n_local_groups = self.n_groups

        self.wqkv = nn.Linear(
            cfg.hidden_size,
            cfg.hidden_size + self.head_dim * 2 * self.n_groups,
            bias=False,
        )
        self.wo = nn.Linear(
            cfg.hidden_size,
            cfg.hidden_size,
            bias=False,
        )

        assert self.use_flash_attention, 'non-Flash attention not supported yet.'
        self.core_attention = FlashSelfAttention(attention_dropout=cfg.attention_dropout)
        
        self.layer_id = layer_id

    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor],
        cu_seqlens: Optional[torch.Tensor],
        max_seq_len: Optional[torch.Tensor],
    ):
        seqlen, bsz, dim = x.shape
        xqkv = self.wqkv(x)

        xq, xkv = torch.split(
            xqkv,
            (dim // self.tp_size,
             self.head_dim*2*self.n_groups // self.tp_size
            ),
            dim=-1,
        )

        # gather on 1st dimension
        xq = xq.view(seqlen, bsz, self.n_local_heads, self.head_dim)
        xkv = xkv.view(seqlen, bsz, self.n_local_groups, 2 * self.head_dim)
        xk, xv = xkv.chunk(2, -1)

        # rotary embedding + flash attn
        xq = rearrange(xq, "s b h d -> b s h d")
        xk = rearrange(xk, "s b h d -> b s h d")
        xv = rearrange(xv, "s b h d -> b s h d")

        q_per_kv = self.n_local_heads // self.n_local_groups
        if q_per_kv > 1:
            b, s, h, d = xk.size()
            if h == 1:
                xk = xk.expand(b, s, q_per_kv, d)
                xv = xv.expand(b, s, q_per_kv, d)
            else:
                ''' To cover the cases where h > 1, we have
                    the following implementation, which is equivalent to:
                        xk = xk.repeat_interleave(q_per_kv, dim=-2)
                        xv = xv.repeat_interleave(q_per_kv, dim=-2)
                    but can avoid calling aten::item() that involves cpu.
                '''
                idx = torch.arange(q_per_kv * h, device=xk.device).reshape(q_per_kv, -1).permute(1, 0).flatten()
                xk = torch.index_select(xk.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous()
                xv = torch.index_select(xv.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous()

        if self.use_flash_attention:
            output = self.core_attention(xq, xk, xv,
                                      cu_seqlens=cu_seqlens,
                                      max_seq_len=max_seq_len)
            # reduce-scatter only support first dimention now
            output = rearrange(output, "b s h d -> s b (h d)").contiguous()
        else:
            xq, xk, xv = [
                rearrange(x, "b s ... -> s b ...").contiguous()
                for x in (xq, xk, xv)
            ]
            output = self.core_attention(xq, xk, xv, mask)
        output = self.wo(output)
        return output



class FeedForward(nn.Module):
    def __init__(
        self,
        cfg,
        dim: int,
        hidden_dim: int,
        layer_id: int,
        multiple_of: int=256,
    ):
        super().__init__()

        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        def swiglu(x):
            x = torch.chunk(x, 2, dim=-1)
            return F.silu(x[0]) * x[1]
        self.swiglu = swiglu
            
        self.w1 = nn.Linear(
            dim,
            2 * hidden_dim,
            bias=False,
        )
        self.w2 = nn.Linear(
            hidden_dim,
            dim,
            bias=False,
        )

    def forward(self, x):
        x = self.swiglu(self.w1(x))
        output = self.w2(x)
        return output



class TransformerBlock(nn.Module):
    def __init__(
        self, cfg, layer_id: int
    ):
        super().__init__()

        self.n_heads = cfg.num_attention_heads
        self.dim = cfg.hidden_size
        self.head_dim = cfg.hidden_size // cfg.num_attention_heads
        self.attention = MultiQueryAttention(
            cfg,
            layer_id=layer_id,
        )

        self.feed_forward = FeedForward(
            cfg,
            dim=cfg.hidden_size,
            hidden_dim=cfg.ffn_hidden_size,
            layer_id=layer_id,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(
            cfg.hidden_size,
            eps=cfg.layernorm_epsilon,
        )
        self.ffn_norm = RMSNorm(
            cfg.hidden_size,
            eps=cfg.layernorm_epsilon,
        )

    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor],
        cu_seqlens: Optional[torch.Tensor],
        max_seq_len: Optional[torch.Tensor],
    ):
        residual = self.attention.forward(
            self.attention_norm(x), mask,
            cu_seqlens, max_seq_len
        )
        h = x + residual
        ffn_res = self.feed_forward.forward(self.ffn_norm(h))
        out = h + ffn_res
        return out


class Transformer(nn.Module):
    def __init__(
        self,
        config,
        max_seq_size=8192,
    ):
        super().__init__()
        self.num_layers = config.num_layers
        self.layers = self._build_layers(config)

    def _build_layers(self, config):
        layers = torch.nn.ModuleList()
        for layer_id in range(self.num_layers):
            layers.append(
                TransformerBlock(
                    config,
                    layer_id=layer_id + 1 ,
                )
            )
        return layers

    def forward(
        self,
        hidden_states,
        attention_mask,
        cu_seqlens=None,
        max_seq_len=None,
    ):

        if max_seq_len is not None and not isinstance(max_seq_len, torch.Tensor):
            max_seq_len = torch.tensor(max_seq_len, dtype=torch.int32, device="cpu")

        for lid, layer in enumerate(self.layers):
            hidden_states = layer(
                                    hidden_states,
                                    attention_mask,
                                    cu_seqlens,
                                    max_seq_len,
                                )
        return hidden_states


class Step1Model(PreTrainedModel):
    config_class=PretrainedConfig
    @with_empty_init
    def __init__(
        self,
        config,
    ):
        super().__init__(config)
        self.tok_embeddings = LLaMaEmbedding(config)
        self.transformer = Transformer(config)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
    ):

        hidden_states = self.tok_embeddings(input_ids)

        hidden_states = self.transformer(
            hidden_states,
            attention_mask,
        )
        return hidden_states
    
    

class STEP1TextEncoder(torch.nn.Module):
    def __init__(self, model_dir, max_length=320):
        super(STEP1TextEncoder, self).__init__()
        self.max_length = max_length
        self.text_tokenizer = Wrapped_StepChatTokenizer(os.path.join(model_dir, 'step1_chat_tokenizer.model'))
        text_encoder = Step1Model.from_pretrained(model_dir)
        self.text_encoder = text_encoder.eval().to(torch.bfloat16)
        
    @torch.no_grad
    def forward(self, prompts, with_mask=True, max_length=None):
        self.device = next(self.text_encoder.parameters()).device
        with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
            if type(prompts) is str:
                prompts = [prompts]
            
            txt_tokens = self.text_tokenizer(
                prompts, max_length=max_length or self.max_length, padding="max_length", truncation=True, return_tensors="pt"
            )
            y = self.text_encoder(
                txt_tokens.input_ids.to(self.device), 
                attention_mask=txt_tokens.attention_mask.to(self.device) if with_mask else None
            )
            y_mask = txt_tokens.attention_mask
        return y.transpose(0,1), y_mask