import torch import torch.nn as nn from loguru import logger from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from lightx2v.utils.envs import * from lightx2v_platform.base.global_var import AI_DEVICE torch_device_module = getattr(torch, AI_DEVICE) # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 class Qwen2RotaryEmbedding(nn.Module): def __init__( self, dim=None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, rope_type="default", config=None, ): super().__init__() # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} if config is None: logger.warning_once("`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46") self.rope_kwargs = { "rope_type": rope_type, "factor": scaling_factor, "dim": dim, "base": base, "max_position_embeddings": max_position_embeddings, } self.rope_type = rope_type self.max_seq_len_cached = max_position_embeddings self.original_max_seq_len = max_position_embeddings else: # BC: "rope_type" was originally "type" if config["rope_scaling"] is not None: self.rope_type = config["rope_scaling"].get("rope_type", config["rope_scaling"].get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config["max_position_embeddings"] self.original_max_seq_len = config["max_position_embeddings"] self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq def _dynamic_frequency_update(self, position_ids, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: 1 - growing beyond the cached sequence length (allow scaling) 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() def forward(self, x, position_ids): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention cos = cos * self.attention_scaling sin = sin * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class BagelPreInfer: def __init__(self, config, llm_config): self.config = config self.rotary_emb = Qwen2RotaryEmbedding(config=llm_config) def set_scheduler(self, scheduler): self.scheduler = scheduler def embed_tokens(self, weights, packed_text_ids): packed_text_ids = packed_text_ids.to(AI_DEVICE) embeds = weights.embed_tokens.apply(packed_text_ids) return embeds def infer(self, weights, packed_sequence, packed_position_ids): # create position embeddings to be shared across the decoder layers cos, sin = self.rotary_emb(packed_sequence, packed_position_ids.unsqueeze(0)) cos = cos.squeeze(0).to(AI_DEVICE) sin = sin.squeeze(0).to(AI_DEVICE) packed_position_embeddings = (cos, sin) return packed_position_embeddings def vae2llm(self, weights, x): x = x.to(AI_DEVICE).to(torch.bfloat16) x = weights.vae2llm.apply(x) return x