model.py 18.4 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
10
# Modified from transformers.models.t5.modeling_t5
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from .tokenizer import HuggingfaceTokenizer
root's avatar
root committed
11
from loguru import logger
12
13
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8

helloyongyang's avatar
helloyongyang committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

__all__ = [
    "T5Model",
    "T5Encoder",
    "T5Decoder",
    "T5EncoderModel",
]


def fp16_clamp(x):
    if x.dtype == torch.float16 and torch.isinf(x).any():
        clamp = torch.finfo(x.dtype).max - 1000
        x = torch.clamp(x, min=-clamp, max=clamp)
    return x


def init_weights(m):
    if isinstance(m, T5LayerNorm):
        nn.init.ones_(m.weight)
    elif isinstance(m, T5Model):
        nn.init.normal_(m.token_embedding.weight, std=1.0)
    elif isinstance(m, T5FeedForward):
        nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
        nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
        nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
    elif isinstance(m, T5Attention):
        nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
        nn.init.normal_(m.k.weight, std=m.dim**-0.5)
        nn.init.normal_(m.v.weight, std=m.dim**-0.5)
        nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
    elif isinstance(m, T5RelativeEmbedding):
Dongz's avatar
Dongz committed
45
        nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5)
helloyongyang's avatar
helloyongyang committed
46
47
48
49


class GELU(nn.Module):
    def forward(self, x):
Dongz's avatar
Dongz committed
50
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
helloyongyang's avatar
helloyongyang committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67


class T5LayerNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super(T5LayerNorm, self).__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            x = x.type_as(self.weight)
        return self.weight * x


class T5Attention(nn.Module):
68
    def __init__(self, dim, dim_attn, num_heads, dropout=0.1, quantized=False, quant_scheme=None):
helloyongyang's avatar
helloyongyang committed
69
70
71
72
73
74
75
        assert dim_attn % num_heads == 0
        super(T5Attention, self).__init__()
        self.dim = dim
        self.dim_attn = dim_attn
        self.num_heads = num_heads
        self.head_dim = dim_attn // num_heads

76
77
78
79
80
81
        if quantized:
            if quant_scheme == "int8":
                linear_cls = QuantLinearInt8
        else:
            linear_cls = nn.Linear

helloyongyang's avatar
helloyongyang committed
82
        # layers
83
84
85
86
        self.q = linear_cls(dim, dim_attn, bias=False)
        self.k = linear_cls(dim, dim_attn, bias=False)
        self.v = linear_cls(dim, dim_attn, bias=False)
        self.o = linear_cls(dim_attn, dim, bias=False)
helloyongyang's avatar
helloyongyang committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, context=None, mask=None, pos_bias=None):
        """
        x:          [B, L1, C].
        context:    [B, L2, C] or None.
        mask:       [B, L2] or [B, L1, L2] or None.
        """
        # check inputs
        context = x if context is None else context
        b, n, c = x.size(0), self.num_heads, self.head_dim

        # compute query, key, value
        q = self.q(x).view(b, -1, n, c)
        k = self.k(context).view(b, -1, n, c)
        v = self.v(context).view(b, -1, n, c)

        # attention bias
        attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
        if pos_bias is not None:
            attn_bias += pos_bias
        if mask is not None:
            assert mask.ndim in [2, 3]
            mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
            attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)

        # compute attention (T5 does not use scaling)
        attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
115
        attn = F.softmax(attn.float(), dim=-1).to(torch.bfloat16)
helloyongyang's avatar
helloyongyang committed
116
117
118
119
120
121
122
123
124
125
        x = torch.einsum("bnij,bjnc->binc", attn, v)

        # output
        x = x.reshape(b, -1, n * c)
        x = self.o(x)
        x = self.dropout(x)
        return x


class T5FeedForward(nn.Module):
126
    def __init__(self, dim, dim_ffn, dropout=0.1, quantized=False, quant_scheme=None):
helloyongyang's avatar
helloyongyang committed
127
128
129
130
        super(T5FeedForward, self).__init__()
        self.dim = dim
        self.dim_ffn = dim_ffn

131
132
133
134
135
        if quantized:
            if quant_scheme == "int8":
                linear_cls = QuantLinearInt8
        else:
            linear_cls = nn.Linear
helloyongyang's avatar
helloyongyang committed
136
        # layers
137
138
139
        self.gate = nn.Sequential(linear_cls(dim, dim_ffn, bias=False), GELU())
        self.fc1 = linear_cls(dim, dim_ffn, bias=False)
        self.fc2 = linear_cls(dim_ffn, dim, bias=False)
helloyongyang's avatar
helloyongyang committed
140
141
142
143
144
145
146
147
148
149
150
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x) * self.gate(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class T5SelfAttention(nn.Module):
151
    def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1, quantized=False, quant_scheme=None):
