Commit f0877951 authored by hungchiayu1's avatar hungchiayu1
Browse files

update gitignore

parent 337829fe
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# PyPI configuration file
.pypirc
from diffusers import AutoencoderOobleck
import torch
from transformers import T5EncoderModel,T5TokenizerFast
from diffusers import FluxTransformer2DModel
from torch import nn
from typing import List
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.training_utils import compute_density_for_timestep_sampling
import copy
import torch.nn.functional as F
import numpy as np
from model import TangoFlux
from huggingface_hub import snapshot_download
from tqdm import tqdm
from typing import Optional,Union,List
from datasets import load_dataset, Audio
from math import pi
import json
import inspect
import yaml
from safetensors.torch import load_file
class TangoFluxInference:
def __init__(self,name='declare-lab/TangoFlux',device="cuda"):
self.vae = AutoencoderOobleck.from_pretrained("stabilityai/stable-audio-open-1.0",subfolder='vae')
paths = snapshot_download(repo_id=name)
weights = load_file("{}/tangoflux.safetensors".format(paths))
with open('{}/config.json'.format(paths),'r') as f:
config = json.load(f)
self.model = TangoFlux(config)
self.model.load_state_dict(weights,strict=False)
# _IncompatibleKeys(missing_keys=['text_encoder.encoder.embed_tokens.weight'], unexpected_keys=[]) this behaviour is expected
self.vae.to(device)
self.model.to(device)
def generate(self,prompt,steps=25,duration=10,guidance_scale=4.5):
with torch.no_grad():
latents = self.model.inference_flow(prompt,
duration=duration,
num_inference_steps=steps,
guidance_scale=guidance_scale)
wave = self.vae.decode(latents.transpose(2,1)).sample.cpu()[0]
return wave
from transformers import T5EncoderModel,T5TokenizerFast
import torch
from diffusers import FluxTransformer2DModel
from torch import nn
from typing import List
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.training_utils import compute_density_for_timestep_sampling
import copy
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from typing import Optional,Union,List
from datasets import load_dataset, Audio
from math import pi
import inspect
import yaml
class StableAudioPositionalEmbedding(nn.Module):
"""Used for continuous time
Adapted from stable audio open.
"""
def __init__(self, dim: int):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, times: torch.Tensor) -> torch.Tensor:
times = times[..., None]
freqs = times * self.weights[None] * 2 * pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((times, fouriered), dim=-1)
return fouriered
class DurationEmbedder(nn.Module):
"""
A simple linear projection model to map numbers to a latent space.
Code is adapted from
https://github.com/Stability-AI/stable-audio-tools
Args:
number_embedding_dim (`int`):
Dimensionality of the number embeddings.
min_value (`int`):
The minimum value of the seconds number conditioning modules.
max_value (`int`):
The maximum value of the seconds number conditioning modules
internal_dim (`int`):
Dimensionality of the intermediate number hidden states.
"""
def __init__(
self,
number_embedding_dim,
min_value,
max_value,
internal_dim: Optional[int] = 256,
):
super().__init__()
self.time_positional_embedding = nn.Sequential(
StableAudioPositionalEmbedding(internal_dim),
nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim),
)
self.number_embedding_dim = number_embedding_dim
self.min_value = min_value
self.max_value = max_value
self.dtype = torch.float32
def forward(
self,
floats: torch.Tensor,
):
floats = floats.clamp(self.min_value, self.max_value)
normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value)
# Cast floats to same type as embedder
embedder_dtype = next(self.time_positional_embedding.parameters()).dtype
normalized_floats = normalized_floats.to(embedder_dtype)
embedding = self.time_positional_embedding(normalized_floats)
float_embeds = embedding.view(-1, 1, self.number_embedding_dim)
return float_embeds
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class TangoFlux(nn.Module):
def __init__(self,config,initialize_reference_model=False):
super().__init__()
self.num_layers = config.get('num_layers', 6)
self.num_single_layers = config.get('num_single_layers', 18)
self.in_channels = config.get('in_channels', 64)
self.attention_head_dim = config.get('attention_head_dim', 128)
self.joint_attention_dim = config.get('joint_attention_dim', 1024)
self.num_attention_heads = config.get('num_attention_heads', 8)
self.audio_seq_len = config.get('audio_seq_len', 645)
self.max_duration = config.get('max_duration', 30)
self.uncondition = config.get('uncondition', False)
self.text_encoder_name = config.get('text_encoder_name', "google/flan-t5-large")
self.noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler)
self.max_text_seq_len = 64
self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name)
self.tokenizer = T5TokenizerFast.from_pretrained(self.text_encoder_name)
self.text_embedding_dim = self.text_encoder.config.d_model
self.fc = nn.Sequential(nn.Linear(self.text_embedding_dim,self.joint_attention_dim),nn.ReLU())
self.duration_emebdder = DurationEmbedder(self.text_embedding_dim,min_value=0,max_value=self.max_duration)
self.transformer = FluxTransformer2DModel(
in_channels=self.in_channels,
num_layers=self.num_layers,
num_single_layers=self.num_single_layers,
attention_head_dim=self.attention_head_dim,
num_attention_heads=self.num_attention_heads,
joint_attention_dim=self.joint_attention_dim,
pooled_projection_dim=self.text_embedding_dim,
guidance_embeds=False)
self.beta_dpo = 2000 ## this is used for dpo training
def get_sigmas(self,timesteps, n_dim=3, dtype=torch.float32):
device = self.text_encoder.device
sigmas = self.noise_scheduler_copy.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.noise_scheduler_copy.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def encode_text_classifier_free(self, prompt: List[str], num_samples_per_prompt=1):
device = self.text_encoder.device
batch = self.tokenizer(
prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
)
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
with torch.no_grad():
prompt_embeds = self.text_encoder(
input_ids=input_ids, attention_mask=attention_mask
)[0]
prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0)
# get unconditional embeddings for classifier free guidance
uncond_tokens = [""]
max_length = prompt_embeds.shape[1]
uncond_batch = self.tokenizer(
uncond_tokens, max_length=max_length, padding='max_length', truncation=True, return_tensors="pt",
)
uncond_input_ids = uncond_batch.input_ids.to(device)
uncond_attention_mask = uncond_batch.attention_mask.to(device)
with torch.no_grad():
negative_prompt_embeds = self.text_encoder(
input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
)[0]
negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0)
# For classifier free guidance, we need to do two forward passes.
# We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_mask = torch.cat([uncond_attention_mask, attention_mask])
boolean_prompt_mask = (prompt_mask == 1).to(device)
return prompt_embeds, boolean_prompt_mask
@torch.no_grad()
def encode_text(self, prompt):
device = self.text_encoder.device
batch = self.tokenizer(
prompt, max_length=self.max_text_seq_len, padding=True, truncation=True, return_tensors="pt")
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
encoder_hidden_states = self.text_encoder(
input_ids=input_ids, attention_mask=attention_mask)[0]
boolean_encoder_mask = (attention_mask == 1).to(device)
return encoder_hidden_states, boolean_encoder_mask
def encode_duration(self,duration):
return self.duration_emebdder(duration)
@torch.no_grad()
def inference_flow(self, prompt,
num_inference_steps=50,
timesteps=None,
guidance_scale=3,
duration=10,
disable_progress=False,
num_samples_per_prompt=1):
'''Only tested for single inference. Haven't test for batch inference'''
bsz = num_samples_per_prompt
device = self.transformer.device
scheduler = self.noise_scheduler
if not isinstance(prompt,list):
prompt = [prompt]
if not isinstance(duration,torch.Tensor):
duration = torch.tensor([duration],device=device)
classifier_free_guidance = guidance_scale > 1.0
duration_hidden_states = self.encode_duration(duration)
if classifier_free_guidance:
bsz = 2 * num_samples_per_prompt
encoder_hidden_states, boolean_encoder_mask = self.encode_text_classifier_free(prompt, num_samples_per_prompt=num_samples_per_prompt)
duration_hidden_states = duration_hidden_states.repeat(bsz,1,1)
else:
encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt,num_samples_per_prompt=num_samples_per_prompt)
mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(encoder_hidden_states)
masked_data = torch.where(mask_expanded, encoder_hidden_states, torch.tensor(float('nan')))
pooled = torch.nanmean(masked_data, dim=1)
pooled_projection = self.fc(pooled)
encoder_hidden_states = torch.cat([encoder_hidden_states,duration_hidden_states],dim=1) ## (bs,seq_len,dim)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
timesteps, num_inference_steps = retrieve_timesteps(
scheduler,
num_inference_steps,
device,
timesteps,
sigmas
)
latents = torch.randn(num_samples_per_prompt,self.audio_seq_len,64)
weight_dtype = latents.dtype
progress_bar = tqdm(range(num_inference_steps), disable=disable_progress)
txt_ids = torch.zeros(bsz,encoder_hidden_states.shape[1],3).to(device)
audio_ids = torch.arange(self.audio_seq_len).unsqueeze(0).unsqueeze(-1).repeat(bsz,1,3).to(device)
timesteps = timesteps.to(device)
latents = latents.to(device)
encoder_hidden_states = encoder_hidden_states.to(device)
for i, t in enumerate(timesteps):
latents_input = torch.cat([latents] * 2) if classifier_free_guidance else latents
noise_pred = self.transformer(
hidden_states=latents_input,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=torch.tensor([t/1000],device=device),
guidance = None,
pooled_projections=pooled_projection,
encoder_hidden_states=encoder_hidden_states,
txt_ids=txt_ids,
img_ids=audio_ids,
return_dict=False,
)[0]
if classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = scheduler.step(noise_pred, t, latents).prev_sample
return latents
def forward(self,
latents,
prompt,
duration=torch.tensor([10]),
sft=True
):
device = latents.device
audio_seq_length = self.audio_seq_len
bsz = latents.shape[0]
encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt)
duration_hidden_states = self.encode_duration(duration)
mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(encoder_hidden_states)
masked_data = torch.where(mask_expanded, encoder_hidden_states, torch.tensor(float('nan')))
pooled = torch.nanmean(masked_data, dim=1)
pooled_projection = self.fc(pooled)
## Add duration hidden states to encoder hidden states
encoder_hidden_states = torch.cat([encoder_hidden_states,duration_hidden_states],dim=1) ## (bs,seq_len,dim)
txt_ids = torch.zeros(bsz,encoder_hidden_states.shape[1],3).to(device)
audio_ids = torch.arange(audio_seq_length).unsqueeze(0).unsqueeze(-1).repeat(bsz,1,3).to(device)
if sft:
if self.uncondition:
mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1]
if len(mask_indices) > 0:
encoder_hidden_states[mask_indices] = 0
noise = torch.randn_like(latents)
u = compute_density_for_timestep_sampling(
weighting_scheme='logit_normal',
batch_size=bsz,
logit_mean=0,
logit_std=1,
mode_scale=None,
)
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device)
sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
model_pred = self.transformer(
hidden_states=noisy_model_input,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projection,
img_ids=audio_ids,
txt_ids=txt_ids,
guidance=None,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timesteps/1000,
return_dict=False)[0]
target = noise - latents
loss = torch.mean(
( (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1,
)
loss = loss.mean()
raw_model_loss, raw_ref_loss,implicit_acc,epsilon_diff = 0,0,0,0 ## default this to 0 if doing sft
else:
encoder_hidden_states = encoder_hidden_states.repeat(2, 1, 1)
pooled_projection = pooled_projection.repeat(2,1)
noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1) ## Have to sample same noise for preferred and rejected
u = compute_density_for_timestep_sampling(
weighting_scheme='logit_normal',
batch_size=bsz//2,
logit_mean=0,
logit_std=1,
mode_scale=None,
)
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device)
timesteps = timesteps.repeat(2)
sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
model_pred = self.transformer(
hidden_states=noisy_model_input,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projection,
img_ids=audio_ids,
txt_ids=txt_ids,
guidance=None,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timesteps/1000,
return_dict=False)[0]
target = noise - latents
model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
model_losses_w, model_losses_l = model_losses.chunk(2)
model_diff = model_losses_w - model_losses_l
raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())
with torch.no_grad():
ref_preds = self.ref_transformer(
hidden_states=noisy_model_input,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projection,
img_ids=audio_ids,
txt_ids=txt_ids,
guidance=None,
timestep=timesteps/1000,
return_dict=False)[0]
ref_loss = F.mse_loss(ref_preds.float(), target.float(), reduction="none")
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
ref_diff = ref_losses_w - ref_losses_l
raw_ref_loss = ref_loss.mean()
epsilon_diff = torch.max(torch.zeros_like(model_losses_w),
ref_losses_w-model_losses_w).mean()
scale_term = -0.5 * self.beta_dpo
inside_term = scale_term * (model_diff - ref_diff)
implicit_acc = (scale_term * (model_diff - ref_diff) > 0).sum().float() / inside_term.size(0)
loss = -1 * F.logsigmoid(inside_term).mean() + model_losses_w.mean()
return loss, raw_model_loss, raw_ref_loss, implicit_acc,epsilon_diff
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment