# Copyright 2025 Bytedance Ltd. and/or its affiliates. # Copyright (c) 2024 The Qwen Team and The HuggingFace Inc. team. # SPDX-License-Identifier: Apache-2.0 # # This file has been modified by ByteDance Ltd. and/or its affiliates. # # Original file was released under Apache-2.0, with the full license text # available at https://github.com/huggingface/transformers/blob/main/LICENSE. import math from dataclasses import dataclass from typing import Any import numpy as np import torch from torch import nn from torch.nn.attention.flex_attention import flex_attention from transformers.models.qwen2.configuration_qwen2 import Qwen2Config from transformers.models.qwen2.modeling_qwen2 import ( Qwen2Attention, Qwen2MLP, Qwen2PreTrainedModel, Qwen2RMSNorm, Qwen2RotaryEmbedding, ) from transformers.utils import ModelOutput from vllm.transformers_utils.configs.bagel import BagelConfig from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm_omni.diffusion.layers.rope import RotaryEmbedding def patchify(imgs, p): """ imgs: (N, 3, H, W) or (3, H, W) x: (N, L, patch_size**2 *3) or (L, patch_size**2 *3) """ is_batch = imgs.ndim == 4 if not is_batch: imgs = imgs.unsqueeze(0) # n: batch, c: channel, h: grid_h, p: patch_h, w: grid_w, q: patch_w x = imgs.reshape(imgs.shape[0], 3, imgs.shape[2] // p, p, imgs.shape[3] // p, p) # Permute to (n, grid_h, grid_w, c, patch_h, patch_w) to match Conv2d (c, h, w) flattening x = torch.einsum("nchpwq->nhwcpq", x) x = x.reshape(imgs.shape[0], -1, 3 * p**2) if not is_batch: x = x.squeeze(0) return x class MLPconnector(nn.Module): def __init__(self, input_dim, output_dim, activation="gelu_pytorch_tanh"): super().__init__() self.fc1 = nn.Linear(input_dim, output_dim) if activation == "gelu": self.act = nn.GELU() elif activation == "gelu_pytorch_tanh": self.act = nn.GELU(approximate="tanh") else: self.act = nn.ReLU() self.fc2 = nn.Linear(output_dim, output_dim) def forward(self, x): return self.fc2(self.act(self.fc1(x))) torch._dynamo.config.cache_size_limit = 512 torch._dynamo.config.accumulated_cache_size_limit = 4096 flex_attention = torch.compile(flex_attention) class Qwen2MoTConfig(Qwen2Config): """Configuration for Qwen2MoT (Mixture of Tokens) model. This is fundamentally different from Qwen2, hence the distinct name. """ model_type = "qwen2_mot" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, vocab_size=151936, hidden_size=4096, intermediate_size=22016, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, hidden_act="silu", max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, is_causal=True, _attn_implementation="eager", qk_norm=True, layer_module="Qwen2MoTDecoderLayer", freeze_und=False, **kwargs, ): super().__init__( vocab_size=vocab_size, hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, hidden_act=hidden_act, max_position_embeddings=max_position_embeddings, initializer_range=initializer_range, rms_norm_eps=rms_norm_eps, use_cache=use_cache, tie_word_embeddings=tie_word_embeddings, rope_theta=rope_theta, rope_scaling=rope_scaling, use_sliding_window=use_sliding_window, sliding_window=sliding_window, max_window_layers=max_window_layers, attention_dropout=attention_dropout, is_causal=is_causal, _attn_implementation=_attn_implementation, **kwargs, ) self.qk_norm = qk_norm self.layer_module = layer_module class NaiveCache: def __init__(self, num_layers): self.key_cache = {k: None for k in range(num_layers)} self.value_cache = {k: None for k in range(num_layers)} @property def num_layers(self): return len(self.key_cache) @property def seq_lens(self): if self.key_cache[0] is not None: return self.key_cache[0].shape[0] else: return 0 @dataclass class BaseNavitOutputWithPast(ModelOutput): packed_query_sequence: torch.FloatTensor = None past_key_values: NaiveCache | None = None class PackedAttentionMoT(Qwen2Attention): def __init__(self, config, layer_idx: int | None = None): super().__init__(config, layer_idx) self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.q_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.head_dim = config.hidden_size // config.num_attention_heads head_dim = self.head_dim self.q_proj_moe_gen = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=True) self.k_proj_moe_gen = nn.Linear(config.hidden_size, config.num_key_value_heads * head_dim, bias=True) self.v_proj_moe_gen = nn.Linear(config.hidden_size, config.num_key_value_heads * head_dim, bias=True) self.o_proj_moe_gen = nn.Linear(config.num_attention_heads * head_dim, config.hidden_size, bias=False) self.rotary_op = RotaryEmbedding(is_neox_style=True) def forward( self, packed_query_sequence: torch.Tensor, query_lens: torch.Tensor, packed_query_position_embeddings: torch.Tensor, packed_query_indexes: torch.Tensor, past_key_values: NaiveCache | None = None, key_values_lens: torch.Tensor | None = None, packed_key_value_indexes: torch.Tensor | None = None, update_past_key_values=True, is_causal=True, mode="und", packed_vae_token_indexes=None, packed_text_indexes=None, ): if mode == "und": packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim) packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) packed_query_states = self.q_norm(packed_query_states) packed_key_states = self.k_norm(packed_key_states) elif mode == "gen": packed_query_sequence = packed_query_sequence.to(torch.bfloat16) packed_query_states = packed_query_sequence.new_zeros( (packed_query_sequence.shape[0], self.num_heads * self.head_dim) ) packed_key_states = packed_query_sequence.new_zeros( (packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim) ) packed_value_states = packed_query_sequence.new_zeros( (packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim) ) packed_text_query_sequence = packed_query_sequence[packed_text_indexes] packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] packed_query_states[packed_text_indexes] = self.q_proj(packed_text_query_sequence) packed_query_states[packed_vae_token_indexes] = self.q_proj_moe_gen(packed_vae_query_sequence) packed_key_states[packed_text_indexes] = self.k_proj(packed_text_query_sequence) packed_key_states[packed_vae_token_indexes] = self.k_proj_moe_gen(packed_vae_query_sequence) packed_value_states[packed_text_indexes] = self.v_proj(packed_text_query_sequence) packed_value_states[packed_vae_token_indexes] = self.v_proj_moe_gen(packed_vae_query_sequence) packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim) packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim) packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim) packed_query_states = packed_query_states.to(torch.float32) packed_query_states[packed_text_indexes] = self.q_norm(packed_query_states[packed_text_indexes]) packed_query_states[packed_vae_token_indexes] = self.q_norm_moe_gen( packed_query_states[packed_vae_token_indexes] ) packed_key_states = packed_key_states.to(torch.float32) packed_key_states[packed_text_indexes] = self.k_norm(packed_key_states[packed_text_indexes]) packed_key_states[packed_vae_token_indexes] = self.k_norm_moe_gen( packed_key_states[packed_vae_token_indexes] ) cos, sin = [x[..., : self.head_dim // 2] for x in packed_query_position_embeddings] packed_query_states = self.rotary_op(packed_query_states.to(cos.dtype).unsqueeze(0), cos, sin).squeeze(0) packed_key_states = self.rotary_op(packed_key_states.to(cos.dtype).unsqueeze(0), cos, sin).squeeze(0) packed_query_states = packed_query_states.to(torch.bfloat16) packed_key_states = packed_key_states.to(torch.bfloat16) packed_value_states = packed_value_states.to(torch.bfloat16) if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None: past_key_states = past_key_values.key_cache[self.layer_idx] past_value_states = past_key_values.value_cache[self.layer_idx] seqlens = sum(query_lens) + sum(key_values_lens) merged_key_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim]) merged_value_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim]) merged_key_states[packed_query_indexes] = packed_key_states merged_key_states[packed_key_value_indexes] = past_key_states merged_value_states[packed_query_indexes] = packed_value_states merged_value_states[packed_key_value_indexes] = past_value_states key_values_lens = key_values_lens + query_lens else: merged_key_states = packed_key_states merged_value_states = packed_value_states key_values_lens = query_lens cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)) cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)) packed_attn_output = flash_attn_varlen_func( q=packed_query_states, k=merged_key_states, v=merged_value_states, cu_seqlens_q=cu_seqlens_q.to(torch.int32), cu_seqlens_k=cu_seqlens_k.to(torch.int32), max_seqlen_q=max(query_lens).item(), max_seqlen_k=max(key_values_lens).item(), causal=is_causal, ) packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size) if mode == "und": packed_attn_output = self.o_proj(packed_attn_output) elif mode == "gen": packed_attn_output[packed_text_indexes] = self.o_proj(packed_attn_output[packed_text_indexes]) packed_attn_output[packed_vae_token_indexes] = self.o_proj_moe_gen( packed_attn_output[packed_vae_token_indexes] ) if update_past_key_values: past_key_values.key_cache[self.layer_idx] = merged_key_states past_key_values.value_cache[self.layer_idx] = merged_value_states return packed_attn_output, past_key_values class Qwen2MoTDecoderLayer(nn.Module): def __init__( self, config, layer_idx: int | None = None, attn_module: Qwen2Attention | None = PackedAttentionMoT, ): super().__init__() self.hidden_size = config.hidden_size self.self_attn = attn_module(config, layer_idx) self.mlp = Qwen2MLP(config) self.mlp_moe_gen = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor | None = None, packed_query_sequence: torch.Tensor | None = None, query_lens: torch.Tensor = None, packed_query_position_embeddings: torch.Tensor = None, packed_query_indexes: torch.Tensor = None, past_key_values: NaiveCache | None = None, key_values_lens: torch.Tensor | None = None, packed_key_value_indexes: torch.Tensor | None = None, update_past_key_values=True, is_causal=True, mode="und", packed_vae_token_indexes=None, packed_text_indexes=None, ) -> BaseNavitOutputWithPast: if packed_query_sequence is None: packed_query_sequence = hidden_states residual = packed_query_sequence if mode == "und": packed_query_sequence = self.input_layernorm(packed_query_sequence) elif mode == "gen": packed_query_sequence_ = torch.zeros_like(packed_query_sequence) packed_query_sequence_[packed_text_indexes] = self.input_layernorm( packed_query_sequence[packed_text_indexes] ) packed_query_sequence_[packed_vae_token_indexes] = self.input_layernorm_moe_gen( packed_query_sequence[packed_vae_token_indexes] ) packed_query_sequence = packed_query_sequence_ # Self Attention packed_query_sequence, past_key_values = self.self_attn( packed_query_sequence=packed_query_sequence, query_lens=query_lens, packed_query_position_embeddings=packed_query_position_embeddings, packed_query_indexes=packed_query_indexes, past_key_values=past_key_values, key_values_lens=key_values_lens, packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=update_past_key_values, is_causal=is_causal, mode=mode, packed_vae_token_indexes=packed_vae_token_indexes, packed_text_indexes=packed_text_indexes, ) packed_query_sequence = residual + packed_query_sequence # Fully Connected residual = packed_query_sequence if mode == "und": packed_query_sequence = self.post_attention_layernorm(packed_query_sequence) packed_query_sequence = self.mlp(packed_query_sequence) elif mode == "gen": packed_text_query_sequence = packed_query_sequence[packed_text_indexes] packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] packed_text_query_sequence = self.post_attention_layernorm(packed_text_query_sequence).to(torch.bfloat16) packed_vae_query_sequence = self.post_attention_layernorm_moe_gen(packed_vae_query_sequence).to( torch.bfloat16 ) packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16) packed_query_sequence_[packed_text_indexes] = self.mlp(packed_text_query_sequence) packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_vae_query_sequence) packed_query_sequence = packed_query_sequence_ packed_query_sequence = residual + packed_query_sequence return packed_query_sequence, past_key_values class Qwen2MoTModel(Qwen2PreTrainedModel): def __init__(self, config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.use_moe = "Mo" in config.layer_module self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [ Qwen2MoTDecoderLayer(config, layer_idx, attn_module=PackedAttentionMoT) for layer_idx in range(config.num_hidden_layers) ] ) self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if self.use_moe: self.norm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2RotaryEmbedding(config=config) # Initialize weights and apply final processing self.post_init() def forward( self, packed_query_sequence: torch.Tensor, query_lens: torch.Tensor, packed_query_position_ids: torch.Tensor, packed_query_indexes: torch.Tensor, past_key_values: NaiveCache | None = None, key_values_lens: torch.Tensor | None = None, packed_key_value_indexes: torch.Tensor | None = None, update_past_key_values=True, is_causal=True, mode="und", packed_vae_token_indexes=None, packed_text_indexes=None, ) -> BaseNavitOutputWithPast: # create position embeddings to be shared across the decoder layers cos, sin = self.rotary_emb(packed_query_sequence, packed_query_position_ids.unsqueeze(0)) cos = cos.squeeze(0) sin = sin.squeeze(0) packed_query_position_embeddings = (cos, sin) extra_inputs = {} if self.use_moe: extra_inputs.update(mode=mode) if mode == "gen": assert packed_vae_token_indexes is not None assert packed_text_indexes is not None extra_inputs.update( packed_vae_token_indexes=packed_vae_token_indexes, packed_text_indexes=packed_text_indexes, ) for layer_idx, decoder_layer in enumerate(self.layers): packed_query_sequence, past_key_values = decoder_layer( hidden_states=packed_query_sequence, encoder_hidden_states=None, query_lens=query_lens, packed_query_position_embeddings=packed_query_position_embeddings, packed_query_indexes=packed_query_indexes, past_key_values=past_key_values, key_values_lens=key_values_lens, packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=update_past_key_values, is_causal=is_causal, **extra_inputs, ) if self.use_moe: if mode == "und": packed_query_sequence = self.norm(packed_query_sequence) elif mode == "gen": packed_query_sequence_ = torch.zeros_like(packed_query_sequence) packed_query_sequence_[packed_text_indexes] = self.norm(packed_query_sequence[packed_text_indexes]) packed_query_sequence_[packed_vae_token_indexes] = self.norm_moe_gen( packed_query_sequence[packed_vae_token_indexes] ) packed_query_sequence = packed_query_sequence_ else: packed_query_sequence = self.norm(packed_query_sequence) return BaseNavitOutputWithPast( packed_query_sequence=packed_query_sequence, past_key_values=past_key_values, ) class Qwen2MoTForCausalLM(Qwen2PreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = Qwen2MoTModel(config) self.vocab_size = config.vocab_size # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model def forward( self, packed_query_sequence: torch.Tensor, query_lens: torch.Tensor, packed_query_position_ids: torch.Tensor, packed_query_indexes: torch.Tensor, past_key_values: NaiveCache | None = None, key_values_lens: torch.Tensor | None = None, packed_key_value_indexes: torch.Tensor | None = None, update_past_key_values=True, is_causal=True, mode="und", packed_vae_token_indexes=None, packed_text_indexes=None, ) -> BaseNavitOutputWithPast: outputs = self.model( packed_query_sequence=packed_query_sequence, query_lens=query_lens, packed_query_position_ids=packed_query_position_ids, packed_query_indexes=packed_query_indexes, past_key_values=past_key_values, key_values_lens=key_values_lens, packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=update_past_key_values, is_causal=is_causal, mode=mode, packed_vae_token_indexes=packed_vae_token_indexes, packed_text_indexes=packed_text_indexes, ) return outputs def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. """ def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ half = dim // 2 freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( device=t.device ) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) t_emb = self.mlp(t_freq) return t_emb class PositionEmbedding(nn.Module): def __init__(self, max_num_patch_per_side, hidden_size): super().__init__() self.max_num_patch_per_side = max_num_patch_per_side self.hidden_size = hidden_size self.pos_embed = nn.Parameter(torch.zeros(max_num_patch_per_side**2, hidden_size), requires_grad=False) self._init_weights() def _init_weights(self): # Initialize (and freeze) pos_embed by sin-cos embedding: pos_embed = get_2d_sincos_pos_embed(self.hidden_size, self.max_num_patch_per_side) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float()) def forward(self, position_ids): return self.pos_embed[position_ids] def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side): num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size coords_h = torch.arange(0, num_patches_h) coords_w = torch.arange(0, num_patches_w) pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten() return pos_ids class Bagel(torch.nn.Module): config_class = BagelConfig base_model_prefix = "bagel" def __init__(self, language_model, vit_model, config: BagelConfig): super().__init__() self.language_model = language_model self.hidden_size = config.llm_config.hidden_size self.use_moe = "Mo" in config.llm_config.layer_module self.num_heads = config.llm_config.num_attention_heads if config.visual_gen: self.latent_patch_size = config.latent_patch_size self.timestep_shift = config.timestep_shift self.latent_downsample = config.vae_config.downsample * config.latent_patch_size self.max_latent_size = config.max_latent_size self.latent_channel = config.vae_config.z_channels self.patch_latent_dim = self.latent_patch_size**2 * self.latent_channel self.time_embedder = TimestepEmbedder(self.hidden_size) self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size) self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim) self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size) if config.visual_und: self.vit_model = vit_model self.vit_patch_size = config.vit_config.patch_size self.vit_max_num_patch_per_side = config.vit_max_num_patch_per_side self.vit_hidden_size = config.vit_config.hidden_size self.connector = MLPconnector(self.vit_hidden_size, self.hidden_size, config.connector_act) self.vit_pos_embed = PositionEmbedding(self.vit_max_num_patch_per_side, self.hidden_size) self.get_flattened_position_ids = get_flattened_position_ids_extrapolate self.config = config self._init_weights() def _init_weights(self): if self.config.visual_gen: nn.init.constant_(self.llm2vae.weight, 0) nn.init.constant_(self.llm2vae.bias, 0) def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ids): packed_text_ids = list() packed_text_position_ids = list() text_token_lens = list() packed_text_indexes = list() packed_key_value_indexes = list() curr = 0 newlens, new_rope = list(), list() for prompt, curr_kvlen, curr_position_id in zip(prompts, curr_kvlens, curr_rope): packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) curr += curr_kvlen text_ids = tokenizer.encode(prompt) text_ids = [new_token_ids["bos_token_id"]] + text_ids + [new_token_ids["eos_token_id"]] text_token_lens.append(len(text_ids)) packed_text_ids.extend(text_ids) packed_text_position_ids.extend(range(curr_position_id, curr_position_id + len(text_ids))) packed_text_indexes.extend(range(curr, curr + len(text_ids))) newlens.append(curr_kvlen + len(text_ids)) new_rope.append(curr_position_id + len(text_ids)) curr += len(text_ids) generation_input = { "text_token_lens": torch.tensor(text_token_lens, dtype=torch.int), "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), "packed_text_position_ids": torch.tensor(packed_text_position_ids, dtype=torch.long), "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), } return generation_input, newlens, new_rope def forward_cache_update_text( self, past_key_values: NaiveCache, packed_text_ids: torch.IntTensor, packed_text_position_ids: torch.LongTensor, text_token_lens: torch.LongTensor, packed_text_indexes: torch.LongTensor, packed_key_value_indexes: torch.LongTensor, key_values_lens: torch.IntTensor, ): packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) extra_inputs = {} if self.use_moe: extra_inputs = {"mode": "und"} output = self.language_model.forward( packed_query_sequence=packed_text_embedding, query_lens=text_token_lens, packed_query_position_ids=packed_text_position_ids, packed_query_indexes=packed_text_indexes, past_key_values=past_key_values, packed_key_value_indexes=packed_key_value_indexes, key_values_lens=key_values_lens, update_past_key_values=True, is_causal=True, **extra_inputs, ) past_key_values = output.past_key_values return past_key_values def prepare_vae_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids, timestep=0): patchified_vae_latent_shapes, packed_vae_position_ids = list(), list() packed_vae_token_indexes = list() packed_text_ids, packed_text_indexes = list(), list() packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list() packed_key_value_indexes = list() _curr = curr = 0 vae_image_tensors = list() newlens, new_rope = list(), list() for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope): packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) curr += curr_kvlen packed_text_ids.append(new_token_ids["start_of_image"]) packed_text_indexes.append(_curr) packed_indexes.append(curr) curr += 1 _curr += 1 image_tensor = transforms(image) vae_image_tensors.append(image_tensor) vae_position_ids = self.get_flattened_position_ids( image_tensor.size(1), image_tensor.size(2), self.latent_downsample, max_num_patches_per_side=self.max_latent_size, ) packed_vae_position_ids.append(vae_position_ids) H, W = image_tensor.shape[1:] h = H // self.latent_downsample w = W // self.latent_downsample patchified_vae_latent_shapes.append((h, w)) num_img_tokens = w * h packed_vae_token_indexes.extend(range(_curr, _curr + num_img_tokens)) packed_indexes.extend(range(curr, curr + num_img_tokens)) curr += num_img_tokens _curr += num_img_tokens packed_text_ids.append(new_token_ids["end_of_image"]) packed_text_indexes.append(_curr) packed_indexes.append(curr) curr += 1 _curr += 1 packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2)) packed_seqlens.append(num_img_tokens + 2) newlens.append(curr_kvlen + num_img_tokens + 2) new_rope.append(curr_position_id + 1) image_sizes = [item.shape for item in vae_image_tensors] max_image_size = [max(item) for item in list(zip(*image_sizes))] padded_images = torch.zeros(size=(len(vae_image_tensors), *max_image_size)) for i, image_tensor in enumerate(vae_image_tensors): padded_images[i, :, : image_tensor.shape[1], : image_tensor.shape[2]] = image_tensor generation_input = { "padded_images": padded_images, "patchified_vae_latent_shapes": patchified_vae_latent_shapes, "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0), "packed_timesteps": torch.tensor([timestep]), "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long), "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), } return generation_input, newlens, new_rope def forward_cache_update_vae( self, vae_model, past_key_values: NaiveCache, padded_images: torch.Tensor, patchified_vae_latent_shapes: list, packed_vae_position_ids: torch.LongTensor, packed_timesteps: torch.Tensor, packed_vae_token_indexes: torch.LongTensor, packed_text_ids: torch.LongTensor, packed_text_indexes: torch.LongTensor, packed_position_ids: torch.LongTensor, packed_seqlens: torch.IntTensor, packed_indexes: torch.LongTensor, key_values_lens: torch.IntTensor, packed_key_value_indexes: torch.Tensor, ): packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) packed_sequence[packed_text_indexes] = packed_text_embedding padded_latent = vae_model.encode(padded_images) p = self.latent_patch_size packed_latent = list() for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes): latent = latent[:, : h * p, : w * p].reshape(self.latent_channel, h, p, w, p) latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel) packed_latent.append(latent) packed_latent = torch.cat(packed_latent, dim=0) packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids) packed_timestep_embeds = self.time_embedder(packed_timesteps) packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + packed_pos_embed if packed_latent.dtype != packed_sequence.dtype: packed_latent = packed_latent.to(packed_sequence.dtype) packed_sequence[packed_vae_token_indexes] = packed_latent extra_inputs = {} if self.use_moe: extra_inputs = { "mode": "gen", "packed_vae_token_indexes": packed_vae_token_indexes, "packed_text_indexes": packed_text_indexes, } output = self.language_model.forward( packed_query_sequence=packed_sequence, query_lens=packed_seqlens, packed_query_position_ids=packed_position_ids, packed_query_indexes=packed_indexes, past_key_values=past_key_values, key_values_lens=key_values_lens, packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=True, is_causal=False, **extra_inputs, ) past_key_values = output.past_key_values return past_key_values def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids): packed_vit_token_indexes = list() vit_token_seqlens, packed_vit_tokens, packed_vit_position_ids = list(), list(), list() packed_text_ids, packed_text_indexes = list(), list() packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list() packed_key_value_indexes = list() _curr = curr = 0 newlens, new_rope = list(), list() for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope): packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) curr += curr_kvlen packed_text_ids.append(new_token_ids["start_of_image"]) packed_text_indexes.append(_curr) packed_indexes.append(curr) curr += 1 _curr += 1 image_tensor = transforms(image) vit_position_ids = self.get_flattened_position_ids( image_tensor.size(1), image_tensor.size(2), self.vit_patch_size, max_num_patches_per_side=self.vit_max_num_patch_per_side, ) vit_tokens = patchify(image_tensor, self.vit_patch_size) packed_vit_tokens.append(vit_tokens) num_img_tokens = vit_tokens.shape[0] packed_vit_position_ids.append(vit_position_ids) vit_token_seqlens.append(num_img_tokens) packed_vit_token_indexes.extend(range(_curr, _curr + num_img_tokens)) packed_indexes.extend(range(curr, curr + num_img_tokens)) curr += num_img_tokens _curr += num_img_tokens packed_text_ids.append(new_token_ids["end_of_image"]) packed_text_indexes.append(_curr) packed_indexes.append(curr) curr += 1 _curr += 1 packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2)) packed_seqlens.append(num_img_tokens + 2) newlens.append(curr_kvlen + num_img_tokens + 2) new_rope.append(curr_position_id + 1) generation_input = { "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), "vit_token_seqlens": torch.tensor(vit_token_seqlens, dtype=torch.int), "packed_vit_tokens": torch.cat(packed_vit_tokens, dim=0), "packed_vit_position_ids": torch.cat(packed_vit_position_ids, dim=0), "packed_vit_token_indexes": torch.tensor(packed_vit_token_indexes, dtype=torch.long), "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), } return generation_input, newlens, new_rope def forward_cache_update_vit( self, past_key_values: NaiveCache, packed_text_ids: torch.LongTensor, packed_text_indexes: torch.LongTensor, packed_vit_tokens: torch.Tensor, packed_vit_token_indexes: torch.LongTensor, packed_vit_position_ids: torch.LongTensor, vit_token_seqlens: torch.IntTensor, packed_position_ids: torch.LongTensor, packed_seqlens: torch.IntTensor, packed_indexes: torch.LongTensor, packed_key_value_indexes: torch.LongTensor, key_values_lens: torch.IntTensor, ): packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) packed_sequence[packed_text_indexes] = packed_text_embedding cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0)) cu_seqlens = cu_seqlens.to(torch.int32) max_seqlen = torch.max(vit_token_seqlens).item() packed_vit_token_embed = self.vit_model( packed_pixel_values=packed_vit_tokens, packed_flattened_position_ids=packed_vit_position_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) packed_vit_token_embed = self.connector(packed_vit_token_embed) pos_emb = self.vit_pos_embed(packed_vit_position_ids) packed_vit_token_embed = packed_vit_token_embed + pos_emb if packed_vit_token_embed.dtype != packed_sequence.dtype: packed_vit_token_embed = packed_vit_token_embed.to(packed_sequence.dtype) packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed extra_inputs = {} if self.use_moe: extra_inputs = {"mode": "und"} output = self.language_model.forward( packed_query_sequence=packed_sequence, query_lens=packed_seqlens, packed_query_position_ids=packed_position_ids, packed_query_indexes=packed_indexes, past_key_values=past_key_values, packed_key_value_indexes=packed_key_value_indexes, key_values_lens=key_values_lens, update_past_key_values=True, is_causal=False, **extra_inputs, ) past_key_values = output.past_key_values return past_key_values def prepare_input(self, curr_kvlens, curr_rope, image_sizes, new_token_ids=None): packed_text_ids, packed_text_indexes = list(), list() packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = list(), list(), list() packed_position_ids, packed_seqlens, packed_indexes = list(), list(), list() packed_key_value_indexes = list() query_curr = curr = 0 for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope): packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) curr += curr_kvlen packed_text_ids.append(new_token_ids["start_of_image"]) packed_text_indexes.append(query_curr) packed_indexes.append(curr) curr += 1 query_curr += 1 vae_position_ids = self.get_flattened_position_ids( H, W, self.latent_downsample, max_num_patches_per_side=self.max_latent_size ) packed_vae_position_ids.append(vae_position_ids) h, w = H // self.latent_downsample, W // self.latent_downsample num_image_tokens = h * w packed_init_noises.append(torch.randn(num_image_tokens, self.latent_channel * self.latent_patch_size**2)) packed_vae_token_indexes.extend(range(query_curr, query_curr + num_image_tokens)) packed_seqlens.append(num_image_tokens + 2) packed_indexes.extend(range(curr, curr + num_image_tokens)) curr += num_image_tokens query_curr += num_image_tokens packed_text_ids.append(new_token_ids["end_of_image"]) packed_text_indexes.append(query_curr) packed_indexes.append(curr) curr += 1 query_curr += 1 packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2)) # Construct Output generation_input = { "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), "packed_init_noises": torch.cat(packed_init_noises, dim=0), "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0), "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long), "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), } return generation_input def prepare_vae_latent(self, curr_kvlens, curr_rope, image_sizes, new_token_ids): return self.prepare_input(curr_kvlens, curr_rope, image_sizes, new_token_ids) def generate_image( self, packed_text_ids: torch.LongTensor, packed_text_indexes: torch.LongTensor, packed_init_noises: torch.Tensor, packed_vae_position_ids: torch.LongTensor, packed_vae_token_indexes: torch.LongTensor, packed_seqlens: torch.IntTensor, packed_position_ids: torch.LongTensor, packed_indexes: torch.LongTensor, past_key_values: NaiveCache, key_values_lens: torch.IntTensor, packed_key_value_indexes: torch.LongTensor, num_timesteps: int = 24, timestep_shift: float = 1.0, ): model_pred_cache_dic, model_pred_current = None, None model_pred_text_cache_dic, model_pred_text_current = None, None model_pred_img_cache_dic, model_pred_img_current = None, None x_t = packed_init_noises timesteps = torch.linspace(1, 0, num_timesteps, device=x_t.device) timesteps = timestep_shift * timesteps / (1 + (timestep_shift - 1) * timesteps) dts = timesteps[:-1] - timesteps[1:] timesteps = timesteps[:-1] for i, t in enumerate(timesteps): timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device) v_t = self._forward_flow( x_t=x_t, timestep=timestep, packed_vae_token_indexes=packed_vae_token_indexes, packed_vae_position_ids=packed_vae_position_ids, packed_text_ids=packed_text_ids, packed_text_indexes=packed_text_indexes, packed_position_ids=packed_position_ids, packed_indexes=packed_indexes, packed_seqlens=packed_seqlens, key_values_lens=key_values_lens, past_key_values=past_key_values, packed_key_value_indexes=packed_key_value_indexes, # cache model_pred_cache_dic=model_pred_cache_dic, model_pred_current=model_pred_current, model_pred_text_cache_dic=model_pred_text_cache_dic, model_pred_text_current=model_pred_text_current, model_pred_img_cache_dic=model_pred_img_cache_dic, model_pred_img_current=model_pred_img_current, ) x_t = x_t - v_t.to(x_t.device) * dts[i] # velocity pointing from data to noise unpacked_latent = x_t.split((packed_seqlens - 2).tolist()) return unpacked_latent def _forward_flow( self, x_t: torch.Tensor, timestep: torch.LongTensor, packed_vae_token_indexes: torch.LongTensor, packed_vae_position_ids: torch.LongTensor, packed_text_ids: torch.LongTensor, packed_text_indexes: torch.LongTensor, packed_indexes: torch.LongTensor, packed_position_ids: torch.LongTensor, packed_seqlens: torch.IntTensor, key_values_lens: torch.IntTensor, past_key_values: NaiveCache, packed_key_value_indexes: torch.LongTensor, # cache model_pred_cache_dic: dict[str, Any] | None = None, model_pred_current: int | None = None, model_pred_text_cache_dic: dict[str, Any] | None = None, model_pred_text_current: int | None = None, model_pred_img_cache_dic: dict[str, Any] | None = None, model_pred_img_current: int | None = None, ): packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) packed_sequence[packed_text_indexes] = packed_text_embedding assert timestep.unique().shape[0] == 1 packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids) packed_timestep_embeds = self.time_embedder(timestep) x_t = self.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed if x_t.dtype != packed_sequence.dtype: x_t = x_t.to(packed_sequence.dtype) packed_sequence[packed_vae_token_indexes] = x_t extra_inputs = {} if self.use_moe: extra_inputs = { "mode": "gen", "packed_vae_token_indexes": packed_vae_token_indexes, "packed_text_indexes": packed_text_indexes, } output = self.language_model.forward( packed_query_sequence=packed_sequence, query_lens=packed_seqlens, packed_query_position_ids=packed_position_ids, packed_query_indexes=packed_indexes, past_key_values=past_key_values, key_values_lens=key_values_lens, packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=False, is_causal=False, **extra_inputs, ) v_t = self.llm2vae(output.packed_query_sequence) v_t = v_t[packed_vae_token_indexes] return v_t