import gc
import math
import os
from copy import deepcopy
import torch
from PIL import Image
from torch.nn import functional as F
from lightx2v.models.networks.bagel.data_utils import add_special_tokens
from lightx2v.models.networks.bagel.infer.post_infer import BagelPostInfer
from lightx2v.models.networks.bagel.infer.pre_infer import BagelPreInfer
from lightx2v.models.networks.bagel.infer.transformer_infer import BagelTransformerInfer
from lightx2v.models.networks.bagel.model_io import BagelInputs, NaiveCache, cache_init
from lightx2v.models.networks.bagel.modeling_utils import PositionEmbedding
from lightx2v.models.networks.bagel.tokenization_qwen2 import Qwen2Tokenizer
from lightx2v.models.networks.bagel.weights.post_weights import Qwen2PostWeights
from lightx2v.models.networks.bagel.weights.pre_weights import Qwen2PreWeights
from lightx2v.models.networks.bagel.weights.transformer_weights import Qwen2TransformerWeights
from lightx2v.utils.envs import *
from lightx2v.utils.utils import *
VLM_THINK_SYSTEM_PROMPT = """You should first think about the reasoning process in the mind and then provide the user with the answer.
The reasoning process is enclosed within tags, i.e. reasoning process here answer here"""
GEN_THINK_SYSTEM_PROMPT = """You should first think about the planning process in the mind and then generate the image.
The planning process is enclosed within tags, i.e. planning process here image here"""
class BagelModel:
pre_weight_class = Qwen2PreWeights
transformer_weight_class = Qwen2TransformerWeights
post_weight_class = Qwen2PostWeights
def __init__(self, config):
self.config = config
self.model_path = config["model_path"]
# init llm config
llm_config = self.config["llm_config"]
with self.config.temporarily_unlocked():
llm_config.update(self.config["llm_config_update"])
self.llm_config = llm_config
self.use_moe = "Mo" in self.llm_config["layer_module"]
self.num_heads = self.llm_config["num_attention_heads"]
self.hidden_size = self.llm_config["hidden_size"]
self.think = config.get("think", False)
self.understanding_output = config.get("understanding_output", False)
self.inference_hyper = config["inference_hyper"]
self.do_sample = config.get("do_sample", False)
self.text_temperature = config.get("text_temperature", 0.3)
self.max_think_token_n = config.get("max_think_token_n", 1000)
self.enable_taylorseer = False
self.cpu_offload = config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block")
self.device = torch.device("cpu") if self.cpu_offload else torch.device(AI_DEVICE)
self._init_infer_class()
self._init_weights()
self._init_infer()
self._init_modules()
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.pre_infer.set_scheduler(scheduler)
self.transformer_infer.set_scheduler(scheduler)
self.post_infer.set_scheduler(scheduler)
def _init_infer_class(self):
self.pre_infer_class = BagelPreInfer
self.transformer_infer_class = BagelTransformerInfer
self.post_infer_class = BagelPostInfer
def _apply_weights(self, weight_dict=None):
if weight_dict is not None:
self.original_weight_dict = weight_dict
del weight_dict
gc.collect()
# Load weights into containers
self.pre_weight.load(self.original_weight_dict)
self.transformer_weights.load(self.original_weight_dict)
self.post_weight.load(self.original_weight_dict)
del self.original_weight_dict
torch.cuda.empty_cache()
gc.collect()
def _init_weights(self):
self.pre_weight = self.pre_weight_class(self.config)
self.transformer_weights = self.transformer_weight_class(self.config, self.llm_config)
self.post_weight = self.post_weight_class(self.config)
weight_dict = safetensors.torch.load_file(os.path.join(self.config["model_path"], "ema.safetensors"), device=AI_DEVICE)
self._apply_weights(weight_dict)
def _init_infer(self):
self.transformer_infer = self.transformer_infer_class(self.config, self.llm_config)
self.pre_infer = self.pre_infer_class(self.config, self.llm_config)
self.post_infer = self.post_infer_class(self.config, self.llm_config)
def _init_modules(self):
tokenizer = Qwen2Tokenizer.from_pretrained(self.model_path)
tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
self.tokenizer = tokenizer
self.new_token_ids = new_token_ids
if self.config.visual_gen:
self.latent_patch_size = self.config.latent_patch_size
self.timestep_shift = self.config.timestep_shift
self.latent_downsample = self.config.vae_config["downsample"] * self.config.latent_patch_size
self.latent_channel = self.config.vae_config["z_channels"]
self.patch_latent_dim = self.latent_patch_size**2 * self.latent_channel
self.max_latent_size = self.config["max_latent_size_update"]
self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size)
self.frequency_embedding_size = 256
def init_gen_context(self):
gen_context = {
"kv_lens": [0],
"ropes": [0],
"past_key_values": NaiveCache(self.llm_config.num_hidden_layers),
}
return gen_context
def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ids):
packed_text_ids = list()
packed_text_position_ids = list()
text_token_lens = list()
packed_text_indexes = list()
packed_key_value_indexes = list()
curr = 0
newlens, new_rope = list(), list()
for prompt, curr_kvlen, curr_position_id in zip(prompts, curr_kvlens, curr_rope):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen
text_ids = tokenizer.encode(prompt)
text_ids = [new_token_ids["bos_token_id"]] + text_ids + [new_token_ids["eos_token_id"]]
text_token_lens.append(len(text_ids))
packed_text_ids.extend(text_ids)
packed_text_position_ids.extend(range(curr_position_id, curr_position_id + len(text_ids)))
packed_text_indexes.extend(range(curr, curr + len(text_ids)))
newlens.append(curr_kvlen + len(text_ids))
new_rope.append(curr_position_id + len(text_ids))
curr += len(text_ids)
generation_input = {
"text_token_lens": torch.tensor(text_token_lens, dtype=torch.int),
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
"packed_text_position_ids": torch.tensor(packed_text_position_ids, dtype=torch.long),
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
"packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
}
return generation_input, newlens, new_rope
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_ids: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
):
packed_query_position_embeddings = self.pre_infer.infer(self.pre_weight, packed_query_sequence, packed_query_position_ids)
extra_inputs = {}
if self.use_moe:
extra_inputs.update(mode=mode)
if mode == "gen":
assert packed_vae_token_indexes is not None
assert packed_text_indexes is not None
extra_inputs.update(
packed_vae_token_indexes=packed_vae_token_indexes,
packed_text_indexes=packed_text_indexes,
)
packed_query_sequence, past_key_values = self.transformer_infer.infer(
self.transformer_weights.blocks,
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_embeddings=packed_query_position_embeddings,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
**extra_inputs,
)
packed_query_sequence = self.post_infer.infer(
self.post_weight,
packed_query_sequence,
packed_text_indexes,
packed_vae_token_indexes,
mode,
)
return packed_query_sequence, past_key_values
@torch.no_grad
def forward_cache_update_text(
self,
past_key_values: NaiveCache,
packed_text_ids: torch.IntTensor,
packed_text_position_ids: torch.LongTensor,
text_token_lens: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
packed_key_value_indexes: torch.LongTensor,
key_values_lens: torch.IntTensor,
):
packed_text_embedding = self.pre_infer.embed_tokens(self.pre_weight, packed_text_ids)
extra_inputs = {}
if self.use_moe:
extra_inputs = {"mode": "und"}
packed_query_sequence, past_key_values = self.forward_inference(
packed_query_sequence=packed_text_embedding,
query_lens=text_token_lens,
packed_query_position_ids=packed_text_position_ids,
packed_query_indexes=packed_text_indexes,
past_key_values=past_key_values,
packed_key_value_indexes=packed_key_value_indexes,
key_values_lens=key_values_lens,
update_past_key_values=True,
is_causal=True,
**extra_inputs,
)
return past_key_values
@torch.no_grad()
def update_context_text(self, text, gen_context):
# used for interleave data, currently only support 1 data inference,
past_key_values = gen_context["past_key_values"]
kv_lens = gen_context["kv_lens"]
ropes = gen_context["ropes"]
generation_input, kv_lens, ropes = self.prepare_prompts(
curr_kvlens=kv_lens,
curr_rope=ropes,
prompts=[text],
tokenizer=self.tokenizer,
new_token_ids=self.new_token_ids,
)
past_key_values = self.forward_cache_update_text(past_key_values, **generation_input)
gen_context["kv_lens"] = kv_lens
gen_context["ropes"] = ropes
gen_context["past_key_values"] = past_key_values
return gen_context
def gen_text(self):
assert NotImplementedError
@torch.no_grad()
def prepare_inputs(self, input_info, scheduler):
gen_context = self.transformer_infer.gen_context
cfg_text_context = self.transformer_infer.cfg_text_context
cfg_img_context = self.transformer_infer.cfg_img_context
input_lists = [input_info.prompt]
output_list = []
with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
if self.think:
if self.understanding_output:
system_prompt = VLM_THINK_SYSTEM_PROMPT
else:
system_prompt = GEN_THINK_SYSTEM_PROMPT
gen_context = self.update_context_text(system_prompt, gen_context)
cfg_img_context = self.update_context_text(system_prompt, cfg_img_context)
for input_term in input_lists:
if isinstance(input_term, str): # True
cfg_text_context = deepcopy(gen_context)
gen_context = self.update_context_text(input_term, gen_context)
cfg_img_context = self.update_context_text(input_term, cfg_img_context)
elif isinstance(input_term, Image.Image):
assert NotImplementedError
else:
raise ValueError(f"Unsupported input type: {type(input_term)}")
max_think_token_n = 1000
if self.understanding_output:
assert NotImplementedError
gen_text = self.gen_text(gen_context, do_sample=self.do_sample, temperature=self.text_temperature, max_length=max_think_token_n)
output_list.append(gen_text)
else:
if self.think:
gen_text = self.gen_text(gen_context, do_sample=self.do_sample, temperature=self.text_temperature, max_length=max_think_token_n)
gen_context = self.update_context_text(gen_text, gen_context)
output_list.append(gen_text)
else:
gen_text = None
kv_lens = gen_context["kv_lens"]
ropes = gen_context["ropes"]
generation_input = scheduler.prepare_vae_latent(
curr_kvlens=kv_lens,
curr_rope=ropes,
image_sizes=[(1024, 1024)],
new_token_ids=self.new_token_ids,
)
# text cfg
cfg_text_past_key_values = cfg_text_context["past_key_values"]
kv_lens_cfg = cfg_text_context["kv_lens"]
ropes_cfg = cfg_text_context["ropes"]
generation_input_cfg_text = scheduler.prepare_vae_latent_cfg(
curr_kvlens=kv_lens_cfg,
curr_rope=ropes_cfg,
image_sizes=[(1024, 1024)],
)
# img cfg
cfg_img_past_key_values = cfg_img_context["past_key_values"]
kv_lens_cfg = cfg_img_context["kv_lens"]
ropes_cfg = cfg_img_context["ropes"]
generation_input_cfg_img = scheduler.prepare_vae_latent_cfg(
curr_kvlens=kv_lens_cfg,
curr_rope=ropes_cfg,
image_sizes=[(1024, 1024)],
)
scheduler.generation_input = generation_input
scheduler.generation_input_cfg_text = generation_input_cfg_text
scheduler.generation_input_cfg_image = generation_input_cfg_img
scheduler.latents = generation_input["packed_init_noises"]
num_timesteps = scheduler.infer_steps
if self.enable_taylorseer:
model_pred_cache_dic, model_pred_current = cache_init(self, num_timesteps)
model_pred_text_cache_dic, model_pred_text_current = cache_init(self, num_timesteps)
model_pred_img_cache_dic, model_pred_img_current = cache_init(self, num_timesteps)
else:
model_pred_cache_dic, model_pred_current = None, None
model_pred_text_cache_dic, model_pred_text_current = None, None
model_pred_img_cache_dic, model_pred_img_current = None, None
bagel_inputs = BagelInputs(
image_shapes=input_info.image_shapes,
gen_context=gen_context,
cfg_text_precontext=cfg_text_context,
cfg_img_precontext=cfg_img_context,
model_pred_cache_dic=model_pred_cache_dic,
model_pred_current=model_pred_current,
model_pred_text_cache_dic=model_pred_text_cache_dic,
model_pred_text_current=model_pred_text_current,
model_pred_img_cache_dic=model_pred_img_cache_dic,
model_pred_img_current=model_pred_img_current,
generation_input=generation_input,
generation_input_cfg_text=generation_input_cfg_text,
generation_input_cfg_img=generation_input_cfg_img,
cfg_text_past_key_values=cfg_text_past_key_values,
cfg_img_past_key_values=cfg_img_past_key_values,
)
return bagel_inputs, scheduler
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
@torch.no_grad()
def time_embedder(self, weights, t):
t = t.to(AI_DEVICE)
t = self.timestep_embedding(t, self.frequency_embedding_size)
t = t.to(torch.bfloat16)
t = weights.mlp_0.apply(t)
t = F.silu(t)
t = weights.mlp_2.apply(t)
return t
@torch.no_grad()
def vae2llm(self, x):
x = self.pre_infer.vae2llm(self.pre_weight, x)
return x
@torch.no_grad()
def llm2vae(self, x):
x = self.post_infer.llm2vae(self.post_weight, x)
return x
@torch.no_grad
def infer(self, inputs):
t = self.scheduler.timesteps[self.scheduler.step_index]
x_t = self.scheduler.latents.to(torch.bfloat16).to(AI_DEVICE)
timestep = torch.tensor([t] * x_t.shape[0])
if t > self.inference_hyper["cfg_interval"][0] and t <= self.inference_hyper["cfg_interval"][1]:
cfg_text_scale = self.inference_hyper["cfg_text_scale"]
cfg_img_scale = self.inference_hyper["cfg_img_scale"]
else:
cfg_text_scale = 1.0
cfg_img_scale = 1.0
packed_text_ids = inputs.generation_input["packed_text_ids"]
packed_seqlens = inputs.generation_input["packed_seqlens"]
packed_text_indexes = inputs.generation_input["packed_text_indexes"]
packed_text_embedding = self.pre_infer.embed_tokens(self.pre_weight, packed_text_ids)
packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
packed_sequence[packed_text_indexes] = packed_text_embedding
assert timestep.unique().shape[0] == 1
packed_pos_embed = self.latent_pos_embed(inputs.generation_input["packed_vae_position_ids"]).to(AI_DEVICE).to(torch.bfloat16)
packed_timestep_embeds = self.time_embedder(self.pre_weight, timestep)
packed_pos_embed = packed_pos_embed.to(AI_DEVICE)
x_t = x_t.to(AI_DEVICE)
x_t = self.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed
if x_t.dtype != packed_sequence.dtype:
x_t = x_t.to(packed_sequence.dtype)
packed_sequence[inputs.generation_input["packed_vae_token_indexes"]] = x_t
extra_inputs = {}
if self.use_moe:
extra_inputs = {"mode": "gen", "packed_vae_token_indexes": inputs.generation_input["packed_vae_token_indexes"], "packed_text_indexes": packed_text_indexes}
if self.enable_taylorseer:
self.scheduler.cache_dic = inputs.model_pred_cache_dic
self.scheduler.current = inputs.model_pred_current
output = self.forward_inference(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=inputs.generation_input["packed_position_ids"],
packed_query_indexes=inputs.generation_input["packed_indexes"],
past_key_values=inputs.gen_context["past_key_values"],
key_values_lens=inputs.generation_input["key_values_lens"],
packed_key_value_indexes=inputs.generation_input["packed_key_value_indexes"],
update_past_key_values=False,
is_causal=False,
**extra_inputs,
)
v_t = self.llm2vae(output[0])
v_t = v_t[inputs.generation_input["packed_vae_token_indexes"]]
if cfg_text_scale > 1.0:
if self.enable_taylorseer:
self.cache_dic = inputs.model_pred_text_cache_dic
self.current = inputs.model_pred_text_current
cfg_text_output = self.forward_inference(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=inputs.generation_input_cfg_text["cfg_packed_position_ids"],
packed_query_indexes=inputs.generation_input_cfg_text["cfg_packed_query_indexes"],
past_key_values=inputs.cfg_text_past_key_values,
key_values_lens=inputs.generation_input_cfg_text["cfg_key_values_lens"],
packed_key_value_indexes=inputs.generation_input_cfg_text["cfg_packed_key_value_indexes"],
update_past_key_values=False,
is_causal=False,
**extra_inputs,
)
cfg_text_v_t = self.llm2vae(cfg_text_output[0])
cfg_text_v_t = cfg_text_v_t[inputs.generation_input["packed_vae_token_indexes"]]
if cfg_img_scale > 1.0:
if self.enable_taylorseer:
self.cache_dic = inputs.model_pred_text_cache_dic
self.current = inputs.model_pred_text_current
cfg_img_output = self.forward_inference(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=inputs.generation_input_cfg_img["cfg_packed_position_ids"],
packed_query_indexes=inputs.generation_input_cfg_img["cfg_packed_query_indexes"],
past_key_values=inputs.cfg_img_past_key_values,
key_values_lens=inputs.generation_input_cfg_img["cfg_key_values_lens"],
packed_key_value_indexes=inputs.generation_input_cfg_img["cfg_packed_key_value_indexes"],
update_past_key_values=False,
is_causal=False,
**extra_inputs,
)
cfg_img_v_t = self.llm2vae(cfg_img_output[0])
cfg_img_v_t = cfg_img_v_t[inputs.generation_input["packed_vae_token_indexes"]]
if cfg_text_scale > 1.0:
if self.inference_hyper["cfg_renorm_type"] == "text_channel":
v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
norm_v_t_text_ = torch.norm(v_t_text_, dim=-1, keepdim=True)
scale = (norm_v_t / (norm_v_t_text_ + 1e-8)).clamp(min=self.inference_hyper["cfg_renorm_min"], max=1.0)
v_t_text = v_t_text_ * scale
if cfg_img_scale > 1.0:
v_t = cfg_img_v_t + cfg_img_scale * (v_t_text - cfg_img_v_t)
else:
v_t = v_t_text
else:
v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
if cfg_img_scale > 1.0:
v_t_ = cfg_img_v_t + cfg_img_scale * (v_t_text_ - cfg_img_v_t)
else:
v_t_ = v_t_text_
# NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit
if self.inference_hyper["cfg_renorm_type"] == "global":
norm_v_t = torch.norm(v_t, dtype=torch.float32)
norm_v_t_ = torch.norm(v_t_, dtype=torch.float32)
elif self.inference_hyper["cfg_renorm_type"] == "channel":
norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
norm_v_t_ = torch.norm(v_t_, dim=-1, keepdim=True)
else:
raise NotImplementedError(f"{self.inference_hyper['cfg_renorm_min']} is not suppoprted")
scale = (norm_v_t / (norm_v_t_ + 1e-8)).clamp(min=self.inference_hyper["cfg_renorm_min"], max=1.0)
v_t = v_t_ * scale
else:
# No CFG
pass
self.scheduler.noise_pred = v_t
return v_t