Commit 1d9ad5d4 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2728 failed with stages
in 0 seconds
from typing import List, Optional, Tuple, Union
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModelForCausalLM, \
LlamaConfig, LlamaModel, LlamaForCausalLM, AutoTokenizer
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from ..blip3o_arch import blip3oMetaModel, blip3oMetaForCausalLM
from blip3o.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_IDX, DEFAULT_IM_START_TOKEN_IDX, DEFAULT_IM_END_TOKEN_IDX
import pdb
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import numpy_to_pil
import numpy as np
from diffusers.models import AutoencoderKL
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
class blip3oConfig(LlamaConfig):
model_type = "blip3o_llama"
class blip3oLlamaModel(blip3oMetaModel, LlamaModel):
config_class = blip3oConfig
def __init__(self, config: LlamaConfig):
super(blip3oLlamaModel, self).__init__(config)
class blip3oLlamaForCausalLM(LlamaForCausalLM, blip3oMetaForCausalLM):
config_class = blip3oConfig
def __init__(self, config):
super(LlamaForCausalLM, self).__init__(config)
self.model = blip3oLlamaModel(config)
self.pretraining_tp = config.pretraining_tp
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.dist = None
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
ids: Optional[list] = None,
i_s_pos: Optional[list] = None,
image_type: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
gen_image: Optional[torch.FloatTensor] = None,
und_image: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
latents
) = self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
gen_image,
und_image,
i_s_pos,
image_sizes
)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
total_loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
# compute image loss
# target_img_embeds = torch.clone(inputs_embeds.detach())[:,1:,:] # get target image emb
img_loss_funct = torch.nn.MSELoss()
# img_hidden_states = self.get_model().down_projector(hidden_states[:,-self.get_n_query():,:])
img_hidden_states = []
for b in range(hidden_states.shape[0]):
img_hidden_states.append(hidden_states[b,i_s_pos[b]:i_s_pos[b]+64,:])
img_hidden_states = torch.stack(img_hidden_states,dim=0)
img_hidden_states = self.get_model().down_projector(img_hidden_states)
# img_loss = 0.0
if latents is None:
img_loss = img_loss_funct(img_hidden_states, torch.clone(img_hidden_states.detach()))
else:
bsz = latents.shape[0]
# device = latents.device
dtype = latents.dtype
noise = torch.randn_like(latents, device=latents.device)
u = torch.rand(size=(bsz,), device="cpu")
indices = (u * self.get_model().noise_scheduler.config.num_train_timesteps).long()
timesteps = self.get_model().noise_scheduler.timesteps[indices].to(device=latents.device)
sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=dtype)
noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
noise_pred = self.get_model().dit(
x=noisy_latents,
timestep=timesteps,
z_latents=self.mask_drop(img_hidden_states),
)
target = noise - latents
img_loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
print(f"img loss {img_loss}, text loss {loss}")
total_loss = img_loss
return CausalLMOutputWithPast(
loss=total_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if images is not None:
(
inputs,
position_ids,
attention_mask,
_,
inputs_embeds,
img_indicator,
_
) = self.prepare_inputs_labels_for_understanding(
inputs,
position_ids,
attention_mask,
None,
None,
images,
image_sizes=image_sizes
)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)
return super().generate(
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs
)
@torch.no_grad()
def generate_image(
self,
text: List[str],
tokenizer: AutoTokenizer,
image: Optional[torch.Tensor] = None,
max_var: Optional[float] = None,
# placeholder: str = DEFAULT_IMG_PLACEHOLDER,
):
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler")
vision_tower = self.get_vision_tower()
mm_projector = self.get_mm_projector()
N_QUERY = self.get_n_query()
if image is not None:
# image: [Batch, 3, 448, 448]
prompt_image_embeds = vision_tower(batch_images)
num_img, _, c = prompt_image_embeds.shape # [batch, 576, 1024]
all_image_embeds = torch.clone(prompt_image_embeds).detach()
prompt_image_embeds = prompt_image_embeds.contiguous().view(-1, c)
prompt_image_embeds = mm_projector(prompt_image_embeds)
inputs = tokenizer(text, padding="longest", return_tensors="pt")
device = self.get_model().device
attention_mask = inputs.attention_mask.to(device)
input_ids = inputs.input_ids.to(device) # B x N
input_ids = torch.cat([input_ids, torch.tensor([[198]]).to(device)], dim=1)
# breakpoint()
text_embeds = self.get_model().embed_tokens(input_ids)
latent_queries = self.get_model().latent_queries.repeat(text_embeds.shape[0], 1, 1)
text_embeds = torch.cat([text_embeds, latent_queries], dim=1)
attention_mask = torch.cat([attention_mask, torch.ones_like(latent_queries[:, :, 0])], dim=1)
outputs = self.model(
inputs_embeds=text_embeds,
# img_indicator=img_indicator,
# concept_indicator=concept_indicator if self.use_concept_token else None,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True,
)
hidden_states = outputs.hidden_states[-1][:,-N_QUERY:,:]
img_hidden_states = self.get_model().down_projector(hidden_states)
output_img = self.sample_images(img_hidden_states, scheduler)
output_img = output_img.view(1, 1792, -1).permute(0,2,1).contiguous()
return output_img
def sample_images(
self,
img_hidden_states,
scheduler,
guidance_scale: float = 3.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
num_inference_steps: int = 30,
num_images_per_prompt: int = 1,
return_tensor=False,
**kwargs,
):
device = img_hidden_states.device
dtype = img_hidden_states.dtype
img_hidden_states_null = torch.zeros_like(img_hidden_states, device=device, dtype=dtype)
img_hidden_states_input = torch.cat([img_hidden_states_null, img_hidden_states], 0)
batch_size = img_hidden_states.shape[0]
latent_size = self.get_model().dit.config.input_size
latent_channels = self.get_model().dit.config.in_channels
latents = randn_tensor(
shape=(batch_size * num_images_per_prompt, latent_channels, latent_size, latent_size),
generator=generator,
device=device,
dtype=dtype,
)
# set step values
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
scheduler.set_timesteps(num_inference_steps, sigmas=sigmas)
# Repeat z_latents and conditions for each image per prompt
img_hidden_states_input = img_hidden_states_input.repeat_interleave(num_images_per_prompt, dim=0)
for t in scheduler.timesteps:
latent_model_input = latents.repeat(2, 1, 1, 1)
if hasattr(scheduler, "scale_model_input"):
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict noise model_output
noise_pred = self.get_model().dit(
x=latent_model_input,
timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(latent_model_input.device, torch.long),
z_latents=img_hidden_states_input,
)
# perform guidance
noise_pred_uncond, noise_pred = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
# compute previous image: x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents).prev_sample
# samples = self.decode_latents(latents, return_tensor=return_tensor)
return latents
def decode_latents(self, latents, normalize=True, return_tensor=False):
if isinstance(self.get_model().vae, AutoencoderKL):
latents = latents / self.get_model().vae.config.scaling_factor
if self.get_model().vae.config.shift_factor is not None:
latents = latents + self.get_model().vae.config.shift_factor
latents = latents.to(dtype=torch.float32)
samples = self.get_model().vae.decode(latents).sample
else:
samples = self.get_model().vae.decode(latents)
if normalize:
samples = (samples / 2 + 0.5).clamp(0, 1)
else:
samples = samples.clamp(-1, 1)
if return_tensor:
return samples
samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
samples = numpy_to_pil(samples)
return samples
def prepare_and_encode_inputs(
self,
inputs: List[str | Image.Image],
tokenizer: AutoTokenizer,
do_classifier_free_guidance: bool = False,
):
# pdb.set_trace()
device = self.get_model().device
dtype = self.get_model().dtype
has_image, has_text = False, False
text_prompt, image_prompt = "", []
img_processor = self.get_vision_tower().image_processor
negative_prompt = {}
for x in inputs:
if isinstance(x, str):
has_text = True
text_prompt += x
else:
has_image = True
text_prompt += DEFAULT_IMAGE_TOKEN
image_prompt.append(img_processor.preprocess(x, return_tensors='pt')['pixel_values'])
# pdb.set_trace()
if len(image_prompt) == 0:
image_prompt = None
else:
image_prompt = torch.cat(image_prompt)
image_prompt = image_prompt.type(dtype).to(device)
if has_image and not has_text:
prompt = self.encode_images(image_prompt)
# pdb.set_trace()
if do_classifier_free_guidance:
key = "[NULL_IMAGE]"
if key not in negative_prompt:
negative_image = torch.zeros_like(image_prompt)
negative_prompt[key] = self.encode_images(negative_image)
prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
else:
prompt = self.generate_image(text=[text_prompt], image=image_prompt, tokenizer=tokenizer)
if do_classifier_free_guidance:
key = ""
if key not in negative_prompt:
negative_prompt[key] = self.generate_image(text=[""], tokenizer=tokenizer)
prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
gen_pooling = self.get_gen_pooling()
n_query = self.get_n_query()
num_img, _, c = prompt.shape
if 'pool2d' in gen_pooling and has_text and not 'early' in gen_pooling:
stride = int(gen_pooling.split('_')[1])
sqrt_n = int(n_query**0.5)
prompt = prompt.permute(0, 2, 1).reshape(num_img, -1, sqrt_n, sqrt_n)
prompt = F.avg_pool2d(prompt, kernel_size=(stride, stride), stride=stride)
prompt = prompt.reshape(num_img, c, -1).permute(0,2,1)
return prompt
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
if images is not None:
inputs['images'] = images
if image_sizes is not None:
inputs['image_sizes'] = image_sizes
return inputs
AutoConfig.register("blip3o_llama", blip3oConfig)
AutoModelForCausalLM.register(blip3oConfig, blip3oLlamaForCausalLM)
from typing import List, Optional, Tuple, Union, Dict
import torch
import torch.nn as nn
from PIL import Image
import torch.nn.functional as F
import transformers
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from blip3o.model.blip3o_arch import blip3oMetaModel, blip3oMetaForCausalLM
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLModel, Qwen2_5_VLForConditionalGeneration
from blip3o.constants import UND_IMAGE_TOKEN_IDX
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import numpy_to_pil
import numpy as np
from diffusers.models import AutoencoderKL
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
class blip3oQwenConfig(Qwen2_5_VLConfig):
model_type = "blip3o_qwen"
class blip3oQwenModel(blip3oMetaModel, Qwen2_5_VLModel):
config_class = blip3oQwenConfig
def __init__(self, config: Qwen2_5_VLConfig):
super(blip3oQwenModel, self).__init__(config)
class blip3oQwenForCausalLM(Qwen2_5_VLForConditionalGeneration, blip3oMetaForCausalLM):
config_class = blip3oQwenConfig
def __init__(self, config):
Qwen2_5_VLForConditionalGeneration.__init__(self, config)
config.model_type = "blip3o_qwen"
self.model = blip3oQwenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
ids: Optional[list] = None,
i_s_pos: Optional[list] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
gen_image: Optional[torch.FloatTensor] = None,
und_image: Optional[torch.FloatTensor] = None,
grid_thw: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
latents
) = self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
gen_image,
und_image,
grid_thw,
i_s_pos,
image_sizes
)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
total_loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
# compute image loss
# target_img_embeds = torch.clone(inputs_embeds.detach())[:,1:,:] # get target image emb
img_loss_funct = torch.nn.MSELoss()
# img_hidden_states = self.get_model().down_projector(hidden_states[:,-self.get_n_query():,:])
img_hidden_states = []
for b in range(hidden_states.shape[0]):
img_hidden_states.append(hidden_states[b,i_s_pos[b]:i_s_pos[b]+64,:])
img_hidden_states = torch.stack(img_hidden_states,dim=0)
img_hidden_states = self.get_model().down_projector(img_hidden_states)
# img_loss = 0.0
if latents is None:
img_loss = img_loss_funct(img_hidden_states, torch.clone(img_hidden_states.detach()))
else:
bsz = latents.shape[0]
# device = latents.device
dtype = latents.dtype
noise = torch.randn_like(latents, device=latents.device)
u = torch.rand(size=(bsz,), device="cpu")
indices = (u * self.get_model().noise_scheduler.config.num_train_timesteps).long()
timesteps = self.get_model().noise_scheduler.timesteps[indices].to(device=latents.device)
sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=dtype)
noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
noise_pred = self.get_model().dit(
x=noisy_latents,
timestep=timesteps,
z_latents=self.mask_drop(img_hidden_states),
)
target = noise - latents
img_loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
print(f"img loss {img_loss}")
total_loss = img_loss
return CausalLMOutputWithPast(
loss=total_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if images is not None:
(
inputs,
position_ids,
attention_mask,
_,
inputs_embeds,
img_indicator,
_
) = self.prepare_inputs_labels_for_understanding(
inputs,
position_ids,
attention_mask,
None,
None,
images,
image_sizes=image_sizes
)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)
return super().generate(
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs
)
@torch.no_grad()
def generate_image(
self,
text: List[str],
tokenizer: AutoTokenizer,
pixel_values: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.Tensor] = None,
max_var: Optional[float] = None,
# placeholder: str = DEFAULT_IMG_PLACEHOLDER,
):
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler")
N_QUERY = self.get_n_query()
inputs = tokenizer(text, padding="longest", return_tensors="pt")
device = self.get_model().device
attention_mask = inputs.attention_mask.to(device)
input_ids = inputs.input_ids.to(device) # B x N
input_ids = torch.cat([input_ids, torch.tensor([[151665]]).to(device)], dim=1)
# breakpoint()
text_embeds = self.get_model().embed_tokens(input_ids)
latent_queries = self.get_model().latent_queries.repeat(text_embeds.shape[0], 1, 1)
if pixel_values is not None:
und_image_idx = (input_ids == UND_IMAGE_TOKEN_IDX)
pixel_values = pixel_values.type(self.visual.dtype)
und_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
text_embeds[und_image_idx] = und_image_embeds.to(text_embeds.device)[:und_image_idx.sum(), :]
text_embeds = torch.cat([text_embeds, latent_queries], dim=1)
attention_mask = torch.cat([attention_mask, torch.ones_like(latent_queries[:, :, 0])], dim=1)
outputs = self.model(
inputs_embeds=text_embeds,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True,
)
hidden_states = outputs.hidden_states[-1][:,-N_QUERY:,:]
img_hidden_states = hidden_states
output_img = self.sample_images(img_hidden_states, scheduler)
output_img = output_img.view(1, 1792, -1).permute(0,2,1).contiguous()
return output_img
def sample_images(
self,
img_hidden_states,
scheduler,
guidance_scale: float = 3.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
num_inference_steps: int = 30,
num_images_per_prompt: int = 1,
return_tensor=False,
**kwargs,
):
device = img_hidden_states.device
dtype = img_hidden_states.dtype
img_hidden_states_null = torch.zeros_like(img_hidden_states, device=device, dtype=dtype)
img_hidden_states_input = torch.cat([img_hidden_states_null, img_hidden_states], 0)
batch_size = img_hidden_states.shape[0]
latent_size = self.get_model().dit.config.input_size
latent_channels = self.get_model().dit.config.in_channels
latents = randn_tensor(
shape=(batch_size * num_images_per_prompt, latent_channels, latent_size, latent_size),
generator=generator,
device=device,
dtype=dtype,
)
# set step values
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
scheduler.set_timesteps(num_inference_steps, sigmas=sigmas)
# Repeat z_latents and conditions for each image per prompt
img_hidden_states_input = img_hidden_states_input.repeat_interleave(num_images_per_prompt, dim=0)
for t in scheduler.timesteps:
latent_model_input = latents.repeat(2, 1, 1, 1)
if hasattr(scheduler, "scale_model_input"):
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict noise model_output
noise_pred = self.get_model().dit(
x=latent_model_input,
timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(latent_model_input.device, torch.long),
z_latents=img_hidden_states_input,
)
# perform guidance
noise_pred_uncond, noise_pred = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
# compute previous image: x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents).prev_sample
# samples = self.decode_latents(latents, return_tensor=return_tensor)
# breakpoint()
return latents
def decode_latents(self, latents, normalize=True, return_tensor=False):
if isinstance(self.get_model().vae, AutoencoderKL):
latents = latents / self.get_model().vae.config.scaling_factor
if self.get_model().vae.config.shift_factor is not None:
latents = latents + self.get_model().vae.config.shift_factor
latents = latents.to(dtype=torch.float32)
samples = self.get_model().vae.decode(latents).sample
else:
samples = self.get_model().vae.decode(latents)
if normalize:
samples = (samples / 2 + 0.5).clamp(0, 1)
else:
samples = samples.clamp(-1, 1)
if return_tensor:
return samples
samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
samples = numpy_to_pil(samples)
return samples
def prepare_and_encode_inputs(
self,
inputs: List[str | Image.Image],
tokenizer: AutoTokenizer,
do_classifier_free_guidance: bool = False,
):
# pdb.set_trace()
device = self.get_model().device
dtype = self.get_model().dtype
has_image, has_text = False, False
text_prompt, image_prompt = "", []
img_processor = self.get_vision_tower().image_processor
negative_prompt = {}
for x in inputs:
if isinstance(x, str):
has_text = True
text_prompt += x
else:
has_image = True
text_prompt += DEFAULT_IMAGE_TOKEN
image_prompt.append(img_processor.preprocess(x, return_tensors='pt')['pixel_values'])
# pdb.set_trace()
if len(image_prompt) == 0:
image_prompt = None
else:
image_prompt = torch.cat(image_prompt)
image_prompt = image_prompt.type(dtype).to(device)
if has_image and not has_text:
prompt = self.encode_images(image_prompt)
# pdb.set_trace()
if do_classifier_free_guidance:
key = "[NULL_IMAGE]"
if key not in negative_prompt:
negative_image = torch.zeros_like(image_prompt)
negative_prompt[key] = self.encode_images(negative_image)
prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
else:
prompt = self.generate_image(text=[text_prompt], image=image_prompt, tokenizer=tokenizer)
if do_classifier_free_guidance:
key = ""
if key not in negative_prompt:
negative_prompt[key] = self.generate_image(text=[""], tokenizer=tokenizer)
prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
gen_pooling = self.get_gen_pooling()
n_query = self.get_n_query()
num_img, _, c = prompt.shape
if 'pool2d' in gen_pooling and has_text and not 'early' in gen_pooling:
stride = int(gen_pooling.split('_')[1])
sqrt_n = int(n_query**0.5)
prompt = prompt.permute(0, 2, 1).reshape(num_img, -1, sqrt_n, sqrt_n)
prompt = F.avg_pool2d(prompt, kernel_size=(stride, stride), stride=stride)
prompt = prompt.reshape(num_img, c, -1).permute(0,2,1)
return prompt
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
if images is not None:
inputs['images'] = images
if image_sizes is not None:
inputs['image_sizes'] = image_sizes
return inputs
AutoConfig.register("blip3o_qwen", blip3oQwenConfig)
AutoModelForCausalLM.register(blip3oQwenConfig, blip3oQwenForCausalLM)
# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention import LuminaFeedForward
from diffusers.models.attention_processor import Attention, LuminaAttnProcessor2_0
from diffusers.models.embeddings import LuminaCombinedTimestepCaptionEmbedding, LuminaPatchEmbed, PixArtAlphaTextProjection
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
from diffusers.utils import is_torch_version, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class LuminaNextDiTBlock(nn.Module):
"""
A LuminaNextDiTBlock for LuminaNextDiT2DModel.
Parameters:
dim (`int`): Embedding dimension of the input features.
num_attention_heads (`int`): Number of attention heads.
num_kv_heads (`int`):
Number of attention heads in key and value features (if using GQA), or set to None for the same as query.
multiple_of (`int`): The number of multiple of ffn layer.
ffn_dim_multiplier (`float`): The multipier factor of ffn layer dimension.
norm_eps (`float`): The eps for norm layer.
qk_norm (`bool`): normalization for query and key.
cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states.
norm_elementwise_affine (`bool`, *optional*, defaults to True),
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
num_kv_heads: int,
multiple_of: int,
ffn_dim_multiplier: float,
norm_eps: float,
qk_norm: bool,
cross_attention_dim: int,
norm_elementwise_affine: bool = True,
) -> None:
super().__init__()
self.head_dim = dim // num_attention_heads
self.gate = nn.Parameter(torch.zeros([num_attention_heads]))
# Self-attention
self.attn1 = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=dim // num_attention_heads,
qk_norm="layer_norm_across_heads" if qk_norm else None,
heads=num_attention_heads,
kv_heads=num_kv_heads,
eps=1e-5,
bias=False,
out_bias=False,
processor=LuminaAttnProcessor2_0(),
)
self.attn1.to_out = nn.Identity()
# Cross-attention
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
dim_head=dim // num_attention_heads,
qk_norm="layer_norm_across_heads" if qk_norm else None,
heads=num_attention_heads,
kv_heads=num_kv_heads,
eps=1e-5,
bias=False,
out_bias=False,
processor=LuminaAttnProcessor2_0(),
)
self.feed_forward = LuminaFeedForward(
dim=dim,
inner_dim=4 * dim,
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
)
self.norm1 = LuminaRMSNormZero(
embedding_dim=dim,
norm_eps=norm_eps,
norm_elementwise_affine=norm_elementwise_affine,
)
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.norm1_context = RMSNorm(cross_attention_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
image_rotary_emb: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_mask: torch.Tensor,
temb: torch.Tensor,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Perform a forward pass through the LuminaNextDiTBlock.
Parameters:
hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock.
attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask.
image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies.
encoder_hidden_states: (`torch.Tensor`): The hidden_states of text prompt are processed by Gemma encoder.
encoder_mask (`torch.Tensor`): The hidden_states of text prompt attention mask.
temb (`torch.Tensor`): Timestep embedding with text prompt embedding.
cross_attention_kwargs (`Dict[str, Any]`): kwargs for cross attention.
"""
residual = hidden_states
# Self-attention
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
self_attn_output = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
attention_mask=attention_mask,
query_rotary_emb=image_rotary_emb,
key_rotary_emb=image_rotary_emb,
**cross_attention_kwargs,
)
# Cross-attention
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
cross_attn_output = self.attn2(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
attention_mask=encoder_mask,
query_rotary_emb=image_rotary_emb,
key_rotary_emb=None,
**cross_attention_kwargs,
)
cross_attn_output = cross_attn_output * self.gate.tanh().view(1, 1, -1, 1)
mixed_attn_output = self_attn_output + cross_attn_output
mixed_attn_output = mixed_attn_output.flatten(-2)
# linear proj
hidden_states = self.attn2.to_out[0](mixed_attn_output)
hidden_states = residual + gate_msa.unsqueeze(1).tanh() * self.norm2(hidden_states)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
return hidden_states
class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
"""
LuminaNextDiT: Diffusion model with a Transformer backbone.
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
Parameters:
sample_size (`int`): The width of the latent images. This is fixed during training since
it is used to learn a number of position embeddings.
patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
in_channels (`int`, *optional*, defaults to 4):
The number of input channels for the model. Typically, this matches the number of channels in the input
images.
hidden_size (`int`, *optional*, defaults to 4096):
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
hidden representations.
num_layers (`int`, *optional*, default to 32):
The number of layers in the model. This defines the depth of the neural network.
num_attention_heads (`int`, *optional*, defaults to 32):
The number of attention heads in each attention layer. This parameter specifies how many separate attention
mechanisms are used.
num_kv_heads (`int`, *optional*, defaults to 8):
The number of key-value heads in the attention mechanism, if different from the number of attention heads.
If None, it defaults to num_attention_heads.
multiple_of (`int`, *optional*, defaults to 256):
A factor that the hidden size should be a multiple of. This can help optimize certain hardware
configurations.
ffn_dim_multiplier (`float`, *optional*):
A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
the model configuration.
norm_eps (`float`, *optional*, defaults to 1e-5):
A small value added to the denominator for numerical stability in normalization layers.
learn_sigma (`bool`, *optional*, defaults to True):
Whether the model should learn the sigma parameter, which might be related to uncertainty or variance in
predictions.
qk_norm (`bool`, *optional*, defaults to True):
Indicates if the queries and keys in the attention mechanism should be normalized.
cross_attention_dim (`int`, *optional*, defaults to 2048):
The dimensionality of the text embeddings. This parameter defines the size of the text representations used
in the model.
scaling_factor (`float`, *optional*, defaults to 1.0):
A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
overall scale of the model's operations.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["LuminaNextDiTBlock"]
@register_to_config
def __init__(
self,
sample_size: int = 128,
patch_size: Optional[int] = 2,
in_channels: Optional[int] = 4,
hidden_size: Optional[int] = 2304,
num_layers: Optional[int] = 32, # 32
num_attention_heads: Optional[int] = 32, # 32
num_kv_heads: Optional[int] = None,
multiple_of: Optional[int] = 256,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: Optional[float] = 1e-5,
learn_sigma: Optional[bool] = True,
qk_norm: Optional[bool] = True,
cross_attention_dim: Optional[int] = 2048,
scaling_factor: Optional[float] = 1.0,
) -> None:
super().__init__()
self.sample_size = sample_size
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.head_dim = hidden_size // num_attention_heads
self.scaling_factor = scaling_factor
self.gradient_checkpointing = False
self.caption_projection = PixArtAlphaTextProjection(in_features=cross_attention_dim, hidden_size=hidden_size)
self.patch_embedder = LuminaPatchEmbed(patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, bias=True)
self.time_caption_embed = LuminaCombinedTimestepCaptionEmbedding(hidden_size=min(hidden_size, 1024), cross_attention_dim=hidden_size)
self.layers = nn.ModuleList(
[
LuminaNextDiTBlock(
hidden_size,
num_attention_heads,
num_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
hidden_size,
)
for _ in range(num_layers)
]
)
self.norm_out = LuminaLayerNormContinuous(
embedding_dim=hidden_size,
conditioning_embedding_dim=min(hidden_size, 1024),
elementwise_affine=False,
eps=1e-6,
bias=True,
out_dim=patch_size * patch_size * self.out_channels,
)
# self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels)
assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_mask: torch.Tensor,
image_rotary_emb: torch.Tensor,
cross_attention_kwargs: Dict[str, Any] = None,
return_dict=True,
) -> torch.Tensor:
"""
Forward pass of LuminaNextDiT.
Parameters:
hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W).
timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,).
encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D).
encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L).
"""
hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb)
image_rotary_emb = image_rotary_emb.to(hidden_states.device)
# breakpoint()
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask)
encoder_mask = encoder_mask.bool()
for layer in self.layers:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states,
mask,
image_rotary_emb,
encoder_hidden_states,
encoder_mask,
temb,
cross_attention_kwargs,
**ckpt_kwargs,
)
else:
hidden_states = layer(
hidden_states,
mask,
image_rotary_emb,
encoder_hidden_states,
encoder_mask,
temb=temb,
cross_attention_kwargs=cross_attention_kwargs,
)
hidden_states = self.norm_out(hidden_states, temb)
# unpatchify
height_tokens = width_tokens = self.patch_size
height, width = img_size[0]
batch_size = hidden_states.size(0)
sequence_length = (height // height_tokens) * (width // width_tokens)
hidden_states = hidden_states[:, :sequence_length].view(
batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels
)
output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
import argparse
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from blip3o.model.utils import auto_upgrade
def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
print("Loading base model")
base = AutoModelForCausalLM.from_pretrained(
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
print("Loading target model")
auto_upgrade(target_model_path)
target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
print("Calculating delta")
for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
if name not in base.state_dict():
assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
continue
if param.data.shape == base.state_dict()[name].shape:
param.data -= base.state_dict()[name]
else:
assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
bparam = base.state_dict()[name]
param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
print("Saving delta")
if hub_repo_id:
kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
else:
kwargs = {}
target.save_pretrained(delta_path, **kwargs)
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
target_tokenizer.save_pretrained(delta_path, **kwargs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-model-path", type=str, required=True)
parser.add_argument("--target-model-path", type=str, required=True)
parser.add_argument("--delta-path", type=str, required=True)
parser.add_argument("--hub-repo-id", type=str, default=None)
args = parser.parse_args()
make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
import os
from .clip_encoder import CLIPVisionTower
from .imagebind import ImageBindWrapper
from .open_clip_encoder import OpenCLIPVisionTower
from .siglip_encoder import SigLipVisionTower
from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
from .eva_clip.eva_clip_encoder import EvaClipVisionTower
from .dev_eva_clip.eva_vit import EvaViTWrapper
from blip3o.model.nextdit_crossattn import NextDiTCrossAttnConfig, NextDiTCrossAttn
from diffusers.models import AutoencoderKL
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
is_absolute_path_exists = os.path.exists(vision_tower)
use_s2 = getattr(vision_tower_cfg, 's2', False)
if "siglip" in vision_tower:
return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
if "eva" in vision_tower:
return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
if use_s2:
return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
else:
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
raise ValueError(f'Unknown vision tower: {vision_tower}')
def build_gen_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(vision_tower_cfg, 'gen_vision_tower')
is_absolute_path_exists = os.path.exists(vision_tower)
use_s2 = getattr(vision_tower_cfg, 's2', False)
if "siglip" in vision_tower:
return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
if "eva" in vision_tower:
return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
if use_s2:
return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
else:
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
raise ValueError(f'Unknown vision tower: {vision_tower}')
def build_dit(vision_tower_cfg, **kwargs):
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae")
# vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
dit = NextDiTCrossAttn(NextDiTCrossAttnConfig())
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler")
# scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler")
vae.eval()
vae.requires_grad_(False)
return dit, vae, noise_scheduler
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment