model.py 21.1 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
10
# Modified from transformers.models.t5.modeling_t5
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from .tokenizer import HuggingfaceTokenizer
root's avatar
root committed
11
from loguru import logger
gushiqiao's avatar
gushiqiao committed
12
from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8, Q8FQuantLinearInt8, Q8FQuantLinearFp8
13

helloyongyang's avatar
helloyongyang committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

__all__ = [
    "T5Model",
    "T5Encoder",
    "T5Decoder",
    "T5EncoderModel",
]


def fp16_clamp(x):
    if x.dtype == torch.float16 and torch.isinf(x).any():
        clamp = torch.finfo(x.dtype).max - 1000
        x = torch.clamp(x, min=-clamp, max=clamp)
    return x


gushiqiao's avatar
gushiqiao committed
30
31
32
33
34
35
36
37
def optimize_memory_usage():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    import gc

    gc.collect()


helloyongyang's avatar
helloyongyang committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def init_weights(m):
    if isinstance(m, T5LayerNorm):
        nn.init.ones_(m.weight)
    elif isinstance(m, T5Model):
        nn.init.normal_(m.token_embedding.weight, std=1.0)
    elif isinstance(m, T5FeedForward):
        nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
        nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
        nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
    elif isinstance(m, T5Attention):
        nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
        nn.init.normal_(m.k.weight, std=m.dim**-0.5)
        nn.init.normal_(m.v.weight, std=m.dim**-0.5)
        nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
    elif isinstance(m, T5RelativeEmbedding):
Dongz's avatar
Dongz committed
53
        nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5)
helloyongyang's avatar
helloyongyang committed
54
55
56
57


class GELU(nn.Module):
    def forward(self, x):
Dongz's avatar
Dongz committed
58
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
helloyongyang's avatar
helloyongyang committed
59
60
61


class T5LayerNorm(nn.Module):
gushiqiao's avatar
gushiqiao committed
62
    def __init__(self, dim, eps=1e-6, dtype=torch.float16):
helloyongyang's avatar
helloyongyang committed
63
64
65
        super(T5LayerNorm, self).__init__()
        self.dim = dim
        self.eps = eps
gushiqiao's avatar
gushiqiao committed
66
        self.weight = nn.Parameter(torch.ones(dim, dtype=dtype))
helloyongyang's avatar
helloyongyang committed
67
68
69
70
71
72
73
74
75

    def forward(self, x):
        x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            x = x.type_as(self.weight)
        return self.weight * x


class T5Attention(nn.Module):
gushiqiao's avatar
gushiqiao committed
76
    def __init__(self, dim, dim_attn, num_heads, dropout=0.1, quantized=False, quant_scheme=None, dtype=torch.bfloat16):
helloyongyang's avatar
helloyongyang committed
77
78
79
80
81
82
83
        assert dim_attn % num_heads == 0
        super(T5Attention, self).__init__()
        self.dim = dim
        self.dim_attn = dim_attn
        self.num_heads = num_heads
        self.head_dim = dim_attn // num_heads

84
85
        if quantized:
            if quant_scheme == "int8":
gushiqiao's avatar
gushiqiao committed
86
                linear_cls = VllmQuantLinearInt8
87
            elif quant_scheme == "fp8":
gushiqiao's avatar
gushiqiao committed
88
89
90
                linear_cls = VllmQuantLinearFp8
            elif quant_scheme == "int8-torchao":
                linear_cls = TorchaoQuantLinearInt8
gushiqiao's avatar
gushiqiao committed
91
92
93
94
            elif quant_scheme == "int8-q8f":
                linear_cls = Q8FQuantLinearInt8
            elif quant_scheme == "fp8-q8f":
                linear_cls = Q8FQuantLinearFp8
95
96
97
        else:
            linear_cls = nn.Linear

helloyongyang's avatar
helloyongyang committed
98
        # layers
gushiqiao's avatar
gushiqiao committed
99
100
101
102
        self.q = linear_cls(dim, dim_attn, bias=False, dtype=dtype)
        self.k = linear_cls(dim, dim_attn, bias=False, dtype=dtype)
        self.v = linear_cls(dim, dim_attn, bias=False, dtype=dtype)
        self.o = linear_cls(dim_attn, dim, bias=False, dtype=dtype)
helloyongyang's avatar
helloyongyang committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, context=None, mask=None, pos_bias=None):
        """
        x:          [B, L1, C].
        context:    [B, L2, C] or None.
        mask:       [B, L2] or [B, L1, L2] or None.
        """
        # check inputs
        context = x if context is None else context
        b, n, c = x.size(0), self.num_heads, self.head_dim

        # compute query, key, value
        q = self.q(x).view(b, -1, n, c)
        k = self.k(context).view(b, -1, n, c)
        v = self.v(context).view(b, -1, n, c)

        # attention bias
        attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
        if pos_bias is not None:
            attn_bias += pos_bias
        if mask is not None:
            assert mask.ndim in [2, 3]
            mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
            attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)

        # compute attention (T5 does not use scaling)
        attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
gushiqiao's avatar
gushiqiao committed
131
132
133

        if hasattr(self, "cpu_offload") and self.cpu_offload:
            del attn_bias
134
        attn = F.softmax(attn.float(), dim=-1).to(torch.bfloat16)
helloyongyang's avatar
helloyongyang committed
135
136
        x = torch.einsum("bnij,bjnc->binc", attn, v)

gushiqiao's avatar
gushiqiao committed
137
138
        if hasattr(self, "cpu_offload") and self.cpu_offload:
            del attn
helloyongyang's avatar
helloyongyang committed
139
140
141
142
143
144
145
        x = x.reshape(b, -1, n * c)
        x = self.o(x)
        x = self.dropout(x)
        return x


class T5FeedForward(nn.Module):
gushiqiao's avatar
gushiqiao committed
146
    def __init__(self, dim, dim_ffn, dropout=0.1, quantized=False, quant_scheme=None, dtype=torch.bfloat16):
helloyongyang's avatar
helloyongyang committed
147
148
149
150
        super(T5FeedForward, self).__init__()
        self.dim = dim
        self.dim_ffn = dim_ffn

151
152
        if quantized:
            if quant_scheme == "int8":
gushiqiao's avatar
gushiqiao committed
153
                linear_cls = VllmQuantLinearInt8
154
            elif quant_scheme == "fp8":
gushiqiao's avatar
gushiqiao committed
155
156
157
                linear_cls = VllmQuantLinearFp8
            elif quant_scheme == "int8-torchao":
                linear_cls = TorchaoQuantLinearInt8
gushiqiao's avatar
gushiqiao committed
158
159
160
161
            elif quant_scheme == "int8-q8f":
                linear_cls = Q8FQuantLinearInt8
            elif quant_scheme == "fp8-q8f":
                linear_cls = Q8FQuantLinearFp8
162
163
        else:
            linear_cls = nn.Linear
helloyongyang's avatar
helloyongyang committed
164
        # layers
gushiqiao's avatar
gushiqiao committed
165
166
167
        self.gate = nn.Sequential(linear_cls(dim, dim_ffn, bias=False, dtype=dtype), GELU())
        self.fc1 = linear_cls(dim, dim_ffn, bias=False, dtype=dtype)
        self.fc2 = linear_cls(dim_ffn, dim, bias=False, dtype=dtype)
helloyongyang's avatar
helloyongyang committed
168
169
170
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
gushiqiao's avatar
gushiqiao committed
171
172
173
174
175
176
177
178
        if hasattr(self, "cpu_offload") and self.cpu_offload:
            gate_out = self.gate(x)
            fc1_out = self.fc1(x)
            x = fc1_out * gate_out
            del gate_out, fc1_out
        else:
            x = self.fc1(x) * self.gate(x)

helloyongyang's avatar
helloyongyang committed
179
180
181
182
183
184
185
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class T5SelfAttention(nn.Module):
gushiqiao's avatar
gushiqiao committed
186
    def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1, quantized=False, quant_scheme=None, dtype=torch.bfloat16):
helloyongyang's avatar
helloyongyang committed
187
188
189
190
191
192
193
194
195
        super(T5SelfAttention, self).__init__()
        self.dim = dim
        self.dim_attn = dim_attn
        self.dim_ffn = dim_ffn
        self.num_heads = num_heads
        self.num_buckets = num_buckets
        self.shared_pos = shared_pos

        # layers
gushiqiao's avatar
gushiqiao committed
196
197
198
199
200
        self.norm1 = T5LayerNorm(dim, dtype=dtype)
        self.attn = T5Attention(dim, dim_attn, num_heads, dropout, quantized, quant_scheme, dtype)
        self.norm2 = T5LayerNorm(dim, dtype=dtype)
        self.ffn = T5FeedForward(dim, dim_ffn, dropout, quantized, quant_scheme, dtype=dtype)
        self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, dtype=dtype)
helloyongyang's avatar
helloyongyang committed
201
202
203

    def forward(self, x, mask=None, pos_bias=None):
        e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
gushiqiao's avatar
gushiqiao committed
204
205
206
207
208
209
210
211
212
213
214
215
216

        if hasattr(self, "cpu_offload") and self.cpu_offload:
            attn_out = self.attn(self.norm1(x), mask=mask, pos_bias=e)
            x = fp16_clamp(x + attn_out)
            del attn_out

            ffn_out = self.ffn(self.norm2(x))
            x = fp16_clamp(x + ffn_out)
            del ffn_out
        else:
            x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
            x = fp16_clamp(x + self.ffn(self.norm2(x)))

helloyongyang's avatar
helloyongyang committed
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        return x


class T5CrossAttention(nn.Module):
    def __init__(
        self,
        dim,
        dim_attn,
        dim_ffn,
        num_heads,
        num_buckets,
        shared_pos=True,
        dropout=0.1,
    ):
        super(T5CrossAttention, self).__init__()
        self.dim = dim
        self.dim_attn = dim_attn
        self.dim_ffn = dim_ffn
        self.num_heads = num_heads
        self.num_buckets = num_buckets
        self.shared_pos = shared_pos

        # layers
        self.norm1 = T5LayerNorm(dim)
        self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
        self.norm2 = T5LayerNorm(dim)
        self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
        self.norm3 = T5LayerNorm(dim)
        self.ffn = T5FeedForward(dim, dim_ffn, dropout)
Dongz's avatar
Dongz committed
246
        self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
helloyongyang's avatar
helloyongyang committed
247

Dongz's avatar
Dongz committed
248
    def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None):
helloyongyang's avatar
helloyongyang committed
249
250
        e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
        x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
Dongz's avatar
Dongz committed
251
        x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask))
helloyongyang's avatar
helloyongyang committed
252
253
254
255
256
        x = fp16_clamp(x + self.ffn(self.norm3(x)))
        return x


class T5RelativeEmbedding(nn.Module):
gushiqiao's avatar
gushiqiao committed
257
    def __init__(self, num_buckets, num_heads, bidirectional, dtype=torch.bfloat16, max_dist=128):
helloyongyang's avatar
helloyongyang committed
258
259
260
261
262
263
264
        super(T5RelativeEmbedding, self).__init__()
        self.num_buckets = num_buckets
        self.num_heads = num_heads
        self.bidirectional = bidirectional
        self.max_dist = max_dist

        # layers
gushiqiao's avatar
gushiqiao committed
265
        self.embedding = nn.Embedding(num_buckets, num_heads, dtype=dtype)
helloyongyang's avatar
helloyongyang committed
266
267
268
269
270

    def forward(self, lq, lk):
        device = self.embedding.weight.device
        # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
        #     torch.arange(lq).unsqueeze(1).to(device)
Dongz's avatar
Dongz committed
271
        rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
helloyongyang's avatar
helloyongyang committed
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
        rel_pos = self._relative_position_bucket(rel_pos)
        rel_pos_embeds = self.embedding(rel_pos)
        rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0)  # [1, N, Lq, Lk]
        return rel_pos_embeds.contiguous()

    def _relative_position_bucket(self, rel_pos):
        # preprocess
        if self.bidirectional:
            num_buckets = self.num_buckets // 2
            rel_buckets = (rel_pos > 0).long() * num_buckets
            rel_pos = torch.abs(rel_pos)
        else:
            num_buckets = self.num_buckets
            rel_buckets = 0
            rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))

        # embeddings for small and large positions
        max_exact = num_buckets // 2
Dongz's avatar
Dongz committed
290
291
        rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)).long()
        rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
helloyongyang's avatar
helloyongyang committed
292
293
294
295
296
        rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
        return rel_buckets


class T5Encoder(nn.Module):
gushiqiao's avatar
gushiqiao committed
297
    def __init__(self, dtype, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1, cpu_offload=False, quantized=False, quant_scheme=None):
helloyongyang's avatar
helloyongyang committed
298
        super(T5Encoder, self).__init__()
299

300
        self.cpu_offload = cpu_offload
helloyongyang's avatar
helloyongyang committed
301
302
303
304
305
306
307
        self.dim = dim
        self.dim_attn = dim_attn
        self.dim_ffn = dim_ffn
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.num_buckets = num_buckets
        self.shared_pos = shared_pos
308
        self.quant_scheme = quant_scheme
helloyongyang's avatar
helloyongyang committed
309
310

        # layers
gushiqiao's avatar
gushiqiao committed
311
312
        self.token_embedding = vocab.to(dtype) if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim, dtype=dtype)
        self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, dtype=dtype) if shared_pos else None
helloyongyang's avatar
helloyongyang committed
313
        self.dropout = nn.Dropout(dropout)
gushiqiao's avatar
gushiqiao committed
314
        self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, quantized, quant_scheme, dtype) for _ in range(num_layers)])
gushiqiao's avatar
gushiqiao committed
315
316
317
318
319
320

        if cpu_offload:
            for block in self.blocks:
                block.cpu_offload = cpu_offload
                block.attn.cpu_offload = cpu_offload
                block.ffn.cpu_offload = cpu_offload
gushiqiao's avatar
gushiqiao committed
321
        self.norm = T5LayerNorm(dim, dtype=dtype)
helloyongyang's avatar
helloyongyang committed
322
323

        # initialize weights
324
        # self.apply(init_weights)
helloyongyang's avatar
helloyongyang committed
325
326

    def forward(self, ids, mask=None):
327
328
        if self.cpu_offload:
            self.token_embedding = self.token_embedding.cuda()
helloyongyang's avatar
helloyongyang committed
329
        x = self.token_embedding(ids)
330
331
        if self.cpu_offload:
            self.token_embedding = self.token_embedding.cpu()
gushiqiao's avatar
gushiqiao committed
332
            optimize_memory_usage()
helloyongyang's avatar
helloyongyang committed
333
        x = self.dropout(x)
gushiqiao's avatar
gushiqiao committed
334

335
336
        if self.cpu_offload and self.pos_embedding is not None:
            self.pos_embedding = self.pos_embedding.cuda()
helloyongyang's avatar
helloyongyang committed
337
        e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
338
339
        if self.cpu_offload and self.pos_embedding is not None:
            self.pos_embedding = self.pos_embedding.cpu()
gushiqiao's avatar
gushiqiao committed
340
341
342
            optimize_memory_usage()

        for i, block in enumerate(self.blocks):
343
344
            if self.cpu_offload:
                block = block.cuda()
helloyongyang's avatar
helloyongyang committed
345
            x = block(x, mask, pos_bias=e)
346
347
            if self.cpu_offload:
                block = block.cpu()
gushiqiao's avatar
gushiqiao committed
348
349
350
                del block
                optimize_memory_usage()

351
352
        if self.cpu_offload:
            self.norm = self.norm.cuda()
helloyongyang's avatar
helloyongyang committed
353
        x = self.norm(x)
354
355
        if self.cpu_offload:
            self.norm = self.norm.cpu()
gushiqiao's avatar
gushiqiao committed
356
357
            optimize_memory_usage()

helloyongyang's avatar
helloyongyang committed
358
        x = self.dropout(x)
359
        return x.to(torch.bfloat16)
helloyongyang's avatar
helloyongyang committed
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384


class T5Decoder(nn.Module):
    def __init__(
        self,
        vocab,
        dim,
        dim_attn,
        dim_ffn,
        num_heads,
        num_layers,
        num_buckets,
        shared_pos=True,
        dropout=0.1,
    ):
        super(T5Decoder, self).__init__()
        self.dim = dim
        self.dim_attn = dim_attn
        self.dim_ffn = dim_ffn
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.num_buckets = num_buckets
        self.shared_pos = shared_pos

        # layers
Dongz's avatar
Dongz committed
385
386
        self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
        self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None
helloyongyang's avatar
helloyongyang committed
387
        self.dropout = nn.Dropout(dropout)
Dongz's avatar
Dongz committed
388
        self.blocks = nn.ModuleList([T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)])
helloyongyang's avatar
helloyongyang committed
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
        self.norm = T5LayerNorm(dim)

        # initialize weights
        self.apply(init_weights)

    def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
        b, s = ids.size()

        # causal mask
        if mask is None:
            mask = torch.tril(torch.ones(1, s, s).to(ids.device))
        elif mask.ndim == 2:
            mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))

        # layers
        x = self.token_embedding(ids)
        x = self.dropout(x)
        e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
        for block in self.blocks:
            x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
        x = self.norm(x)
        x = self.dropout(x)
        return x


class T5Model(nn.Module):
    def __init__(
        self,
        vocab_size,
        dim,
        dim_attn,
        dim_ffn,
        num_heads,
        encoder_layers,
        decoder_layers,
        num_buckets,
        shared_pos=True,
        dropout=0.1,
    ):
        super(T5Model, self).__init__()
        self.vocab_size = vocab_size
        self.dim = dim
        self.dim_attn = dim_attn
        self.dim_ffn = dim_ffn
        self.num_heads = num_heads
        self.encoder_layers = encoder_layers
        self.decoder_layers = decoder_layers
        self.num_buckets = num_buckets

        # layers
        self.token_embedding = nn.Embedding(vocab_size, dim)
        self.encoder = T5Encoder(
            self.token_embedding,
            dim,
            dim_attn,
            dim_ffn,
            num_heads,
            encoder_layers,
            num_buckets,
            shared_pos,
            dropout,
        )
        self.decoder = T5Decoder(
            self.token_embedding,
            dim,
            dim_attn,
            dim_ffn,
            num_heads,
            decoder_layers,
            num_buckets,
            shared_pos,
            dropout,
        )
        self.head = nn.Linear(dim, vocab_size, bias=False)

        # initialize weights
        self.apply(init_weights)

    def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
        x = self.encoder(encoder_ids, encoder_mask)
        x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
        x = self.head(x)
        return x


def _t5(
    name,
    encoder_only=False,
    decoder_only=False,
    return_tokenizer=False,
    tokenizer_kwargs={},
    dtype=torch.float32,
    device="cpu",
    **kwargs,
):
    # sanity check
    assert not (encoder_only and decoder_only)

    # params
    if encoder_only:
        model_cls = T5Encoder
        kwargs["vocab"] = kwargs.pop("vocab_size")
        kwargs["num_layers"] = kwargs.pop("encoder_layers")
        _ = kwargs.pop("decoder_layers")
    elif decoder_only:
        model_cls = T5Decoder
        kwargs["vocab"] = kwargs.pop("vocab_size")
        kwargs["num_layers"] = kwargs.pop("decoder_layers")
        _ = kwargs.pop("encoder_layers")
    else:
        model_cls = T5Model

    # init model
    with torch.device(device):
gushiqiao's avatar
gushiqiao committed
503
        model = model_cls(dtype=dtype, **kwargs)
helloyongyang's avatar
helloyongyang committed
504
505

    # set device
gushiqiao's avatar
gushiqiao committed
506
    model = model.to(device=device)
507
    return model
helloyongyang's avatar
helloyongyang committed
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535


def umt5_xxl(**kwargs):
    cfg = dict(
        vocab_size=256384,
        dim=4096,
        dim_attn=4096,
        dim_ffn=10240,
        num_heads=64,
        encoder_layers=24,
        decoder_layers=24,
        num_buckets=32,
        shared_pos=False,
        dropout=0.1,
    )
    cfg.update(**kwargs)
    return _t5("umt5-xxl", **cfg)


class T5EncoderModel:
    def __init__(
        self,
        text_len,
        dtype=torch.bfloat16,
        device=torch.cuda.current_device(),
        checkpoint_path=None,
        tokenizer_path=None,
        shard_fn=None,
536
537
        cpu_offload=False,
        offload_granularity="model",
538
539
540
        t5_quantized=False,
        t5_quantized_ckpt=None,
        quant_scheme=None,
helloyongyang's avatar
helloyongyang committed
541
542
543
544
    ):
        self.text_len = text_len
        self.dtype = dtype
        self.device = device
545
546
547
548
        if t5_quantized_ckpt is not None and t5_quantized:
            self.checkpoint_path = t5_quantized_ckpt
        else:
            self.checkpoint_path = checkpoint_path
helloyongyang's avatar
helloyongyang committed
549
        self.tokenizer_path = tokenizer_path
550
551
552
553
554
555
        self.offload_granularity = offload_granularity

        # sync cpu offload
        self.cpu_offload = cpu_offload
        if self.cpu_offload:
            assert self.offload_granularity in ["block", "model"]
helloyongyang's avatar
helloyongyang committed
556

557
558
559
560
561
562
        model = (
            umt5_xxl(
                encoder_only=True,
                return_tokenizer=False,
                dtype=dtype,
                device=device,
563
564
565
                cpu_offload=(cpu_offload if self.offload_granularity == "block" else False),
                quantized=t5_quantized,
                quant_scheme=quant_scheme,
566
567
568
569
            )
            .eval()
            .requires_grad_(False)
        )
570

gushiqiao's avatar
gushiqiao committed
571
        logger.info(f"Start Loading weights from {self.checkpoint_path}")
572
        model.load_state_dict(torch.load(self.checkpoint_path, map_location="cpu", weights_only=True))
gushiqiao's avatar
gushiqiao committed
573
574
        logger.info(f"End Loading weights from {self.checkpoint_path}")

helloyongyang's avatar
helloyongyang committed
575
576
577
578
579
580
        self.model = model
        if shard_fn is not None:
            self.model = shard_fn(self.model, sync_module_states=False)
        else:
            self.model.to(self.device)
        # init tokenizer
Dongz's avatar
Dongz committed
581
        self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace")
helloyongyang's avatar
helloyongyang committed
582

TorynCurtis's avatar
TorynCurtis committed
583
584
585
586
587
588
    def to_cpu(self):
        self.model = self.model.to("cpu")

    def to_cuda(self):
        self.model = self.model.to("cuda")

gushiqiao's avatar
gushiqiao committed
589
590
591
592
    def optimize_memory(self):
        """优化内存使用"""
        optimize_memory_usage()

593
594
    def infer(self, texts):
        if self.cpu_offload and self.offload_granularity == "model":
TorynCurtis's avatar
TorynCurtis committed
595
596
            self.to_cuda()

helloyongyang's avatar
helloyongyang committed
597
598
599
600
        ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
        ids = ids.cuda()
        mask = mask.cuda()
        seq_lens = mask.gt(0).sum(dim=1).long()
gushiqiao's avatar
gushiqiao committed
601
602
603

        with torch.no_grad():
            context = self.model(ids, mask)
TorynCurtis's avatar
TorynCurtis committed
604

605
        if self.cpu_offload and self.offload_granularity == "model":
TorynCurtis's avatar
TorynCurtis committed
606
            self.to_cpu()
gushiqiao's avatar
gushiqiao committed
607
608
609
610
611
            optimize_memory_usage()

        del ids, mask
        if self.cpu_offload:
            optimize_memory_usage()
TorynCurtis's avatar
TorynCurtis committed
612

helloyongyang's avatar
helloyongyang committed
613
614
615
616
        return [u[:v] for u, v in zip(context, seq_lens)]


if __name__ == "__main__":
617
    checkpoint_dir = ""
helloyongyang's avatar
helloyongyang committed
618
619
620
621
622
623
624
625
626
627
628
629
    t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth"
    t5_tokenizer = "google/umt5-xxl"
    model = T5EncoderModel(
        text_len=512,
        dtype=torch.bfloat16,
        device=torch.device("cuda"),
        checkpoint_path=os.path.join(checkpoint_dir, t5_checkpoint),
        tokenizer_path=os.path.join(checkpoint_dir, t5_tokenizer),
        shard_fn=None,
    )
    text = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
    outputs = model.infer(text)
root's avatar
root committed
630
    logger.info(outputs)