helloyongyang's avatar
helloyongyang committed
152
153
154
155
156
157
158
159
160
161
        super(T5SelfAttention, self).__init__()
        self.dim = dim
        self.dim_attn = dim_attn
        self.dim_ffn = dim_ffn
        self.num_heads = num_heads
        self.num_buckets = num_buckets
        self.shared_pos = shared_pos

        # layers
        self.norm1 = T5LayerNorm(dim)
162
        self.attn = T5Attention(dim, dim_attn, num_heads, dropout, quantized, quant_scheme)
helloyongyang's avatar
helloyongyang committed
163
        self.norm2 = T5LayerNorm(dim)
164
        self.ffn = T5FeedForward(dim, dim_ffn, dropout, quantized, quant_scheme)
Dongz's avatar
Dongz committed
165
        self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
helloyongyang's avatar
helloyongyang committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

    def forward(self, x, mask=None, pos_bias=None):
        e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
        x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
        x = fp16_clamp(x + self.ffn(self.norm2(x)))
        return x


class T5CrossAttention(nn.Module):
    def __init__(
        self,
        dim,
        dim_attn,
        dim_ffn,
        num_heads,
        num_buckets,
        shared_pos=True,
        dropout=0.1,
    ):
        super(T5CrossAttention, self).__init__()
        self.dim = dim
        self.dim_attn = dim_attn
        self.dim_ffn = dim_ffn
        self.num_heads = num_heads
        self.num_buckets = num_buckets
        self.shared_pos = shared_pos

        # layers
        self.norm1 = T5LayerNorm(dim)
        self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
        self.norm2 = T5LayerNorm(dim)
        self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
        self.norm3 = T5LayerNorm(dim)
        self.ffn = T5FeedForward(dim, dim_ffn, dropout)
Dongz's avatar
Dongz committed
200
        self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
helloyongyang's avatar
helloyongyang committed
201

Dongz's avatar
Dongz committed
202
    def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None):
helloyongyang's avatar
helloyongyang committed
203
204
        e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
        x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
Dongz's avatar
Dongz committed
205
        x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask))
helloyongyang's avatar
helloyongyang committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        x = fp16_clamp(x + self.ffn(self.norm3(x)))
        return x


class T5RelativeEmbedding(nn.Module):
    def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
        super(T5RelativeEmbedding, self).__init__()
        self.num_buckets = num_buckets
        self.num_heads = num_heads
        self.bidirectional = bidirectional
        self.max_dist = max_dist

        # layers
        self.embedding = nn.Embedding(num_buckets, num_heads)

    def forward(self, lq, lk):
        device = self.embedding.weight.device
        # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
        #     torch.arange(lq).unsqueeze(1).to(device)
Dongz's avatar
Dongz committed
225
        rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
helloyongyang's avatar
helloyongyang committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
        rel_pos = self._relative_position_bucket(rel_pos)
        rel_pos_embeds = self.embedding(rel_pos)
        rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0)  # [1, N, Lq, Lk]
        return rel_pos_embeds.contiguous()

    def _relative_position_bucket(self, rel_pos):
        # preprocess
        if self.bidirectional:
            num_buckets = self.num_buckets // 2
            rel_buckets = (rel_pos > 0).long() * num_buckets
            rel_pos = torch.abs(rel_pos)
        else:
            num_buckets = self.num_buckets
            rel_buckets = 0
            rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))

        # embeddings for small and large positions
        max_exact = num_buckets // 2
Dongz's avatar
Dongz committed
244
245
        rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)).long()
        rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
helloyongyang's avatar
helloyongyang committed
246
247
248
249
250
        rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
        return rel_buckets


class T5Encoder(nn.Module):
251
    def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1, cpu_offload=False, quantized=False, quant_scheme=None):
helloyongyang's avatar
helloyongyang committed
252
        super(T5Encoder, self).__init__()
253

254
        self.cpu_offload = cpu_offload
helloyongyang's avatar
helloyongyang committed
255
256
257
258
259
260
261
        self.dim = dim
        self.dim_attn = dim_attn
        self.dim_ffn = dim_ffn
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.num_buckets = num_buckets
        self.shared_pos = shared_pos
262
        self.quant_scheme = quant_scheme
helloyongyang's avatar
helloyongyang committed
263
264

        # layers
Dongz's avatar
Dongz committed
265
266
        self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
        self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
helloyongyang's avatar
helloyongyang committed
267
        self.dropout = nn.Dropout(dropout)
268
        self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, quantized, quant_scheme) for _ in range(num_layers)])
helloyongyang's avatar
helloyongyang committed
269
270
271
        self.norm = T5LayerNorm(dim)

        # initialize weights
272
        # self.apply(init_weights)
helloyongyang's avatar
helloyongyang committed
273
274

    def forward(self, ids, mask=None):
275
276
        if self.cpu_offload:
            self.token_embedding = self.token_embedding.cuda()
helloyongyang's avatar
helloyongyang committed
277
        x = self.token_embedding(ids)
278
279
        if self.cpu_offload:
            self.token_embedding = self.token_embedding.cpu()
helloyongyang's avatar
helloyongyang committed
280
        x = self.dropout(x)
281
282
        if self.cpu_offload and self.pos_embedding is not None:
            self.pos_embedding = self.pos_embedding.cuda()
helloyongyang's avatar
helloyongyang committed
283
        e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
284
285
        if self.cpu_offload and self.pos_embedding is not None:
            self.pos_embedding = self.pos_embedding.cpu()
helloyongyang's avatar
helloyongyang committed
286
        for block in self.blocks:
287
288
            if self.cpu_offload:
                block = block.cuda()
helloyongyang's avatar
helloyongyang committed
289
            x = block(x, mask, pos_bias=e)
290
291
292
293
            if self.cpu_offload:
                block = block.cpu()
        if self.cpu_offload:
            self.norm = self.norm.cuda()
helloyongyang's avatar
helloyongyang committed
294
        x = self.norm(x)
295
296
        if self.cpu_offload:
            self.norm = self.norm.cpu()
helloyongyang's avatar
helloyongyang committed
297
        x = self.dropout(x)
298
        return x.to(torch.bfloat16)
helloyongyang's avatar
helloyongyang committed
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323


class T5Decoder(nn.Module):
    def __init__(
        self,
        vocab,
        dim,
        dim_attn,
        dim_ffn,
        num_heads,
        num_layers,
        num_buckets,
        shared_pos=True,
        dropout=0.1,
    ):
        super(T5Decoder, self).__init__()
        self.dim = dim
        self.dim_attn = dim_attn
        self.dim_ffn = dim_ffn
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.num_buckets = num_buckets
        self.shared_pos = shared_pos

        # layers
Dongz's avatar
Dongz committed
324
325
        self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
        self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None
helloyongyang's avatar
helloyongyang committed
326
        self.dropout = nn.Dropout(dropout)
Dongz's avatar
Dongz committed
327
        self.blocks = nn.ModuleList([T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)])
helloyongyang's avatar
helloyongyang committed
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        self.norm = T5LayerNorm(dim)

        # initialize weights
        self.apply(init_weights)

    def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
        b, s = ids.size()

        # causal mask
        if mask is None:
            mask = torch.tril(torch.ones(1, s, s).to(ids.device))
        elif mask.ndim == 2:
            mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))

        # layers
        x = self.token_embedding(ids)
        x = self.dropout(x)
        e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
        for block in self.blocks:
            x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
        x = self.norm(x)
        x = self.dropout(x)
        return x


class T5Model(nn.Module):
    def __init__(
        self,
        vocab_size,
        dim,
        dim_attn,
        dim_ffn,
        num_heads,
        encoder_layers,
        decoder_layers,
        num_buckets,
        shared_pos=True,
        dropout=0.1,
    ):
        super(T5Model, self).__init__()
        self.vocab_size = vocab_size
        self.dim = dim
        self.dim_attn = dim_attn
        self.dim_ffn = dim_ffn
        self.num_heads = num_heads
        self.encoder_layers = encoder_layers
        self.decoder_layers = decoder_layers
        self.num_buckets = num_buckets

        # layers
        self.token_embedding = nn.Embedding(vocab_size, dim)
        self.encoder = T5Encoder(
            self.token_embedding,
            dim,
            dim_attn,
            dim_ffn,
            num_heads,
            encoder_layers,
            num_buckets,
            shared_pos,
            dropout,
        )
        self.decoder = T5Decoder(
            self.token_embedding,
            dim,
            dim_attn,
            dim_ffn,
            num_heads,
            decoder_layers,
            num_buckets,
            shared_pos,
            dropout,
        )
        self.head = nn.Linear(dim, vocab_size, bias=False)

        # initialize weights
        self.apply(init_weights)

    def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
        x = self.encoder(encoder_ids, encoder_mask)
        x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
        x = self.head(x)
        return x


def _t5(
    name,
    encoder_only=False,
    decoder_only=False,
    return_tokenizer=False,
    tokenizer_kwargs={},
    dtype=torch.float32,
    device="cpu",
    **kwargs,
):
    # sanity check
    assert not (encoder_only and decoder_only)

    # params
    if encoder_only:
        model_cls = T5Encoder
        kwargs["vocab"] = kwargs.pop("vocab_size")
        kwargs["num_layers"] = kwargs.pop("encoder_layers")
        _ = kwargs.pop("decoder_layers")
    elif decoder_only:
        model_cls = T5Decoder
        kwargs["vocab"] = kwargs.pop("vocab_size")
        kwargs["num_layers"] = kwargs.pop("decoder_layers")
        _ = kwargs.pop("encoder_layers")
    else:
        model_cls = T5Model

    # init model
    with torch.device(device):
        model = model_cls(**kwargs)

    # set device
    model = model.to(dtype=dtype, device=device)
