audio_adapter.py 18.7 KB
Newer Older
gushiqiao's avatar
gushiqiao committed
1
2
3
4
try:
    import flash_attn
except ModuleNotFoundError:
    flash_attn = None
PengGao's avatar
PengGao committed
5
import math
gushiqiao's avatar
gushiqiao committed
6
import os
PengGao's avatar
PengGao committed
7

gushiqiao's avatar
gushiqiao committed
8
import safetensors
wangshankun's avatar
wangshankun committed
9
import torch
10
import torch.distributed as dist
wangshankun's avatar
wangshankun committed
11
12
13
14
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange
15
from loguru import logger
wangshankun's avatar
wangshankun committed
16
17
from transformers import AutoModel

18
19
from lightx2v.utils.envs import *

wangshankun's avatar
wangshankun committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

def load_safetensors(in_path: str):
    if os.path.isdir(in_path):
        return load_safetensors_from_dir(in_path)
    elif os.path.isfile(in_path):
        return load_safetensors_from_path(in_path)
    else:
        raise ValueError(f"{in_path} does not exist")


def load_safetensors_from_path(in_path: str):
    tensors = {}
    with safetensors.safe_open(in_path, framework="pt", device="cpu") as f:
        for key in f.keys():
            tensors[key] = f.get_tensor(key)
    return tensors


def load_safetensors_from_dir(in_dir: str):
    tensors = {}
    safetensors = os.listdir(in_dir)
    safetensors = [f for f in safetensors if f.endswith(".safetensors")]
    for f in safetensors:
        tensors.update(load_safetensors_from_path(os.path.join(in_dir, f)))
    return tensors


def load_pt_safetensors(in_path: str):
    ext = os.path.splitext(in_path)[-1]
    if ext in (".pt", ".pth", ".tar"):
        state_dict = torch.load(in_path, map_location="cpu", weights_only=True)
    else:
        state_dict = load_safetensors(in_path)
    return state_dict


56
def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True):
57
58
59
    model = model.to("cuda")
    # 确定当前进程是否是(负责加载权重)
    is_leader = False
60
61
62
    if dist.is_initialized():
        current_rank = dist.get_rank()
        if current_rank == 0:
63
64
65
            is_leader = True
    elif not dist.is_initialized() or dist.get_rank() == 0:
        is_leader = True
wangshankun's avatar
wangshankun committed
66

67
    if is_leader:
68
        logger.info(f"Loading model state from {in_path}")
wangshankun's avatar
wangshankun committed
69
70
        state_dict = load_pt_safetensors(in_path)
        model.load_state_dict(state_dict, strict=strict)
71
72

    # 将模型状态从领导者同步到组内所有其他进程
73
74
75
    if dist.is_initialized():
        dist.barrier(device_ids=[torch.cuda.current_device()])
        src_global_rank = 0
76
        for param in model.parameters():
77
            dist.broadcast(param.data, src=src_global_rank)
78
        for buffer in model.buffers():
79
            dist.broadcast(buffer.data, src=src_global_rank)
80
    elif dist.is_initialized():
81
        dist.barrier(device_ids=[torch.cuda.current_device()])
82
83
84
85
86
87
        for param in model.parameters():
            dist.broadcast(param.data, src=0)
        for buffer in model.buffers():
            dist.broadcast(buffer.data, src=0)

    return model.to(dtype=GET_DTYPE())
wangshankun's avatar
wangshankun committed
88
89
90
91
92
93
94
95
96


def linear_interpolation(features, output_len: int):
    features = features.transpose(1, 2)
    output_features = F.interpolate(features, size=output_len, align_corners=False, mode="linear")
    return output_features.transpose(1, 2)


def get_q_lens_audio_range(
wangshankun's avatar
wangshankun committed
97
98
99
100
101
    batchsize: int,
    n_tokens_per_rank: int,
    n_query_tokens: int,
    n_tokens_per_frame: int,
    sp_rank: int,
wangshankun's avatar
wangshankun committed
102
103
104
105
106
):
    if n_query_tokens == 0:
        q_lens = [1] * batchsize
        return q_lens, 0, 1
    idx0 = n_tokens_per_rank * sp_rank
wangshankun's avatar
wangshankun committed
107
108
    first_length = n_tokens_per_frame - idx0 % n_tokens_per_frame
    first_length = min(first_length, n_query_tokens)
