model.py 28.2 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
3
# 1. 标准库导入
import gc
helloyongyang's avatar
helloyongyang committed
4
5
import math
import os
6
7
import sys
from pathlib import Path
PengGao's avatar
PengGao committed
8

9
# 2. 第三方库导入
helloyongyang's avatar
helloyongyang committed
10
11
12
import torch
import torch.nn as nn
import torch.nn.functional as F
root's avatar
root committed
13
from loguru import logger
14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
current_dir = Path(__file__).resolve().parent
project_root = current_dir.parent.parent.parent.parent.parent.parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList  # noqa E402
from lightx2v.common.offload.manager import WeightAsyncStreamManager  # noqa E402
from lightx2v.common.ops import *  # noqa E402
from lightx2v.models.input_encoders.hf.q_linear import (  # noqa E402
    Q8FQuantLinearFp8,  # noqa E402
    Q8FQuantLinearInt8,  # noqa E402
    SglQuantLinearFp8,  # noqa E402
    TorchaoQuantLinearInt8,  # noqa E402
    VllmQuantLinearInt8,  # noqa E402
)
from lightx2v.models.input_encoders.hf.wan.t5.tokenizer import HuggingfaceTokenizer  # noqa E402
from lightx2v.utils.envs import *  # noqa E402
from lightx2v.utils.registry_factory import (  # noqa E402
    EMBEDDING_WEIGHT_REGISTER,  # noqa E402
    MM_WEIGHT_REGISTER,  # noqa E402
    RMS_WEIGHT_REGISTER,  # noqa E402
)
from lightx2v.utils.utils import load_weights  # noqa E402
helloyongyang's avatar
helloyongyang committed
38
39
40
41
42
43
44
45
46

__all__ = [
    "T5Model",
    "T5Encoder",
    "T5Decoder",
    "T5EncoderModel",
]


47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
class T5OffloadBlocksWeights(WeightModule):
    def __init__(self, block_nums, mm_type):
        super().__init__()
        self.block_nums = block_nums
        self.blocks = WeightModuleList([T5OffloadSelfAttention(i, mm_type) for i in range(block_nums)])
        self.add_module("blocks", self.blocks)


class T5OffloadSelfAttention(WeightModule):
    def __init__(self, block_index, mm_type, block_prefix="blocks"):
        super().__init__()
        self.block_index = block_index
        if mm_type is None:
            mm_type = "Default"
        self.mm_type = mm_type

        self.add_module(
            "norm1",
            RMS_WEIGHT_REGISTER["sgl-kernel"](
                f"{block_prefix}.{self.block_index}.norm1.weight",
            ),
        )
        self.add_module(
            "norm2",
            RMS_WEIGHT_REGISTER["sgl-kernel"](
                f"{block_prefix}.{self.block_index}.norm2.weight",
            ),
        )
        self.add_module(
            "pos_embedding",
            EMBEDDING_WEIGHT_REGISTER["Default"](
                f"{block_prefix}.{self.block_index}.pos_embedding.embedding.weight",
            ),
        )

        self.compute_phases = WeightModuleList(
            [
                T5OffloadAttention(
                    block_index,
                    block_prefix,
                    mm_type,
                ),
                T5OffloadFeedForward(
                    block_index,
                    block_prefix,
                    mm_type,
                ),
            ]
        )
        self.add_module("compute_phases", self.compute_phases)


class T5OffloadAttention(WeightModule):
    def __init__(self, block_index, block_prefix, mm_type):
        super().__init__()
        self.block_index = block_index
        self.mm_type = mm_type

        self.add_module(
            "attn_q",
            MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.attn.q.weight", None),
        )
        self.add_module(
            "attn_k",
            MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.attn.k.weight", None),
        )
        self.add_module(
            "attn_v",
            MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.attn.v.weight", None),
        )
        self.add_module(
            "attn_o",
            MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.attn.o.weight", None),
        )


class T5OffloadFeedForward(WeightModule):
    def __init__(self, block_index, block_prefix, mm_type):
        super().__init__()
        self.block_index = block_index
        self.mm_type = mm_type

        self.add_module(
            "ffn_fc1",
            MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.ffn.fc1.weight", None),
        )
        self.add_module(
            "ffn_fc2",
            MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.ffn.fc2.weight", None),
        )
        self.add_module(
            "ffn_gate_0",
            MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.ffn.gate.0.weight", None),
        )
        self.gelu = GELU()


helloyongyang's avatar
helloyongyang committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def fp16_clamp(x):
    if x.dtype == torch.float16 and torch.isinf(x).any():
        clamp = torch.finfo(x.dtype).max - 1000
        x = torch.clamp(x, min=-clamp, max=clamp)
    return x


def init_weights(m):
    if isinstance(m, T5LayerNorm):
        nn.init.ones_(m.weight)
    elif isinstance(m, T5Model):
        nn.init.normal_(m.token_embedding.weight, std=1.0)
    elif isinstance(m, T5FeedForward):
        nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
        nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
        nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
    elif isinstance(m, T5Attention):
        nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
        nn.init.normal_(m.k.weight, std=m.dim**-0.5)
        nn.init.normal_(m.v.weight, std=m.dim**-0.5)
        nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
    elif isinstance(m, T5RelativeEmbedding):
Dongz's avatar
Dongz committed
166
        nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5)
helloyongyang's avatar
helloyongyang committed
167
168
169
170


class GELU(nn.Module):
    def forward(self, x):
Dongz's avatar
Dongz committed
171
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
helloyongyang's avatar
helloyongyang committed
172
173
174


class T5LayerNorm(nn.Module):
gushiqiao's avatar
gushiqiao committed
175
    def __init__(self, dim, eps=1e-6, dtype=torch.float16):
helloyongyang's avatar
helloyongyang committed
176
177
178
        super(T5LayerNorm, self).__init__()
        self.dim = dim
        self.eps = eps
gushiqiao's avatar
gushiqiao committed
179
        self.weight = nn.Parameter(torch.ones(dim, dtype=dtype))
helloyongyang's avatar
helloyongyang committed
180
181
182
183
184
185
186
187
188

    def forward(self, x):
        x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            x = x.type_as(self.weight)
        return self.weight * x


class T5Attention(nn.Module):
189
190
191
192
193
194
195
196
197
198
    def __init__(
        self,
        dim,
        dim_attn,
        num_heads,
        dropout=0.1,
        quantized=False,
        quant_scheme=None,
        dtype=torch.bfloat16,
    ):
helloyongyang's avatar
helloyongyang committed
199
200
201
202
203
204
205
        assert dim_attn % num_heads == 0
        super(T5Attention, self).__init__()
        self.dim = dim
        self.dim_attn = dim_attn
        self.num_heads = num_heads
        self.head_dim = dim_attn // num_heads

206
        if quantized:
gushiqiao's avatar
gushiqiao committed
207
            if quant_scheme in ["int8", "int8-vllm"]:
gushiqiao's avatar
gushiqiao committed
208
                linear_cls = VllmQuantLinearInt8
gushiqiao's avatar
gushiqiao committed
209
            elif quant_scheme in ["fp8", "fp8-sgl"]:
210
                linear_cls = SglQuantLinearFp8
gushiqiao's avatar
gushiqiao committed
211
212
            elif quant_scheme == "int8-torchao":
                linear_cls = TorchaoQuantLinearInt8
gushiqiao's avatar
gushiqiao committed
213
214
215
216
            elif quant_scheme == "int8-q8f":
                linear_cls = Q8FQuantLinearInt8
            elif quant_scheme == "fp8-q8f":
                linear_cls = Q8FQuantLinearFp8
gushiqiao's avatar
gushiqiao committed
217
218
            else:
                NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
219
220
221
        else:
            linear_cls = nn.Linear

helloyongyang's avatar
helloyongyang committed
222
        # layers
gushiqiao's avatar
gushiqiao committed
223
224
225
226
        self.q = linear_cls(dim, dim_attn, bias=False, dtype=dtype)
        self.k = linear_cls(dim, dim_attn, bias=False, dtype=dtype)
        self.v = linear_cls(dim, dim_attn, bias=False, dtype=dtype)
        self.o = linear_cls(dim_attn, dim, bias=False, dtype=dtype)
helloyongyang's avatar
helloyongyang committed
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, context=None, mask=None, pos_bias=None):
        """
        x:          [B, L1, C].
        context:    [B, L2, C] or None.
        mask:       [B, L2] or [B, L1, L2] or None.
        """
        # check inputs
        context = x if context is None else context
        b, n, c = x.size(0), self.num_heads, self.head_dim

        # compute query, key, value
        q = self.q(x).view(b, -1, n, c)
        k = self.k(context).view(b, -1, n, c)
        v = self.v(context).view(b, -1, n, c)

        # attention bias
        attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
        if pos_bias is not None:
            attn_bias += pos_bias
        if mask is not None:
            assert mask.ndim in [2, 3]
            mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
            attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)

        # compute attention (T5 does not use scaling)
        attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
255
        attn = F.softmax(attn.float(), dim=-1).type_as(attn)
helloyongyang's avatar
helloyongyang committed
256
257
258
        x = torch.einsum("bnij,bjnc->binc", attn, v)
        x = x.reshape(b, -1, n * c)
        x = self.o(x)
259

helloyongyang's avatar
helloyongyang committed
260
261
262
263
        return x


class T5FeedForward(nn.Module):
264
265
266
267
268
269
270
271
272
    def __init__(
        self,
        dim,
        dim_ffn,
        dropout=0.1,
        quantized=False,
        quant_scheme=None,
        dtype=torch.bfloat16,
    ):
helloyongyang's avatar
helloyongyang committed
273
274
275
276
        super(T5FeedForward, self).__init__()
        self.dim = dim
        self.dim_ffn = dim_ffn

277
        if quantized:
gushiqiao's avatar
gushiqiao committed
278
            if quant_scheme in ["int8", "int8-vllm"]:
gushiqiao's avatar
gushiqiao committed
279
                linear_cls = VllmQuantLinearInt8
gushiqiao's avatar
gushiqiao committed
280
            elif quant_scheme in ["fp8", "fp8-sgl"]:
281
                linear_cls = SglQuantLinearFp8
gushiqiao's avatar
gushiqiao committed
282
283
            elif quant_scheme == "int8-torchao":
                linear_cls = TorchaoQuantLinearInt8
gushiqiao's avatar
gushiqiao committed
284
285
286
287
            elif quant_scheme == "int8-q8f":
                linear_cls = Q8FQuantLinearInt8
            elif quant_scheme == "fp8-q8f":
                linear_cls = Q8FQuantLinearFp8
gushiqiao's avatar
gushiqiao committed
288
289
            else:
                NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
290
291
        else:
            linear_cls = nn.Linear
helloyongyang's avatar
helloyongyang committed
292
        # layers
gushiqiao's avatar
gushiqiao committed
293
        self.gate = nn.Sequential(linear_cls(dim, dim_ffn, bias=False, dtype=dtype), GELU())
294

gushiqiao's avatar
gushiqiao committed
295
296
        self.fc1 = linear_cls(dim, dim_ffn, bias=False, dtype=dtype)
        self.fc2 = linear_cls(dim_ffn, dim, bias=False, dtype=dtype)
helloyongyang's avatar
helloyongyang committed
297
298
299
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
300
        x = self.fc1(x) * self.gate(x)
helloyongyang's avatar
helloyongyang committed
301
302
303
304
305
306
307
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class T5SelfAttention(nn.Module):
308
309
310
311
312
313
314
315
316
317
318
319
320
    def __init__(
        self,
        dim,
        dim_attn,
        dim_ffn,
        num_heads,
        num_buckets,
        shared_pos=True,
        dropout=0.1,
        quantized=False,
        quant_scheme=None,
        dtype=torch.bfloat16,
    ):
helloyongyang's avatar
helloyongyang committed
321
322
323
324
325
326
327
328
329
        super(T5SelfAttention, self).__init__()
        self.dim = dim
        self.dim_attn = dim_attn
        self.dim_ffn = dim_ffn
        self.num_heads = num_heads
        self.num_buckets = num_buckets
        self.shared_pos = shared_pos

        # layers
gushiqiao's avatar
gushiqiao committed
330
331
332
333
334
        self.norm1 = T5LayerNorm(dim, dtype=dtype)
        self.attn = T5Attention(dim, dim_attn, num_heads, dropout, quantized, quant_scheme, dtype)
        self.norm2 = T5LayerNorm(dim, dtype=dtype)
        self.ffn = T5FeedForward(dim, dim_ffn, dropout, quantized, quant_scheme, dtype=dtype)
        self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, dtype=dtype)
helloyongyang's avatar
helloyongyang committed
335
336
337

    def forward(self, x, mask=None, pos_bias=None):
        e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
338
339
        x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
        x = fp16_clamp(x + self.ffn(self.norm2(x)))
gushiqiao's avatar
gushiqiao committed
340

helloyongyang's avatar
helloyongyang committed
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
        return x


class T5CrossAttention(nn.Module):
    def __init__(
        self,
        dim,
        dim_attn,
        dim_ffn,
        num_heads,
        num_buckets,
        shared_pos=True,
        dropout=0.1,
    ):
        super(T5CrossAttention, self).__init__()
        self.dim = dim
        self.dim_attn = dim_attn
        self.dim_ffn = dim_ffn
        self.num_heads = num_heads
        self.num_buckets = num_buckets
        self.shared_pos = shared_pos

        # layers
        self.norm1 = T5LayerNorm(dim)
        self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
        self.norm2 = T5LayerNorm(dim)
        self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
        self.norm3 = T5LayerNorm(dim)
        self.ffn = T5FeedForward(dim, dim_ffn, dropout)
Dongz's avatar
Dongz committed
370
        self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
helloyongyang's avatar
helloyongyang committed
371

Dongz's avatar
Dongz committed
372
    def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None):
helloyongyang's avatar
helloyongyang committed
373
374
        e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
        x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
Dongz's avatar
Dongz committed
375
        x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask))
helloyongyang's avatar
helloyongyang committed
376
377
378
379
380
        x = fp16_clamp(x + self.ffn(self.norm3(x)))
        return x


class T5RelativeEmbedding(nn.Module):
gushiqiao's avatar
gushiqiao committed
381
    def __init__(self, num_buckets, num_heads, bidirectional, dtype=torch.bfloat16, max_dist=128):
helloyongyang's avatar
helloyongyang committed
382
383
384
385
386
387
388
        super(T5RelativeEmbedding, self).__init__()
        self.num_buckets = num_buckets
        self.num_heads = num_heads
        self.bidirectional = bidirectional
        self.max_dist = max_dist

        # layers
gushiqiao's avatar
gushiqiao committed
389
        self.embedding = nn.Embedding(num_buckets, num_heads, dtype=dtype)
helloyongyang's avatar
helloyongyang committed
390
391
392
393
394

    def forward(self, lq, lk):
        device = self.embedding.weight.device
        # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
        #     torch.arange(lq).unsqueeze(1).to(device)
Dongz's avatar
Dongz committed
395
        rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
helloyongyang's avatar
helloyongyang committed
396
        rel_pos = self._relative_position_bucket(rel_pos)
397

helloyongyang's avatar
helloyongyang committed
398
        rel_pos_embeds = self.embedding(rel_pos)
399

helloyongyang's avatar
helloyongyang committed
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
        rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0)  # [1, N, Lq, Lk]
        return rel_pos_embeds.contiguous()

    def _relative_position_bucket(self, rel_pos):
        # preprocess
        if self.bidirectional:
            num_buckets = self.num_buckets // 2
            rel_buckets = (rel_pos > 0).long() * num_buckets
            rel_pos = torch.abs(rel_pos)
        else:
            num_buckets = self.num_buckets
            rel_buckets = 0
            rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))

        # embeddings for small and large positions
        max_exact = num_buckets // 2
Dongz's avatar
Dongz committed
416
417
        rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)).long()
        rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
helloyongyang's avatar
helloyongyang committed
418
419
420
421
422
        rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
        return rel_buckets


class T5Encoder(nn.Module):
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
    def __init__(
        self,
        dtype,
        vocab,
        dim,
        dim_attn,
        dim_ffn,
        num_heads,
        num_layers,
        num_buckets,
        shared_pos=True,
        dropout=0.1,
        cpu_offload=False,
        quantized=False,
        quant_scheme=None,
    ):
helloyongyang's avatar
helloyongyang committed
439
        super(T5Encoder, self).__init__()
440
        self.cpu_offload = cpu_offload
helloyongyang's avatar
helloyongyang committed
441
442
443
444
445
446
447
        self.dim = dim
        self.dim_attn = dim_attn
        self.dim_ffn = dim_ffn
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.num_buckets = num_buckets
        self.shared_pos = shared_pos
448
        self.quant_scheme = quant_scheme
helloyongyang's avatar
helloyongyang committed
449
450

        # layers
gushiqiao's avatar
gushiqiao committed
451
452
        self.token_embedding = vocab.to(dtype) if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim, dtype=dtype)
        self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, dtype=dtype) if shared_pos else None
helloyongyang's avatar
helloyongyang committed
453
        self.dropout = nn.Dropout(dropout)
gushiqiao's avatar
gushiqiao committed
454
455

        if cpu_offload:
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
            self.weights_stream_mgr = WeightAsyncStreamManager(blocks_num=num_layers)
            self.blocks_weights = T5OffloadBlocksWeights(num_layers, quant_scheme)
            self.blocks = self.blocks_weights.blocks
        else:
            self.blocks = nn.ModuleList(
                [
                    T5SelfAttention(
                        dim,
                        dim_attn,
                        dim_ffn,
                        num_heads,
                        num_buckets,
                        shared_pos,
                        dropout,
                        quantized,
                        quant_scheme,
                        dtype,
                    )
                    for _ in range(num_layers)
                ]
            )
helloyongyang's avatar
helloyongyang committed
477

478
        self.norm = T5LayerNorm(dim, dtype=dtype)
helloyongyang's avatar
helloyongyang committed
479

480
    def forward_without_offload(self, ids, mask=None):
helloyongyang's avatar
helloyongyang committed
481
482
483
        x = self.token_embedding(ids)
        x = self.dropout(x)
        e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
gushiqiao's avatar
gushiqiao committed
484
485

        for i, block in enumerate(self.blocks):
helloyongyang's avatar
helloyongyang committed
486
487
            x = block(x, mask, pos_bias=e)
        x = self.norm(x)
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
        x = self.dropout(x)
        return x.to(GET_DTYPE())

    def forword_attn_with_offload(self, x, attn_phase, context=None, mask=None, pos_bias=None):
        context = x if context is None else context
        b, n, c = x.size(0), self.num_heads, self.dim_attn // self.num_heads
        # compute query, key, value
        q = attn_phase.attn_q.apply(x.squeeze(0)).view(b, -1, n, c)
        k = attn_phase.attn_k.apply(context.squeeze(0)).view(b, -1, n, c)
        v = attn_phase.attn_v.apply(context.squeeze(0)).view(b, -1, n, c)
        # attention bias
        attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
        if pos_bias is not None:
            attn_bias += pos_bias
        if mask is not None:
            assert mask.ndim in [2, 3]
            mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
            attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)

        # compute attention (T5 does not use scaling)
        attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
        attn = F.softmax(attn.float(), dim=-1).type_as(attn)
        x = torch.einsum("bnij,bjnc->binc", attn, v)
        x = x.reshape(b, -1, n * c)
        x = attn_phase.attn_o.apply(x.squeeze(0)).unsqueeze(0)
        return x

    def forward_ffn_with_offload(self, x, ffn_phase):
        x = x.squeeze(0)
        x = ffn_phase.ffn_fc1.apply(x) * ffn_phase.gelu(ffn_phase.ffn_gate_0.apply(x))
        x = ffn_phase.ffn_fc2.apply(x)
        return x.unsqueeze(0)

    def forward_block_with_offload(self, block, x, mask=None, pos_bias=None):
        if self.shared_pos:
            e = pos_bias
        else:
            lq, lk = x.size(1), x.size(1)
            rel_pos = torch.arange(lk, device="cuda").unsqueeze(0) - torch.arange(lq, device="cuda").unsqueeze(1)
            num_buckets = block.pos_embedding.weight.shape[0] // 2
            rel_buckets = (rel_pos > 0).long() * num_buckets
            rel_pos = torch.abs(rel_pos)
            max_exact = num_buckets // 2
            rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / math.log(128 / max_exact) * (num_buckets - max_exact)).long()
            rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
            rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
            e = block.pos_embedding.apply(rel_buckets).permute(2, 0, 1).unsqueeze(0).contiguous()

        norm1_out = block.norm1.apply(x)
        x = fp16_clamp(x + self.forword_attn_with_offload(norm1_out, block.compute_phases[0], mask=mask, pos_bias=e))
        norm2_out = block.norm2.apply(x)
        x = fp16_clamp(x + self.forward_ffn_with_offload(norm2_out, block.compute_phases[1]))
        return x

    def forward_with_offload(self, ids, mask=None):
        self.token_embedding = self.token_embedding.to("cuda")
        self.pos_embedding = self.pos_embedding.to("cuda") if self.pos_embedding is not None else None

        x = self.token_embedding(ids)
        x = self.dropout(x)
        e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
        self.norm = self.norm.to("cuda")

        for block_idx in range(len(self.blocks)):
            self.block_idx = block_idx
            if block_idx == 0:
                self.weights_stream_mgr.active_weights[0] = self.blocks[0]
                self.weights_stream_mgr.active_weights[0].to_cuda()
gushiqiao's avatar
gushiqiao committed
556

557
558
559
560
561
562
563
564
            if block_idx < len(self.blocks) - 1:
                self.weights_stream_mgr.prefetch_weights(block_idx + 1, self.blocks)

            with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
                x = self.forward_block_with_offload(self.blocks[block_idx], x, mask, pos_bias=e)
            self.weights_stream_mgr.swap_weights()

        x = self.norm(x)
helloyongyang's avatar
helloyongyang committed
565
        x = self.dropout(x)
566
        return x.to(GET_DTYPE())
helloyongyang's avatar
helloyongyang committed
567

568
569
570
571
572
573
    def forward(self, ids, mask=None):
        if self.cpu_offload:
            return self.forward_with_offload(ids, mask)
        else:
            return self.forward_without_offload(ids, mask)

helloyongyang's avatar
helloyongyang committed
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597

class T5Decoder(nn.Module):
    def __init__(
        self,
        vocab,
        dim,
        dim_attn,
        dim_ffn,
        num_heads,
        num_layers,
        num_buckets,
        shared_pos=True,
        dropout=0.1,
    ):
        super(T5Decoder, self).__init__()
        self.dim = dim
        self.dim_attn = dim_attn
        self.dim_ffn = dim_ffn
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.num_buckets = num_buckets
        self.shared_pos = shared_pos

        # layers
Dongz's avatar
Dongz committed
598
599
        self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
        self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None
helloyongyang's avatar
helloyongyang committed
600
        self.dropout = nn.Dropout(dropout)
Dongz's avatar
Dongz committed
601
        self.blocks = nn.ModuleList([T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)])