446
    return model
helloyongyang's avatar
helloyongyang committed
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474


def umt5_xxl(**kwargs):
    cfg = dict(
        vocab_size=256384,
        dim=4096,
        dim_attn=4096,
        dim_ffn=10240,
        num_heads=64,
        encoder_layers=24,
        decoder_layers=24,
        num_buckets=32,
        shared_pos=False,
        dropout=0.1,
    )
    cfg.update(**kwargs)
    return _t5("umt5-xxl", **cfg)


class T5EncoderModel:
    def __init__(
        self,
        text_len,
        dtype=torch.bfloat16,
        device=torch.cuda.current_device(),
        checkpoint_path=None,
        tokenizer_path=None,
        shard_fn=None,
475
476
        cpu_offload=False,
        offload_granularity="model",
477
478
479
        t5_quantized=False,
        t5_quantized_ckpt=None,
        quant_scheme=None,
helloyongyang's avatar
helloyongyang committed
480
481
482
483
    ):
        self.text_len = text_len
        self.dtype = dtype
        self.device = device
484
485
486
487
        if t5_quantized_ckpt is not None and t5_quantized:
            self.checkpoint_path = t5_quantized_ckpt
        else:
            self.checkpoint_path = checkpoint_path
helloyongyang's avatar
helloyongyang committed
488
        self.tokenizer_path = tokenizer_path
489
490
491
492
493
494
        self.offload_granularity = offload_granularity

        # sync cpu offload
        self.cpu_offload = cpu_offload
        if self.cpu_offload:
            assert self.offload_granularity in ["block", "model"]
helloyongyang's avatar
helloyongyang committed
495

496
497
498
499
500
501
        model = (
            umt5_xxl(
                encoder_only=True,
                return_tokenizer=False,
                dtype=dtype,
                device=device,
502
503
504
                cpu_offload=(cpu_offload if self.offload_granularity == "block" else False),
                quantized=t5_quantized,
                quant_scheme=quant_scheme,
505
506
507
508
            )
            .eval()
            .requires_grad_(False)
        )
509
510
511
512

        logger.info(f"Loading weights from {self.checkpoint_path}")

        model.load_state_dict(torch.load(self.checkpoint_path, map_location="cpu", weights_only=True))
helloyongyang's avatar
helloyongyang committed
513
514
515
516
517
518
        self.model = model
        if shard_fn is not None:
            self.model = shard_fn(self.model, sync_module_states=False)
        else:
            self.model.to(self.device)
        # init tokenizer
Dongz's avatar
Dongz committed
519
        self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace")
helloyongyang's avatar
helloyongyang committed
520

TorynCurtis's avatar
TorynCurtis committed
521
522
523
524
525
526
    def to_cpu(self):
        self.model = self.model.to("cpu")

    def to_cuda(self):
        self.model = self.model.to("cuda")

527
528
    def infer(self, texts):
        if self.cpu_offload and self.offload_granularity == "model":
TorynCurtis's avatar
TorynCurtis committed
529
530
            self.to_cuda()

helloyongyang's avatar
helloyongyang committed
531
532
533
534
535
        ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
        ids = ids.cuda()
        mask = mask.cuda()
        seq_lens = mask.gt(0).sum(dim=1).long()
        context = self.model(ids, mask)
TorynCurtis's avatar
TorynCurtis committed
536

537
        if self.cpu_offload and self.offload_granularity == "model":
TorynCurtis's avatar
TorynCurtis committed
538
539
            self.to_cpu()

helloyongyang's avatar
helloyongyang committed
540
541
542
543
        return [u[:v] for u, v in zip(context, seq_lens)]


if __name__ == "__main__":
544
    checkpoint_dir = ""
helloyongyang's avatar
helloyongyang committed
545
546
547
548
549
550
551
552
553
554
555
556
    t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth"
    t5_tokenizer = "google/umt5-xxl"
    model = T5EncoderModel(
        text_len=512,
        dtype=torch.bfloat16,
        device=torch.device("cuda"),
        checkpoint_path=os.path.join(checkpoint_dir, t5_checkpoint),
        tokenizer_path=os.path.join(checkpoint_dir, t5_tokenizer),
        shard_fn=None,
    )
    text = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
    outputs = model.infer(text)
root's avatar
root committed
557
    logger.info(outputs)