model.py 20.5 KB
Newer Older
mrfakename's avatar
mrfakename committed
1
from transformers import T5EncoderModel, T5TokenizerFast
hungchiayu1's avatar
hungchiayu1 committed
2
import torch
mrfakename's avatar
mrfakename committed
3
from diffusers import FluxTransformer2DModel
hungchiayu1's avatar
hungchiayu1 committed
4
from torch import nn
mrfakename's avatar
mrfakename committed
5
import random
hungchiayu1's avatar
hungchiayu1 committed
6
7
8
9
10
11
12
13
from typing import List
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.training_utils import compute_density_for_timestep_sampling
import copy
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

mrfakename's avatar
mrfakename committed
14
from typing import Optional, Union, List
hungchiayu1's avatar
hungchiayu1 committed
15
16
17
18
19
20
21
22
from datasets import load_dataset, Audio
from math import pi
import inspect
import yaml


class StableAudioPositionalEmbedding(nn.Module):
    """Used for continuous time
mrfakename's avatar
mrfakename committed
23
    Adapted from Stable Audio Open.
hungchiayu1's avatar
hungchiayu1 committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    """

    def __init__(self, dim: int):
        super().__init__()
        assert (dim % 2) == 0
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim))

    def forward(self, times: torch.Tensor) -> torch.Tensor:
        times = times[..., None]
        freqs = times * self.weights[None] * 2 * pi
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
        fouriered = torch.cat((times, fouriered), dim=-1)
        return fouriered
mrfakename's avatar
mrfakename committed
38
39


hungchiayu1's avatar
hungchiayu1 committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class DurationEmbedder(nn.Module):
    """
    A simple linear projection model to map numbers to a latent space.

    Code is adapted from
    https://github.com/Stability-AI/stable-audio-tools

    Args:
        number_embedding_dim (`int`):
            Dimensionality of the number embeddings.
        min_value (`int`):
            The minimum value of the seconds number conditioning modules.
        max_value (`int`):
            The maximum value of the seconds number conditioning modules
        internal_dim (`int`):
            Dimensionality of the intermediate number hidden states.
    """

    def __init__(
        self,
        number_embedding_dim,
        min_value,
        max_value,
        internal_dim: Optional[int] = 256,
    ):
        super().__init__()
        self.time_positional_embedding = nn.Sequential(
            StableAudioPositionalEmbedding(internal_dim),
            nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim),
        )

        self.number_embedding_dim = number_embedding_dim
        self.min_value = min_value
        self.max_value = max_value
mrfakename's avatar
mrfakename committed
74
        self.dtype = torch.float32
hungchiayu1's avatar
hungchiayu1 committed
75
76
77
78
79
80
81

    def forward(
        self,
        floats: torch.Tensor,
    ):
        floats = floats.clamp(self.min_value, self.max_value)

mrfakename's avatar
mrfakename committed
82
83
84
        normalized_floats = (floats - self.min_value) / (
            self.max_value - self.min_value
        )
hungchiayu1's avatar
hungchiayu1 committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

        # Cast floats to same type as embedder
        embedder_dtype = next(self.time_positional_embedding.parameters()).dtype
        normalized_floats = normalized_floats.to(embedder_dtype)

        embedding = self.time_positional_embedding(normalized_floats)
        float_embeds = embedding.view(-1, 1, self.number_embedding_dim)

        return float_embeds


def retrieve_timesteps(
    scheduler,
    num_inference_steps: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    timesteps: Optional[List[int]] = None,
    sigmas: Optional[List[float]] = None,
    **kwargs,
):

    if timesteps is not None and sigmas is not None:
mrfakename's avatar
mrfakename committed
106
107
108
        raise ValueError(
            "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
        )
hungchiayu1's avatar
hungchiayu1 committed
109
    if timesteps is not None:
mrfakename's avatar
mrfakename committed
110
111
112
        accepts_timesteps = "timesteps" in set(
            inspect.signature(scheduler.set_timesteps).parameters.keys()
        )
hungchiayu1's avatar
hungchiayu1 committed
113
114
115
116
117
118
119
120
121
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    elif sigmas is not None:
mrfakename's avatar
mrfakename committed
122
123
124
        accept_sigmas = "sigmas" in set(
            inspect.signature(scheduler.set_timesteps).parameters.keys()
        )
hungchiayu1's avatar
hungchiayu1 committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        if not accept_sigmas:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" sigmas schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
    return timesteps, num_inference_steps


class TangoFlux(nn.Module):

LucipherDev's avatar
LucipherDev committed
141
    def __init__(self, config, text_encoder_dir=None, initialize_reference_model=False,):
hungchiayu1's avatar
hungchiayu1 committed
142
143

        super().__init__()
mrfakename's avatar
mrfakename committed
144
145
146
147
148
149
150
151
152
153
154
155

        self.num_layers = config.get("num_layers", 6)
        self.num_single_layers = config.get("num_single_layers", 18)
        self.in_channels = config.get("in_channels", 64)
        self.attention_head_dim = config.get("attention_head_dim", 128)
        self.joint_attention_dim = config.get("joint_attention_dim", 1024)
        self.num_attention_heads = config.get("num_attention_heads", 8)
        self.audio_seq_len = config.get("audio_seq_len", 645)
        self.max_duration = config.get("max_duration", 30)
        self.uncondition = config.get("uncondition", False)
        self.text_encoder_name = config.get("text_encoder_name", "google/flan-t5-large")

hungchiayu1's avatar
hungchiayu1 committed
156
157
158
        self.noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
        self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler)
        self.max_text_seq_len = 64
LucipherDev's avatar
LucipherDev committed
159
160
161
162
163
164
        self.text_encoder = T5EncoderModel.from_pretrained(
            text_encoder_dir if text_encoder_dir is not None else self.text_encoder_name
        )
        self.tokenizer = T5TokenizerFast.from_pretrained(
            text_encoder_dir if text_encoder_dir is not None else self.text_encoder_name
        )
hungchiayu1's avatar
hungchiayu1 committed
165
        self.text_embedding_dim = self.text_encoder.config.d_model
mrfakename's avatar
mrfakename committed
166
167
168
169
170
171
172
173

        self.fc = nn.Sequential(
            nn.Linear(self.text_embedding_dim, self.joint_attention_dim), nn.ReLU()
        )
        self.duration_emebdder = DurationEmbedder(
            self.text_embedding_dim, min_value=0, max_value=self.max_duration
        )

hungchiayu1's avatar
hungchiayu1 committed
174
        self.transformer = FluxTransformer2DModel(
mrfakename's avatar
mrfakename committed
175
176
177
178
179
180
181
182
183
184
185
186
187
            in_channels=self.in_channels,
            num_layers=self.num_layers,
            num_single_layers=self.num_single_layers,
            attention_head_dim=self.attention_head_dim,
            num_attention_heads=self.num_attention_heads,
            joint_attention_dim=self.joint_attention_dim,
            pooled_projection_dim=self.text_embedding_dim,
            guidance_embeds=False,
        )

        self.beta_dpo = 2000  ## this is used for dpo training

    def get_sigmas(self, timesteps, n_dim=3, dtype=torch.float32):
hungchiayu1's avatar
hungchiayu1 committed
188
189
190
191
192
193
        device = self.text_encoder.device
        sigmas = self.noise_scheduler_copy.sigmas.to(device=device, dtype=dtype)

        schedule_timesteps = self.noise_scheduler_copy.timesteps.to(device)
        timesteps = timesteps.to(device)
        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
mrfakename's avatar
mrfakename committed
194

hungchiayu1's avatar
hungchiayu1 committed
195
196
197
198
        sigma = sigmas[step_indices].flatten()
        while len(sigma.shape) < n_dim:
            sigma = sigma.unsqueeze(-1)
        return sigma
mrfakename's avatar
mrfakename committed
199

hungchiayu1's avatar
hungchiayu1 committed
200
201
202
    def encode_text_classifier_free(self, prompt: List[str], num_samples_per_prompt=1):
        device = self.text_encoder.device
        batch = self.tokenizer(
mrfakename's avatar
mrfakename committed
203
204
205
206
207
208
209
210
            prompt,
            max_length=self.tokenizer.model_max_length,
            padding=True,
            truncation=True,
            return_tensors="pt",
        )
        input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(
            device
hungchiayu1's avatar
hungchiayu1 committed
211
212
213
214
215
216
        )

        with torch.no_grad():
            prompt_embeds = self.text_encoder(
                input_ids=input_ids, attention_mask=attention_mask
            )[0]
mrfakename's avatar
mrfakename committed
217

hungchiayu1's avatar
hungchiayu1 committed
218
219
220
221
        prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
        attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0)

        # get unconditional embeddings for classifier free guidance
mrfakename's avatar
mrfakename committed
222
223
        uncond_tokens = [""]

hungchiayu1's avatar
hungchiayu1 committed
224
225
        max_length = prompt_embeds.shape[1]
        uncond_batch = self.tokenizer(
mrfakename's avatar
mrfakename committed
226
227
228
229
230
            uncond_tokens,
            max_length=max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
hungchiayu1's avatar
hungchiayu1 committed
231
232
233
        )
        uncond_input_ids = uncond_batch.input_ids.to(device)
        uncond_attention_mask = uncond_batch.attention_mask.to(device)
mrfakename's avatar
mrfakename committed
234

hungchiayu1's avatar
hungchiayu1 committed
235
236
237
238
        with torch.no_grad():
            negative_prompt_embeds = self.text_encoder(
                input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
            )[0]
mrfakename's avatar
mrfakename committed
239
240
241
242
243
244
245

        negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(
            num_samples_per_prompt, 0
        )
        uncond_attention_mask = uncond_attention_mask.repeat_interleave(
            num_samples_per_prompt, 0
        )
hungchiayu1's avatar
hungchiayu1 committed
246
247
248

        # For classifier free guidance, we need to do two forward passes.
        # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
mrfakename's avatar
mrfakename committed
249

hungchiayu1's avatar
hungchiayu1 committed
250
251
252
253
254
255
256
257
258
259
        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
        prompt_mask = torch.cat([uncond_attention_mask, attention_mask])
        boolean_prompt_mask = (prompt_mask == 1).to(device)

        return prompt_embeds, boolean_prompt_mask

    @torch.no_grad()
    def encode_text(self, prompt):
        device = self.text_encoder.device
        batch = self.tokenizer(
mrfakename's avatar
mrfakename committed
260
261
262
263
264
265
266
267
268
            prompt,
            max_length=self.max_text_seq_len,
            padding=True,
            truncation=True,
            return_tensors="pt",
        )
        input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(
            device
        )
hungchiayu1's avatar
hungchiayu1 committed
269
270

        encoder_hidden_states = self.text_encoder(
mrfakename's avatar
mrfakename committed
271
272
273
            input_ids=input_ids, attention_mask=attention_mask
        )[0]

hungchiayu1's avatar
hungchiayu1 committed
274
        boolean_encoder_mask = (attention_mask == 1).to(device)
mrfakename's avatar
mrfakename committed
275

hungchiayu1's avatar
hungchiayu1 committed
276
277
        return encoder_hidden_states, boolean_encoder_mask

mrfakename's avatar
mrfakename committed
278
279
    def encode_duration(self, duration):
        return self.duration_emebdder(duration)
hungchiayu1's avatar
hungchiayu1 committed
280
281

    @torch.no_grad()
mrfakename's avatar
mrfakename committed
282
283
284
285
286
287
288
    def inference_flow(
        self,
        prompt,
        num_inference_steps=50,
        timesteps=None,
        guidance_scale=3,
        duration=10,
LucipherDev's avatar
LucipherDev committed
289
        seed=0,
mrfakename's avatar
mrfakename committed
290
291
        disable_progress=False,
        num_samples_per_prompt=1,
LucipherDev's avatar
LucipherDev committed
292
        callback_on_step_end=None,
mrfakename's avatar
mrfakename committed
293
294
    ):
        """Only tested for single inference. Haven't test for batch inference"""
LucipherDev's avatar
LucipherDev committed
295
296
297
298
299
300
        
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
hungchiayu1's avatar
hungchiayu1 committed
301
302
303
304
305

        bsz = num_samples_per_prompt
        device = self.transformer.device
        scheduler = self.noise_scheduler

mrfakename's avatar
mrfakename committed
306
        if not isinstance(prompt, list):
hungchiayu1's avatar
hungchiayu1 committed
307
            prompt = [prompt]
mrfakename's avatar
mrfakename committed
308
309
        if not isinstance(duration, torch.Tensor):
            duration = torch.tensor([duration], device=device)
hungchiayu1's avatar
hungchiayu1 committed
310
311
312
313
314
        classifier_free_guidance = guidance_scale > 1.0
        duration_hidden_states = self.encode_duration(duration)
        if classifier_free_guidance:
            bsz = 2 * num_samples_per_prompt

mrfakename's avatar
mrfakename committed
315
316
317
318
319
320
            encoder_hidden_states, boolean_encoder_mask = (
                self.encode_text_classifier_free(
                    prompt, num_samples_per_prompt=num_samples_per_prompt
                )
            )
            duration_hidden_states = duration_hidden_states.repeat(bsz, 1, 1)
hungchiayu1's avatar
hungchiayu1 committed
321
322
323

        else:

mrfakename's avatar
mrfakename committed
324
325
326
327
328
329
330
331
332
333
            encoder_hidden_states, boolean_encoder_mask = self.encode_text(
                prompt, num_samples_per_prompt=num_samples_per_prompt
            )

        mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(
            encoder_hidden_states
        )
        masked_data = torch.where(
            mask_expanded, encoder_hidden_states, torch.tensor(float("nan"))
        )
hungchiayu1's avatar
hungchiayu1 committed
334
335
336
337

        pooled = torch.nanmean(masked_data, dim=1)
        pooled_projection = self.fc(pooled)

mrfakename's avatar
mrfakename committed
338
339
340
341
        encoder_hidden_states = torch.cat(
            [encoder_hidden_states, duration_hidden_states], dim=1
        )  ## (bs,seq_len,dim)

hungchiayu1's avatar
hungchiayu1 committed
342
343
        sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
        timesteps, num_inference_steps = retrieve_timesteps(
mrfakename's avatar
mrfakename committed
344
            scheduler, num_inference_steps, device, timesteps, sigmas
hungchiayu1's avatar
hungchiayu1 committed
345
346
        )

mrfakename's avatar
mrfakename committed
347
        latents = torch.randn(num_samples_per_prompt, self.audio_seq_len, 64)
hungchiayu1's avatar
hungchiayu1 committed
348
349
350
351
        weight_dtype = latents.dtype

        progress_bar = tqdm(range(num_inference_steps), disable=disable_progress)

mrfakename's avatar
mrfakename committed
352
353
354
355
356
357
358
359
        txt_ids = torch.zeros(bsz, encoder_hidden_states.shape[1], 3).to(device)
        audio_ids = (
            torch.arange(self.audio_seq_len)
            .unsqueeze(0)
            .unsqueeze(-1)
            .repeat(bsz, 1, 3)
            .to(device)
        )
hungchiayu1's avatar
hungchiayu1 committed
360
361
362
363
364
365
366

        timesteps = timesteps.to(device)
        latents = latents.to(device)
        encoder_hidden_states = encoder_hidden_states.to(device)

        for i, t in enumerate(timesteps):

mrfakename's avatar
mrfakename committed
367
368
369
            latents_input = (
                torch.cat([latents] * 2) if classifier_free_guidance else latents
            )
hungchiayu1's avatar
hungchiayu1 committed
370
371

            noise_pred = self.transformer(
mrfakename's avatar
mrfakename committed
372
373
374
375
376
377
378
379
380
381
382
                hidden_states=latents_input,
                # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
                timestep=torch.tensor([t / 1000], device=device),
                guidance=None,
                pooled_projections=pooled_projection,
                encoder_hidden_states=encoder_hidden_states,
                txt_ids=txt_ids,
                img_ids=audio_ids,
                return_dict=False,
            )[0]

hungchiayu1's avatar
hungchiayu1 committed
383
384
            if classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
mrfakename's avatar
mrfakename committed
385
386
387
                noise_pred = noise_pred_uncond + guidance_scale * (
                    noise_pred_text - noise_pred_uncond
                )
hungchiayu1's avatar
hungchiayu1 committed
388

mrfakename's avatar
mrfakename committed
389
            latents = scheduler.step(noise_pred, t, latents).prev_sample
hungchiayu1's avatar
hungchiayu1 committed
390

LucipherDev's avatar
LucipherDev committed
391
392
393
394
395
            progress_bar.update(1)

            if callback_on_step_end is not None:
                callback_on_step_end()

hungchiayu1's avatar
hungchiayu1 committed
396
397
        return latents

mrfakename's avatar
mrfakename committed
398
    def forward(self, latents, prompt, duration=torch.tensor([10]), sft=True):
hungchiayu1's avatar
hungchiayu1 committed
399
400
401
402
403
404
405

        device = latents.device
        audio_seq_length = self.audio_seq_len
        bsz = latents.shape[0]

        encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt)
        duration_hidden_states = self.encode_duration(duration)
mrfakename's avatar
mrfakename committed
406
407
408
409
410
411
412

        mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(
            encoder_hidden_states
        )
        masked_data = torch.where(
            mask_expanded, encoder_hidden_states, torch.tensor(float("nan"))
        )
hungchiayu1's avatar
hungchiayu1 committed
413
414
415
416
        pooled = torch.nanmean(masked_data, dim=1)
        pooled_projection = self.fc(pooled)

        ## Add duration hidden states to encoder hidden states
mrfakename's avatar
mrfakename committed
417
418
419
420
421
422
423
424
425
426
427
428
        encoder_hidden_states = torch.cat(
            [encoder_hidden_states, duration_hidden_states], dim=1
        )  ## (bs,seq_len,dim)

        txt_ids = torch.zeros(bsz, encoder_hidden_states.shape[1], 3).to(device)
        audio_ids = (
            torch.arange(audio_seq_length)
            .unsqueeze(0)
            .unsqueeze(-1)
            .repeat(bsz, 1, 3)
            .to(device)
        )
hungchiayu1's avatar
hungchiayu1 committed
429
430

        if sft:
mrfakename's avatar
mrfakename committed
431

hungchiayu1's avatar
hungchiayu1 committed
432
433
434
435
            if self.uncondition:
                mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1]
                if len(mask_indices) > 0:
                    encoder_hidden_states[mask_indices] = 0
mrfakename's avatar
mrfakename committed
436

hungchiayu1's avatar
hungchiayu1 committed
437
            noise = torch.randn_like(latents)
mrfakename's avatar
mrfakename committed
438

hungchiayu1's avatar
hungchiayu1 committed
439
            u = compute_density_for_timestep_sampling(
mrfakename's avatar
mrfakename committed
440
441
442
443
444
445
                weighting_scheme="logit_normal",
                batch_size=bsz,
                logit_mean=0,
                logit_std=1,
                mode_scale=None,
            )
hungchiayu1's avatar
hungchiayu1 committed
446
447

            indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
mrfakename's avatar
mrfakename committed
448
449
450
            timesteps = self.noise_scheduler_copy.timesteps[indices].to(
                device=latents.device
            )
hungchiayu1's avatar
hungchiayu1 committed
451
            sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
mrfakename's avatar
mrfakename committed
452

hungchiayu1's avatar
hungchiayu1 committed
453
            noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
mrfakename's avatar
mrfakename committed
454
455
456
457
458
459
460
461

            model_pred = self.transformer(
                hidden_states=noisy_model_input,
                encoder_hidden_states=encoder_hidden_states,
                pooled_projections=pooled_projection,
                img_ids=audio_ids,
                txt_ids=txt_ids,
                guidance=None,
hungchiayu1's avatar
hungchiayu1 committed
462
                # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
mrfakename's avatar
mrfakename committed
463
464
465
                timestep=timesteps / 1000,
                return_dict=False,
            )[0]
hungchiayu1's avatar
hungchiayu1 committed
466
467
468

            target = noise - latents
            loss = torch.mean(
mrfakename's avatar
mrfakename committed
469
470
471
472
473
                ((model_pred.float() - target.float()) ** 2).reshape(
                    target.shape[0], -1
                ),
                1,
            )
hungchiayu1's avatar
hungchiayu1 committed
474
            loss = loss.mean()
mrfakename's avatar
mrfakename committed
475
476
477
478
479
            raw_model_loss, raw_ref_loss, implicit_acc = (
                0,
                0,
                0,
            )  ## default this to 0 if doing sft
hungchiayu1's avatar
hungchiayu1 committed
480
481
482

        else:
            encoder_hidden_states = encoder_hidden_states.repeat(2, 1, 1)
mrfakename's avatar
mrfakename committed
483
484
485
486
            pooled_projection = pooled_projection.repeat(2, 1)
            noise = (
                torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1)
            )  ## Have to sample same noise for preferred and rejected
hungchiayu1's avatar
hungchiayu1 committed
487
            u = compute_density_for_timestep_sampling(
mrfakename's avatar
mrfakename committed
488
489
490
491
492
493
                weighting_scheme="logit_normal",
                batch_size=bsz // 2,
                logit_mean=0,
                logit_std=1,
                mode_scale=None,
            )
hungchiayu1's avatar
hungchiayu1 committed
494
495

            indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
mrfakename's avatar
mrfakename committed
496
497
498
            timesteps = self.noise_scheduler_copy.timesteps[indices].to(
                device=latents.device
            )
hungchiayu1's avatar
hungchiayu1 committed
499
500
            timesteps = timesteps.repeat(2)
            sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
mrfakename's avatar
mrfakename committed
501

hungchiayu1's avatar
hungchiayu1 committed
502
503
            noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise

mrfakename's avatar
mrfakename committed
504
505
506
507
508
509
510
            model_pred = self.transformer(
                hidden_states=noisy_model_input,
                encoder_hidden_states=encoder_hidden_states,
                pooled_projections=pooled_projection,
                img_ids=audio_ids,
                txt_ids=txt_ids,
                guidance=None,
hungchiayu1's avatar
hungchiayu1 committed
511
                # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
mrfakename's avatar
mrfakename committed
512
513
514
                timestep=timesteps / 1000,
                return_dict=False,
            )[0]
hungchiayu1's avatar
hungchiayu1 committed
515
            target = noise - latents
mrfakename's avatar
mrfakename committed
516
517
518
519
520
521
522

            model_losses = F.mse_loss(
                model_pred.float(), target.float(), reduction="none"
            )
            model_losses = model_losses.mean(
                dim=list(range(1, len(model_losses.shape)))
            )
hungchiayu1's avatar
hungchiayu1 committed
523
            model_losses_w, model_losses_l = model_losses.chunk(2)
mrfakename's avatar
mrfakename committed
524
            model_diff = model_losses_w - model_losses_l
hungchiayu1's avatar
hungchiayu1 committed
525
526
527
528
            raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())

            with torch.no_grad():
                ref_preds = self.ref_transformer(
mrfakename's avatar
mrfakename committed
529
530
531
532
533
534
535
536
537
538
539
540
541
                    hidden_states=noisy_model_input,
                    encoder_hidden_states=encoder_hidden_states,
                    pooled_projections=pooled_projection,
                    img_ids=audio_ids,
                    txt_ids=txt_ids,
                    guidance=None,
                    timestep=timesteps / 1000,
                    return_dict=False,
                )[0]

                ref_loss = F.mse_loss(
                    ref_preds.float(), target.float(), reduction="none"
                )
hungchiayu1's avatar
hungchiayu1 committed
542
                ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
mrfakename's avatar
mrfakename committed
543

hungchiayu1's avatar
hungchiayu1 committed
544
545
546
547
548
                ref_losses_w, ref_losses_l = ref_loss.chunk(2)
                ref_diff = ref_losses_w - ref_losses_l
                raw_ref_loss = ref_loss.mean()

            scale_term = -0.5 * self.beta_dpo
mrfakename's avatar
mrfakename committed
549
550
551
552
553
            inside_term = scale_term * (model_diff - ref_diff)
            implicit_acc = (
                scale_term * (model_diff - ref_diff) > 0
            ).sum().float() / inside_term.size(0)
            loss = -1 * F.logsigmoid(inside_term).mean() + model_losses_w.mean()
hungchiayu1's avatar
hungchiayu1 committed
554

mrfakename's avatar
mrfakename committed
555
556
        ## raw_model_loss, raw_ref_loss, implicit_acc is used to help to analyze dpo behaviour.
        return loss, raw_model_loss, raw_ref_loss, implicit_acc