wangshankun's avatar
wangshankun committed
109
110
111
112
113
114
115
116
117
    n_frames = (n_query_tokens - first_length) // n_tokens_per_frame
    last_length = n_query_tokens - n_frames * n_tokens_per_frame - first_length
    q_lens = []
    if first_length > 0:
        q_lens.append(first_length)
    q_lens += [n_tokens_per_frame] * n_frames
    if last_length > 0:
        q_lens.append(last_length)
    t0 = idx0 // n_tokens_per_frame
wangshankun's avatar
wangshankun committed
118
    t1 = t0 + len(q_lens)
wangshankun's avatar
wangshankun committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    return q_lens * batchsize, t0, t1


class PerceiverAttentionCA(nn.Module):
    def __init__(self, dim_head=128, heads=16, kv_dim=2048, adaLN: bool = False):
        super().__init__()
        self.dim_head = dim_head
        self.heads = heads
        inner_dim = dim_head * heads
        kv_dim = inner_dim if kv_dim is None else kv_dim
        self.norm_kv = nn.LayerNorm(kv_dim)
        self.norm_q = nn.LayerNorm(inner_dim, elementwise_affine=not adaLN)

        self.to_q = nn.Linear(inner_dim, inner_dim)
        self.to_kv = nn.Linear(kv_dim, inner_dim * 2)
        self.to_out = nn.Linear(inner_dim, inner_dim)
        if adaLN:
            self.shift_scale_gate = nn.Parameter(torch.randn(1, 3, inner_dim) / inner_dim**0.5)
        else:
            shift_scale_gate = torch.zeros((1, 3, inner_dim))
            shift_scale_gate[:, 2] = 1
            self.register_buffer("shift_scale_gate", shift_scale_gate, persistent=False)

    def forward(self, x, latents, t_emb, q_lens, k_lens):
        """x shape (batchsize, latent_frame, audio_tokens_per_latent,
        model_dim) latents (batchsize, length, model_dim)"""
        batchsize = len(x)
        x = self.norm_kv(x)
        shift, scale, gate = (t_emb + self.shift_scale_gate).chunk(3, dim=1)
148
149
150
151
152
153
        norm_q = self.norm_q(latents)
        if scale.shape[0] != norm_q.shape[0]:
            scale = scale.transpose(0, 1)  # (1, 5070, 3072)
            shift = shift.transpose(0, 1)
            gate = gate.transpose(0, 1)
        latents = norm_q * (1 + scale) + shift
154
        q = self.to_q(latents.to(GET_DTYPE()))
wangshankun's avatar
wangshankun committed
155
156
157
158
159
160
161
162
163
164
        k, v = self.to_kv(x).chunk(2, dim=-1)
        q = rearrange(q, "B L (H C) -> (B L) H C", H=self.heads)
        k = rearrange(k, "B T L (H C) -> (B T L) H C", H=self.heads)
        v = rearrange(v, "B T L (H C) -> (B T L) H C", H=self.heads)
        out = flash_attn.flash_attn_varlen_func(
            q=q,
            k=k,
            v=v,
            cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
            cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
gushiqiao's avatar
gushiqiao committed
165
166
            max_seqlen_q=q_lens.max().item(),
            max_seqlen_k=k_lens.max().item(),
wangshankun's avatar
wangshankun committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
            dropout_p=0.0,
            softmax_scale=None,
            causal=False,
            window_size=(-1, -1),
            deterministic=False,
        )
        out = rearrange(out, "(B L) H C -> B L (H C)", B=batchsize)
        return self.to_out(out) * gate


class AudioProjection(nn.Module):
    def __init__(
        self,
        audio_feature_dim: int = 768,
        n_neighbors: tuple = (2, 2),
        num_tokens: int = 32,
        mlp_dims: tuple = (1024, 1024, 32 * 768),
        transformer_layers: int = 4,
    ):
        super().__init__()
        mlp = []
        self.left, self.right = n_neighbors
        self.audio_frames = sum(n_neighbors) + 1
        in_dim = audio_feature_dim * self.audio_frames
        for i, out_dim in enumerate(mlp_dims):
            mlp.append(nn.Linear(in_dim, out_dim))
            if i != len(mlp_dims) - 1:
                mlp.append(nn.ReLU())
            in_dim = out_dim
        self.mlp = nn.Sequential(*mlp)
        self.norm = nn.LayerNorm(mlp_dims[-1] // num_tokens)
        self.num_tokens = num_tokens
        if transformer_layers > 0:
            decoder_layer = nn.TransformerDecoderLayer(d_model=audio_feature_dim, nhead=audio_feature_dim // 64, dim_feedforward=4 * audio_feature_dim, dropout=0.0, batch_first=True)
            self.transformer_decoder = nn.TransformerDecoder(
                decoder_layer,
                num_layers=transformer_layers,
            )
        else:
            self.transformer_decoder = None

    def forward(self, audio_feature, latent_frame):
        video_frame = (latent_frame - 1) * 4 + 1
        audio_feature_ori = audio_feature
        audio_feature = linear_interpolation(audio_feature_ori, video_frame)
        if self.transformer_decoder is not None:
            audio_feature = self.transformer_decoder(audio_feature, audio_feature_ori)
        audio_feature = F.pad(audio_feature, pad=(0, 0, self.left, self.right), mode="replicate")
        audio_feature = audio_feature.unfold(dimension=1, size=self.audio_frames, step=1)
        audio_feature = rearrange(audio_feature, "B T C W -> B T (W C)")
        audio_feature = self.mlp(audio_feature)  # (B, video_frame, C)
        audio_feature = rearrange(audio_feature, "B T (N C) -> B T N C", N=self.num_tokens)  # (B, video_frame, num_tokens, C)
        return self.norm(audio_feature)


class TimeEmbedding(nn.Module):
    def __init__(self, dim, time_freq_dim, time_proj_dim):
        super().__init__()
        self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
        self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
        self.act_fn = nn.SiLU()
        self.time_proj = nn.Linear(dim, time_proj_dim)

230
231
232
233
234
235
236
237
238
239
240
241
242
    def forward(self, timestep: torch.Tensor):
        # Project timestep
        if timestep.dim() == 2:
            timestep = self.timesteps_proj(timestep.squeeze(0)).unsqueeze(0)
        else:
            timestep = self.timesteps_proj(timestep)

        # Match dtype with time_embedder (except int8)
        target_dtype = next(self.time_embedder.parameters()).dtype
        if timestep.dtype != target_dtype and target_dtype != torch.int8:
            timestep = timestep.to(target_dtype)

        # Time embedding projection
wangshankun's avatar
wangshankun committed
243
244
        temb = self.time_embedder(timestep)
        timestep_proj = self.time_proj(self.act_fn(temb))
245
246

        return timestep_proj.squeeze(0) if timestep_proj.dim() == 3 else timestep_proj
wangshankun's avatar
wangshankun committed
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300


class AudioAdapter(nn.Module):
    def __init__(
        self,
        attention_head_dim=64,
        num_attention_heads=40,
        base_num_layers=30,
        interval=1,
        audio_feature_dim: int = 768,
        num_tokens: int = 32,
        mlp_dims: tuple = (1024, 1024, 32 * 768),
        time_freq_dim: int = 256,
        projection_transformer_layers: int = 4,
    ):
        super().__init__()
        self.audio_proj = AudioProjection(
            audio_feature_dim=audio_feature_dim,
            n_neighbors=(2, 2),
            num_tokens=num_tokens,
            mlp_dims=mlp_dims,
            transformer_layers=projection_transformer_layers,
        )
        # self.num_tokens = num_tokens * 4
        self.num_tokens_x4 = num_tokens * 4
        self.audio_pe = nn.Parameter(torch.randn(self.num_tokens_x4, mlp_dims[-1] // num_tokens) * 0.02)
        ca_num = math.ceil(base_num_layers / interval)
        self.base_num_layers = base_num_layers
        self.interval = interval
        self.ca = nn.ModuleList(
            [
                PerceiverAttentionCA(
                    dim_head=attention_head_dim,
                    heads=num_attention_heads,
                    kv_dim=mlp_dims[-1] // num_tokens,
                    adaLN=time_freq_dim > 0,
                )
                for _ in range(ca_num)
            ]
        )
        self.dim = attention_head_dim * num_attention_heads
        if time_freq_dim > 0:
            self.time_embedding = TimeEmbedding(self.dim, time_freq_dim, self.dim * 3)
        else:
            self.time_embedding = None

    def rearange_audio_features(self, audio_feature: torch.Tensor):
        # audio_feature (B, video_frame, num_tokens, C)
        audio_feature_0 = audio_feature[:, :1]
        audio_feature_0 = torch.repeat_interleave(audio_feature_0, repeats=4, dim=1)
        audio_feature = torch.cat([audio_feature_0, audio_feature[:, 1:]], dim=1)  # (B, 4 * latent_frame, num_tokens, C)
        audio_feature = rearrange(audio_feature, "B (T S) N C -> B T (S N) C", S=4)
        return audio_feature

wangshankun's avatar
wangshankun committed
301
302
    def forward(self, audio_feat: torch.Tensor, timestep: torch.Tensor, latent_frame: int, weight: float = 1.0, seq_p_group=None):
        def modify_hidden_states(hidden_states, grid_sizes, ca_block: PerceiverAttentionCA, x, t_emb, dtype, weight, seq_p_group):
wangshankun's avatar
wangshankun committed
303
304
305
306
307
308
309
310
311
312
313
314
315
316
            """thw specify the latent_frame, latent_height, latenf_width after
            hidden_states is patchified.

            latent_frame does not include the reference images so that the
            audios and hidden_states are strictly aligned
            """
            if len(hidden_states.shape) == 2:  # 扩展batchsize dim
                hidden_states = hidden_states.unsqueeze(0)  # bs = 1
            t, h, w = grid_sizes[0].tolist()
            n_tokens = t * h * w
            ori_dtype = hidden_states.dtype
            device = hidden_states.device
            bs, n_tokens_per_rank = hidden_states.shape[:2]

wangshankun's avatar
wangshankun committed
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
            if seq_p_group is not None:
                sp_size = dist.get_world_size(seq_p_group)
                sp_rank = dist.get_rank(seq_p_group)
            else:
                sp_size = 1
                sp_rank = 0

            tail_length = n_tokens_per_rank * sp_size - n_tokens
            n_unused_ranks = tail_length // n_tokens_per_rank
            if sp_rank > sp_size - n_unused_ranks - 1:
                n_query_tokens = 0
            elif sp_rank == sp_size - n_unused_ranks - 1:
                n_query_tokens = n_tokens_per_rank - tail_length % n_tokens_per_rank
            else:
                n_query_tokens = n_tokens_per_rank
wangshankun's avatar
wangshankun committed
332
333
334
335
336
337
338
339
340

            if n_query_tokens > 0:
                hidden_states_aligned = hidden_states[:, :n_query_tokens]
                hidden_states_tail = hidden_states[:, n_query_tokens:]
            else:
                # for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works.
                hidden_states_aligned = hidden_states[:, :1]
                hidden_states_tail = hidden_states[:, 1:]

wangshankun's avatar
wangshankun committed
341
            q_lens, t0, t1 = get_q_lens_audio_range(batchsize=bs, n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=h * w, sp_rank=sp_rank)
wangshankun's avatar
wangshankun committed
342
343
344
345
346
347
348
349
350
351
            q_lens = torch.tensor(q_lens, device=device, dtype=torch.int32)
            """
            processing audio features in sp_state can be moved outside.
            """
            x = x[:, t0:t1]
            x = x.to(dtype)
            k_lens = torch.tensor([self.num_tokens_x4] * (t1 - t0) * bs, device=device, dtype=torch.int32)
            assert q_lens.shape == k_lens.shape
            # ca_block:CrossAttention函数
            residual = ca_block(x, hidden_states_aligned, t_emb, q_lens, k_lens) * weight
wangshankun's avatar
wangshankun committed
352

wangshankun's avatar
wangshankun committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
            residual = residual.to(ori_dtype)  # audio做了CrossAttention之后以Residual的方式注入
            if n_query_tokens == 0:
                residual = residual * 0.0
            hidden_states = torch.cat([hidden_states_aligned + residual, hidden_states_tail], dim=1)

            if len(hidden_states.shape) == 3:  #
                hidden_states = hidden_states.squeeze(0)  # bs = 1
            return hidden_states

        x = self.audio_proj(audio_feat, latent_frame)
        x = self.rearange_audio_features(x)
        x = x + self.audio_pe
        if self.time_embedding is not None:
            t_emb = self.time_embedding(timestep).unflatten(1, (3, -1))
        else:
            t_emb = torch.zeros((len(x), 3, self.dim), device=x.device, dtype=x.dtype)
        ret_dict = {}
        for block_idx, base_idx in enumerate(range(0, self.base_num_layers, self.interval)):
            block_dict = {
                "kwargs": {
                    "ca_block": self.ca[block_idx],
                    "x": x,
                    "weight": weight,
                    "t_emb": t_emb,
                    "dtype": x.dtype,
wangshankun's avatar
wangshankun committed
378
                    "seq_p_group": seq_p_group,
wangshankun's avatar
wangshankun committed
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
                },
                "modify_func": modify_hidden_states,
            }
            ret_dict[base_idx] = block_dict
        return ret_dict

    @classmethod
    def from_transformer(
        cls,
        transformer,
        audio_feature_dim: int = 1024,
        interval: int = 1,
        time_freq_dim: int = 256,
        projection_transformer_layers: int = 4,
    ):
        num_attention_heads = transformer.config["num_heads"]
        base_num_layers = transformer.config["num_layers"]
        attention_head_dim = transformer.config["dim"] // num_attention_heads

        audio_adapter = AudioAdapter(
            attention_head_dim,
            num_attention_heads,
            base_num_layers,
            interval=interval,
            audio_feature_dim=audio_feature_dim,
            time_freq_dim=time_freq_dim,
            projection_transformer_layers=projection_transformer_layers,
            mlp_dims=(1024, 1024, 32 * audio_feature_dim),
        )
        return audio_adapter

    def get_fsdp_wrap_module_list(
        self,
    ):
        ret_list = list(self.ca)
        return ret_list

    def enable_gradient_checkpointing(
        self,
    ):
        pass


class AudioAdapterPipe:
    def __init__(
wangshankun's avatar
wangshankun committed
424
425
426
427
428
429
430
431
432
        self,
        audio_adapter: AudioAdapter,
        audio_encoder_repo: str = "microsoft/wavlm-base-plus",
        dtype=torch.float32,
        device="cuda",
        tgt_fps: int = 15,
        weight: float = 1.0,
        cpu_offload: bool = False,
        seq_p_group=None,
wangshankun's avatar
wangshankun committed
433
    ) -> None:
wangshankun's avatar
wangshankun committed
434
        self.seq_p_group = seq_p_group
wangshankun's avatar
wangshankun committed
435
436
437
        self.audio_adapter = audio_adapter
        self.dtype = dtype
        self.audio_encoder_dtype = torch.float16
gushiqiao's avatar
gushiqiao committed
438
        self.cpu_offload = cpu_offload
wangshankun's avatar
wangshankun committed
439
        ##音频编码器
wangshankun's avatar
wangshankun committed
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
        self.audio_encoder = AutoModel.from_pretrained(audio_encoder_repo)

        self.audio_encoder.eval()
        self.audio_encoder.to(device, self.audio_encoder_dtype)
        self.tgt_fps = tgt_fps
        self.weight = weight
        if "base" in audio_encoder_repo:
            self.audio_feature_dim = 768
        else:
            self.audio_feature_dim = 1024

    def update_model(self, audio_adapter):
        self.audio_adapter = audio_adapter

    def __call__(self, audio_input_feat, timestep, latent_shape: tuple, dropout_cond: callable = None):
        # audio_input_feat is from AudioPreprocessor
        latent_frame = latent_shape[2]
        if len(audio_input_feat.shape) == 1:  # 扩展batchsize = 1
            audio_input_feat = audio_input_feat.unsqueeze(0)
            latent_frame = latent_shape[1]

        video_frame = (latent_frame - 1) * 4 + 1
        audio_length = int(50 / self.tgt_fps * video_frame)

        with torch.no_grad():
            try:
gushiqiao's avatar
gushiqiao committed
466
467
468
469
470
                if self.cpu_offload:
                    self.audio_encoder = self.audio_encoder.to("cuda")
                audio_feat = self.audio_encoder(audio_input_feat.to(self.audio_encoder_dtype), return_dict=True).last_hidden_state
                if self.cpu_offload:
                    self.audio_encoder = self.audio_encoder.to("cpu")
wangshankun's avatar
wangshankun committed
471
            except Exception as err:
gushiqiao's avatar
gushiqiao committed
472
                audio_feat = torch.rand(1, audio_length, self.audio_feature_dim).to("cuda")
wangshankun's avatar
wangshankun committed
473
474
475
476
477
                print(err)
            audio_feat = audio_feat.to(self.dtype)
            if dropout_cond is not None:
                audio_feat = dropout_cond(audio_feat)

wangshankun's avatar
wangshankun committed
478
        return self.audio_adapter(audio_feat=audio_feat, timestep=timestep, latent_frame=latent_frame, weight=self.weight, seq_p_group=self.seq_p_group)