helloyongyang's avatar
helloyongyang committed
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
        self.norm = T5LayerNorm(dim)

        # initialize weights
        self.apply(init_weights)

    def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
        b, s = ids.size()

        # causal mask
        if mask is None:
            mask = torch.tril(torch.ones(1, s, s).to(ids.device))
        elif mask.ndim == 2:
            mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))

        # layers
        x = self.token_embedding(ids)
        x = self.dropout(x)
        e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
        for block in self.blocks:
            x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
        x = self.norm(x)
        x = self.dropout(x)
        return x


class T5Model(nn.Module):
    def __init__(
        self,
        vocab_size,
        dim,
        dim_attn,
        dim_ffn,
        num_heads,
        encoder_layers,
        decoder_layers,
        num_buckets,
        shared_pos=True,
        dropout=0.1,
    ):
        super(T5Model, self).__init__()
        self.vocab_size = vocab_size
        self.dim = dim
        self.dim_attn = dim_attn
        self.dim_ffn = dim_ffn
        self.num_heads = num_heads
        self.encoder_layers = encoder_layers
        self.decoder_layers = decoder_layers
        self.num_buckets = num_buckets

        # layers
        self.token_embedding = nn.Embedding(vocab_size, dim)
        self.encoder = T5Encoder(
            self.token_embedding,
            dim,
            dim_attn,
            dim_ffn,
            num_heads,
            encoder_layers,
            num_buckets,
            shared_pos,
            dropout,
        )
        self.decoder = T5Decoder(
            self.token_embedding,
            dim,
            dim_attn,
            dim_ffn,
            num_heads,
            decoder_layers,
            num_buckets,
            shared_pos,
            dropout,
        )
        self.head = nn.Linear(dim, vocab_size, bias=False)

        # initialize weights
        self.apply(init_weights)

    def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
        x = self.encoder(encoder_ids, encoder_mask)
        x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
        x = self.head(x)
        return x


def _t5(
    name,
    encoder_only=False,
    decoder_only=False,
    return_tokenizer=False,
    tokenizer_kwargs={},
    dtype=torch.float32,
    device="cpu",
    **kwargs,
):
    # sanity check
    assert not (encoder_only and decoder_only)

    # params
    if encoder_only:
        model_cls = T5Encoder
        kwargs["vocab"] = kwargs.pop("vocab_size")
        kwargs["num_layers"] = kwargs.pop("encoder_layers")
        _ = kwargs.pop("decoder_layers")
    elif decoder_only:
        model_cls = T5Decoder
        kwargs["vocab"] = kwargs.pop("vocab_size")
        kwargs["num_layers"] = kwargs.pop("decoder_layers")
        _ = kwargs.pop("encoder_layers")
    else:
        model_cls = T5Model

    # init model
    with torch.device(device):
gushiqiao's avatar
gushiqiao committed
716
        model = model_cls(dtype=dtype, **kwargs)
helloyongyang's avatar
helloyongyang committed
717
718

    # set device
gushiqiao's avatar
gushiqiao committed
719
    model = model.to(device=device)
720
    return model
helloyongyang's avatar
helloyongyang committed
721
722


723
724
725
726
727
728
729
730
731
def split_block_weights(weights):
    block_weights = {}
    all_keys = list(weights.keys())
    for key in all_keys:
        if key.startswith(("blocks.")):
            block_weights[key] = weights.pop(key)
    return block_weights


helloyongyang's avatar
helloyongyang committed
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
def umt5_xxl(**kwargs):
    cfg = dict(
        vocab_size=256384,
        dim=4096,
        dim_attn=4096,
        dim_ffn=10240,
        num_heads=64,
        encoder_layers=24,
        decoder_layers=24,
        num_buckets=32,
        shared_pos=False,
        dropout=0.1,
    )
    cfg.update(**kwargs)
    return _t5("umt5-xxl", **cfg)


class T5EncoderModel:
    def __init__(
        self,
        text_len,
        dtype=torch.bfloat16,
        device=torch.cuda.current_device(),
        checkpoint_path=None,
        tokenizer_path=None,
        shard_fn=None,
758
        cpu_offload=False,
759
760
761
        t5_quantized=False,
        t5_quantized_ckpt=None,
        quant_scheme=None,
762
        load_from_rank0=False,
helloyongyang's avatar
helloyongyang committed
763
764
765
766
    ):
        self.text_len = text_len
        self.dtype = dtype
        self.device = device
767
768
769
770
        if t5_quantized_ckpt is not None and t5_quantized:
            self.checkpoint_path = t5_quantized_ckpt
        else:
            self.checkpoint_path = checkpoint_path
helloyongyang's avatar
helloyongyang committed
771
        self.tokenizer_path = tokenizer_path
772
773
774

        # sync cpu offload
        self.cpu_offload = cpu_offload
helloyongyang's avatar
helloyongyang committed
775

776
777
778
779
780
781
        model = (
            umt5_xxl(
                encoder_only=True,
                return_tokenizer=False,
                dtype=dtype,
                device=device,
782
                cpu_offload=cpu_offload,
783
784
                quantized=t5_quantized,
                quant_scheme=quant_scheme,
785
786
787
788
            )
            .eval()
            .requires_grad_(False)
        )
789

790
791
792
793
794
        weights_dict = load_weights(
            self.checkpoint_path,
            cpu_offload=cpu_offload,
            load_from_rank0=load_from_rank0,
        )
gushiqiao's avatar
gushiqiao committed
795

796
797
798
799
800
801
802
803
804
        if cpu_offload:
            block_weights_dict = split_block_weights(weights_dict)
            model.blocks_weights.load(block_weights_dict)
            del block_weights_dict
            gc.collect()

        model.load_state_dict(weights_dict)
        del weights_dict
        gc.collect()
helloyongyang's avatar
helloyongyang committed
805
806
807
808
809
810
        self.model = model
        if shard_fn is not None:
            self.model = shard_fn(self.model, sync_module_states=False)
        else:
            self.model.to(self.device)
        # init tokenizer
Dongz's avatar
Dongz committed
811
        self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace")
helloyongyang's avatar
helloyongyang committed
812

813
    def infer(self, texts):
helloyongyang's avatar
helloyongyang committed
814
815
816
817
        ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
        ids = ids.cuda()
        mask = mask.cuda()
        seq_lens = mask.gt(0).sum(dim=1).long()
gushiqiao's avatar
gushiqiao committed
818
819
820

        with torch.no_grad():
            context = self.model(ids, mask)
TorynCurtis's avatar
TorynCurtis committed
821

helloyongyang's avatar
helloyongyang committed
822
823
824
825
        return [u[:v] for u, v in zip(context, seq_lens)]


if __name__ == "__main__":
826
827
    import time

828
    checkpoint_dir = ""
829
830
831
832
833
834
835
836
837
    t5_checkpoint = "./models_t5_umt5-xxl-enc-bf16.pth"
    t5_tokenizer = "./google/umt5-xxl"

    cpu_offload = True
    if cpu_offload:
        device = torch.device("cpu")
    else:
        device = torch.device("cuda")

helloyongyang's avatar
helloyongyang committed
838
839
840
    model = T5EncoderModel(
        text_len=512,
        dtype=torch.bfloat16,
841
        device=device,
helloyongyang's avatar
helloyongyang committed
842
843
844
        checkpoint_path=os.path.join(checkpoint_dir, t5_checkpoint),
        tokenizer_path=os.path.join(checkpoint_dir, t5_tokenizer),
        shard_fn=None,
845
        cpu_offload=cpu_offload,
helloyongyang's avatar
helloyongyang committed
846
847
    )
    text = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
848
849
850

    torch.cuda.synchronize()
    s_t = time.time()
helloyongyang's avatar
helloyongyang committed
851
    outputs = model.infer(text)
852
853
854
855
856

    torch.cuda.synchronize()
    e_t = time.time()

    logger.info(e_t - s_t)
root's avatar
root committed
857
    logger.info(outputs)