Commit c07946d8 authored by hepj's avatar hepj
Browse files

dit & video

parents
# SPDX-License-Identifier: Apache-2.0
# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/logging_utils/formatter.py
import logging
class NewLineFormatter(logging.Formatter):
"""Adds logging prefix to newlines to align multi-line messages."""
def __init__(self, fmt, datefmt=None, style="%"):
logging.Formatter.__init__(self, fmt, datefmt, style)
def format(self, record):
msg = logging.Formatter.format(self, record)
if record.message != "":
parts = msg.split(record.message)
msg = msg.replace("\n", "\r\n" + parts[0])
return msg
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from fastvideo.v1.configs.models import DiTConfig
from fastvideo.v1.platforms import _Backend
# TODO
class BaseDiT(nn.Module, ABC):
_fsdp_shard_conditions: list = []
_param_names_mapping: dict
hidden_size: int
num_attention_heads: int
num_channels_latents: int
# always supports torch_sdpa
_supported_attention_backends: Tuple[
_Backend, ...] = DiTConfig()._supported_attention_backends
def __init_subclass__(cls) -> None:
required_class_attrs = [
"_fsdp_shard_conditions", "_param_names_mapping"
]
super().__init_subclass__()
for attr in required_class_attrs:
if not hasattr(cls, attr):
raise AttributeError(
f"Subclasses of BaseDiT must define '{attr}' class variable"
)
def __init__(self, config: DiTConfig, **kwargs) -> None:
super().__init__()
self.config = config
if not self.supported_attention_backends:
raise ValueError(
f"Subclass {self.__class__.__name__} must define _supported_attention_backends"
)
@abstractmethod
def forward(self,
hidden_states: torch.Tensor,
encoder_hidden_states: Union[torch.Tensor, List[torch.Tensor]],
timestep: torch.LongTensor,
encoder_hidden_states_image: Optional[Union[
torch.Tensor, List[torch.Tensor]]] = None,
guidance=None,
**kwargs) -> torch.Tensor:
pass
def __post_init__(self) -> None:
required_attrs = [
"hidden_size", "num_attention_heads", "num_channels_latents"
]
for attr in required_attrs:
if not hasattr(self, attr):
raise AttributeError(
f"Subclasses of BaseDiT must define '{attr}' instance variable"
)
@property
def supported_attention_backends(self) -> Tuple[_Backend, ...]:
return self._supported_attention_backends
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from fastvideo.v1.attention import DistributedAttention, LocalAttention
from fastvideo.v1.configs.models.dits import HunyuanVideoConfig
from fastvideo.v1.distributed.parallel_state import (
get_sequence_model_parallel_world_size)
from fastvideo.v1.layers.layernorm import (LayerNormScaleShift, ScaleResidual,
ScaleResidualLayerNormScaleShift)
from fastvideo.v1.layers.linear import ReplicatedLinear
# TODO(will-PY-refactor): RMSNorm ....
from fastvideo.v1.layers.mlp import MLP
from fastvideo.v1.layers.rotary_embedding import (_apply_rotary_emb,
get_rotary_pos_embed)
from fastvideo.v1.layers.visual_embedding import (ModulateProjection,
PatchEmbed, TimestepEmbedder,
unpatchify)
from fastvideo.v1.models.dits.base import BaseDiT
from fastvideo.v1.platforms import _Backend
class HunyuanRMSNorm(nn.Module):
def __init__(
self,
dim: int,
elementwise_affine=True,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x) -> torch.Tensor:
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
class MMDoubleStreamBlock(nn.Module):
"""
A multimodal DiT block with separate modulation for text and image/video,
using distributed attention and linear layers.
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
mlp_ratio: float,
dtype: Optional[torch.dtype] = None,
supported_attention_backends: Optional[Tuple[_Backend, ...]] = None,
prefix: str = "",
):
super().__init__()
self.deterministic = False
self.num_attention_heads = num_attention_heads
head_dim = hidden_size // num_attention_heads
mlp_hidden_dim = int(hidden_size * mlp_ratio)
# Image modulation components
self.img_mod = ModulateProjection(
hidden_size,
factor=6,
act_layer="silu",
dtype=dtype,
prefix=f"{prefix}.img_mod",
)
# Fused operations for image stream
self.img_attn_norm = LayerNormScaleShift(hidden_size,
norm_type="layer",
elementwise_affine=False,
dtype=dtype)
self.img_attn_residual_mlp_norm = ScaleResidualLayerNormScaleShift(
hidden_size,
norm_type="layer",
elementwise_affine=False,
dtype=dtype)
self.img_mlp_residual = ScaleResidual()
# Image attention components
self.img_attn_qkv = ReplicatedLinear(hidden_size,
hidden_size * 3,
bias=True,
params_dtype=dtype,
prefix=f"{prefix}.img_attn_qkv")
self.img_attn_q_norm = HunyuanRMSNorm(head_dim, eps=1e-6, dtype=dtype)
self.img_attn_k_norm = HunyuanRMSNorm(head_dim, eps=1e-6, dtype=dtype)
self.img_attn_proj = ReplicatedLinear(hidden_size,
hidden_size,
bias=True,
params_dtype=dtype,
prefix=f"{prefix}.img_attn_proj")
self.img_mlp = MLP(hidden_size,
mlp_hidden_dim,
bias=True,
dtype=dtype,
prefix=f"{prefix}.img_mlp")
# Text modulation components
self.txt_mod = ModulateProjection(
hidden_size,
factor=6,
act_layer="silu",
dtype=dtype,
prefix=f"{prefix}.txt_mod",
)
# Fused operations for text stream
self.txt_attn_norm = LayerNormScaleShift(hidden_size,
norm_type="layer",
elementwise_affine=False,
dtype=dtype)
self.txt_attn_residual_mlp_norm = ScaleResidualLayerNormScaleShift(
hidden_size,
norm_type="layer",
elementwise_affine=False,
dtype=dtype)
self.txt_mlp_residual = ScaleResidual()
# Text attention components
self.txt_attn_qkv = ReplicatedLinear(hidden_size,
hidden_size * 3,
bias=True,
params_dtype=dtype)
# QK norm layers for text
self.txt_attn_q_norm = HunyuanRMSNorm(head_dim, eps=1e-6, dtype=dtype)
self.txt_attn_k_norm = HunyuanRMSNorm(head_dim, eps=1e-6, dtype=dtype)
self.txt_attn_proj = ReplicatedLinear(hidden_size,
hidden_size,
bias=True,
params_dtype=dtype)
self.txt_mlp = MLP(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype)
# Distributed attention
self.attn = DistributedAttention(
num_heads=num_attention_heads,
head_size=head_dim,
causal=False,
supported_attention_backends=supported_attention_backends,
prefix=f"{prefix}.attn")
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
freqs_cis: tuple,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Process modulation vectors
img_mod_outputs = self.img_mod(vec)
(
img_attn_shift,
img_attn_scale,
img_attn_gate,
img_mlp_shift,
img_mlp_scale,
img_mlp_gate,
) = torch.chunk(img_mod_outputs, 6, dim=-1)
txt_mod_outputs = self.txt_mod(vec)
(
txt_attn_shift,
txt_attn_scale,
txt_attn_gate,
txt_mlp_shift,
txt_mlp_scale,
txt_mlp_gate,
) = torch.chunk(txt_mod_outputs, 6, dim=-1)
# Prepare image for attention using fused operation
img_attn_input = self.img_attn_norm(img, img_attn_shift, img_attn_scale)
# Get QKV for image
img_qkv, _ = self.img_attn_qkv(img_attn_input)
batch_size, image_seq_len = img_qkv.shape[0], img_qkv.shape[1]
# Split QKV
img_qkv = img_qkv.view(batch_size, image_seq_len, 3,
self.num_attention_heads, -1)
img_q, img_k, img_v = img_qkv[:, :, 0], img_qkv[:, :, 1], img_qkv[:, :,
2]
# Apply QK-Norm if needed
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
# Apply rotary embeddings
cos, sin = freqs_cis
img_q, img_k = _apply_rotary_emb(
img_q, cos, sin,
is_neox_style=False), _apply_rotary_emb(img_k,
cos,
sin,
is_neox_style=False)
# Prepare text for attention using fused operation
txt_attn_input = self.txt_attn_norm(txt, txt_attn_shift, txt_attn_scale)
# Get QKV for text
txt_qkv, _ = self.txt_attn_qkv(txt_attn_input)
batch_size, text_seq_len = txt_qkv.shape[0], txt_qkv.shape[1]
# Split QKV
txt_qkv = txt_qkv.view(batch_size, text_seq_len, 3,
self.num_attention_heads, -1)
txt_q, txt_k, txt_v = txt_qkv[:, :, 0], txt_qkv[:, :, 1], txt_qkv[:, :,
2]
# Apply QK-Norm if needed
txt_q = self.txt_attn_q_norm(txt_q).to(txt_q.dtype)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_k.dtype)
# Run distributed attention
img_attn, txt_attn = self.attn(img_q, img_k, img_v, txt_q, txt_k, txt_v)
img_attn_out, _ = self.img_attn_proj(
img_attn.view(batch_size, image_seq_len, -1))
# Use fused operation for residual connection, normalization, and modulation
img_mlp_input, img_residual = self.img_attn_residual_mlp_norm(
img, img_attn_out, img_attn_gate, img_mlp_shift, img_mlp_scale)
# Process image MLP
img_mlp_out = self.img_mlp(img_mlp_input)
img = self.img_mlp_residual(img_residual, img_mlp_out, img_mlp_gate)
# Process text attention output
txt_attn_out, _ = self.txt_attn_proj(
txt_attn.reshape(batch_size, text_seq_len, -1))
# Use fused operation for residual connection, normalization, and modulation
txt_mlp_input, txt_residual = self.txt_attn_residual_mlp_norm(
txt, txt_attn_out, txt_attn_gate, txt_mlp_shift, txt_mlp_scale)
# Process text MLP
txt_mlp_out = self.txt_mlp(txt_mlp_input)
txt = self.txt_mlp_residual(txt_residual, txt_mlp_out, txt_mlp_gate)
return img, txt
class MMSingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers using distributed attention
and tensor parallelism.
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
mlp_ratio: float = 4.0,
dtype: Optional[torch.dtype] = None,
supported_attention_backends: Optional[Tuple[_Backend, ...]] = None,
prefix: str = "",
):
super().__init__()
self.deterministic = False
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
head_dim = hidden_size // num_attention_heads
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_hidden_dim = mlp_hidden_dim
# Combined QKV and MLP input projection
self.linear1 = ReplicatedLinear(hidden_size,
hidden_size * 3 + mlp_hidden_dim,
bias=True,
params_dtype=dtype,
prefix=f"{prefix}.linear1")
# Combined projection and MLP output
self.linear2 = ReplicatedLinear(hidden_size + mlp_hidden_dim,
hidden_size,
bias=True,
params_dtype=dtype,
prefix=f"{prefix}.linear2")
# QK norm layers
self.q_norm = HunyuanRMSNorm(head_dim, eps=1e-6, dtype=dtype)
self.k_norm = HunyuanRMSNorm(head_dim, eps=1e-6, dtype=dtype)
# Fused operations with better naming
self.input_norm_scale_shift = LayerNormScaleShift(
hidden_size,
norm_type="layer",
eps=1e-6,
elementwise_affine=False,
dtype=dtype)
self.output_residual = ScaleResidual()
# Activation function
self.mlp_act = nn.GELU(approximate="tanh")
# Modulation
self.modulation = ModulateProjection(hidden_size,
factor=3,
act_layer="silu",
dtype=dtype,
prefix=f"{prefix}.modulation")
# Distributed attention
self.attn = DistributedAttention(
num_heads=num_attention_heads,
head_size=head_dim,
causal=False,
supported_attention_backends=supported_attention_backends,
prefix=f"{prefix}.attn")
def forward(
self,
x: torch.Tensor,
vec: torch.Tensor,
txt_len: int,
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
# Process modulation
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
# Apply pre-norm and modulation using fused operation
x_mod = self.input_norm_scale_shift(x, mod_shift, mod_scale)
# Get combined projections
linear1_out, _ = self.linear1(x_mod)
# Split into QKV and MLP parts
qkv, mlp = torch.split(linear1_out,
[3 * self.hidden_size, self.mlp_hidden_dim],
dim=-1)
# Process QKV
batch_size, seq_len = qkv.shape[0], qkv.shape[1]
qkv = qkv.view(batch_size, seq_len, 3, self.num_attention_heads, -1)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
# Apply QK-Norm
q = self.q_norm(q).to(v.dtype)
k = self.k_norm(k).to(v.dtype)
# Split into image and text parts
img_q, txt_q = q[:, :-txt_len], q[:, -txt_len:]
img_k, txt_k = k[:, :-txt_len], k[:, -txt_len:]
img_v, txt_v = v[:, :-txt_len], v[:, -txt_len:]
# Apply rotary embeddings to image parts
cos, sin = freqs_cis
img_q, img_k = _apply_rotary_emb(
img_q, cos, sin,
is_neox_style=False), _apply_rotary_emb(img_k,
cos,
sin,
is_neox_style=False)
# Run distributed attention
img_attn_output, txt_attn_output = self.attn(img_q, img_k, img_v, txt_q,
txt_k, txt_v)
attn_output = torch.cat((img_attn_output, txt_attn_output),
dim=1).view(batch_size, seq_len, -1)
# Process MLP activation
mlp_output = self.mlp_act(mlp)
# Combine attention and MLP outputs
combined = torch.cat((attn_output, mlp_output), dim=-1)
# Final projection
output, _ = self.linear2(combined)
# Apply residual connection with gating using fused operation
return self.output_residual(x, output, mod_gate)
class HunyuanVideoTransformer3DModel(BaseDiT):
"""
HunyuanVideo Transformer backbone adapted for distributed training.
This implementation uses distributed attention and linear layers for efficient
parallel processing across multiple GPUs.
Based on the architecture from:
- Flux.1: https://github.com/black-forest-labs/flux
- MMDiT: http://arxiv.org/abs/2403.03206
"""
# PY: we make the input args the same as HF config
# shard single stream, double stream blocks, and refiner_blocks
_fsdp_shard_conditions = HunyuanVideoConfig()._fsdp_shard_conditions
_supported_attention_backends = HunyuanVideoConfig(
)._supported_attention_backends
_param_names_mapping = HunyuanVideoConfig()._param_names_mapping
def __init__(self, config: HunyuanVideoConfig):
super().__init__(config=config)
self.patch_size = [
config.patch_size_t, config.patch_size, config.patch_size
]
self.in_channels = config.in_channels
self.num_channels_latents = config.num_channels_latents
self.out_channels = config.in_channels if config.out_channels is None else config.out_channels
self.unpatchify_channels = self.out_channels
self.guidance_embeds = config.guidance_embeds
self.rope_dim_list = list(config.rope_axes_dim)
self.rope_theta = config.rope_theta
self.text_states_dim = config.text_embed_dim
self.text_states_dim_2 = config.pooled_projection_dim
# TODO(will): hack?
self.dtype = config.dtype
pe_dim = config.hidden_size // config.num_attention_heads
if sum(config.rope_axes_dim) != pe_dim:
raise ValueError(
f"Got {config.rope_axes_dim} but expected positional dim {pe_dim}"
)
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.num_channels_latents = config.num_channels_latents
# Image projection
self.img_in = PatchEmbed(self.patch_size,
self.in_channels,
self.hidden_size,
dtype=config.dtype,
prefix=f"{config.prefix}.img_in")
self.txt_in = SingleTokenRefiner(self.text_states_dim,
config.hidden_size,
config.num_attention_heads,
depth=config.num_refiner_layers,
dtype=config.dtype,
prefix=f"{config.prefix}.txt_in")
# Time modulation
self.time_in = TimestepEmbedder(self.hidden_size,
act_layer="silu",
dtype=config.dtype,
prefix=f"{config.prefix}.time_in")
# Text modulation
self.vector_in = MLP(self.text_states_dim_2,
self.hidden_size,
self.hidden_size,
act_type="silu",
dtype=config.dtype,
prefix=f"{config.prefix}.vector_in")
# Guidance modulation
self.guidance_in = (TimestepEmbedder(
self.hidden_size,
act_layer="silu",
dtype=config.dtype,
prefix=f"{config.prefix}.guidance_in")
if self.guidance_embeds else None)
# Double blocks
self.double_blocks = nn.ModuleList([
MMDoubleStreamBlock(
config.hidden_size,
config.num_attention_heads,
mlp_ratio=config.mlp_ratio,
dtype=config.dtype,
supported_attention_backends=self._supported_attention_backends,
prefix=f"{config.prefix}.double_blocks.{i}")
for i in range(config.num_layers)
])
# Single blocks
self.single_blocks = nn.ModuleList([
MMSingleStreamBlock(
config.hidden_size,
config.num_attention_heads,
mlp_ratio=config.mlp_ratio,
dtype=config.dtype,
supported_attention_backends=self._supported_attention_backends,
prefix=f"{config.prefix}.single_blocks.{i+config.num_layers}")
for i in range(config.num_single_layers)
])
self.final_layer = FinalLayer(config.hidden_size,
self.patch_size,
self.out_channels,
dtype=config.dtype,
prefix=f"{config.prefix}.final_layer")
self.__post_init__()
# TODO: change the input the FORWAD_BACTCH Dict
# TODO: change output to a dict
def forward(self,
hidden_states: torch.Tensor,
encoder_hidden_states: Union[torch.Tensor, List[torch.Tensor]],
timestep: torch.LongTensor,
encoder_hidden_states_image: Optional[Union[
torch.Tensor, List[torch.Tensor]]] = None,
guidance=None,
**kwargs):
"""
Forward pass of the HunyuanDiT model.
Args:
hidden_states: Input image/video latents [B, C, T, H, W]
encoder_hidden_states: Text embeddings [B, L, D]
timestep: Diffusion timestep
guidance: Guidance scale for CFG
Returns:
Tuple of (output)
"""
if guidance is None:
guidance = torch.tensor([6016.0],
device=hidden_states.device,
dtype=hidden_states.dtype)
img = x = hidden_states
t = timestep
# Split text embeddings - first token is global, rest are per-token
if isinstance(encoder_hidden_states, torch.Tensor):
txt = encoder_hidden_states[:, 1:]
text_states_2 = encoder_hidden_states[:, 0, :self.text_states_dim_2]
else:
txt = encoder_hidden_states[0]
text_states_2 = encoder_hidden_states[1]
# Get spatial dimensions
_, _, ot, oh, ow = x.shape # codespell:ignore
tt, th, tw = (
ot // self.patch_size[0], # codespell:ignore
oh // self.patch_size[1],
ow // self.patch_size[2],
)
# Get rotary embeddings
freqs_cos, freqs_sin = get_rotary_pos_embed(
(tt * get_sequence_model_parallel_world_size(), th, tw),
self.hidden_size, self.num_attention_heads, self.rope_dim_list,
self.rope_theta)
freqs_cos = freqs_cos.to(x.device)
freqs_sin = freqs_sin.to(x.device)
# Prepare modulation vectors
vec = self.time_in(t)
# Add text modulation
vec = vec + self.vector_in(text_states_2)
# Add guidance modulation if needed
if self.guidance_in and guidance is not None:
vec = vec + self.guidance_in(guidance)
# Embed image and text
img = self.img_in(img)
txt = self.txt_in(txt, t)
txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
# Process through double stream blocks
for index, block in enumerate(self.double_blocks):
double_block_args = [img, txt, vec, freqs_cis]
img, txt = block(*double_block_args)
# Merge txt and img to pass through single stream blocks
x = torch.cat((img, txt), 1)
# Process through single stream blocks
if len(self.single_blocks) > 0:
for index, block in enumerate(self.single_blocks):
single_block_args = [
x,
vec,
txt_seq_len,
freqs_cis,
]
x = block(*single_block_args)
# Extract image features
img = x[:, :img_seq_len, ...]
# Final layer processing
img = self.final_layer(img, vec)
# Unpatchify to get original shape
img = unpatchify(img, tt, th, tw, self.patch_size, self.out_channels)
return img
class SingleTokenRefiner(nn.Module):
"""
A token refiner that processes text embeddings with attention to improve
their representation for cross-attention with image features.
"""
def __init__(
self,
in_channels,
hidden_size,
num_attention_heads,
depth=2,
qkv_bias=True,
dtype=None,
prefix: str = "",
) -> None:
super().__init__()
# Input projection
self.input_embedder = ReplicatedLinear(
in_channels,
hidden_size,
bias=True,
params_dtype=dtype,
prefix=f"{prefix}.input_embedder")
# Timestep embedding
self.t_embedder = TimestepEmbedder(hidden_size,
act_layer="silu",
dtype=dtype,
prefix=f"{prefix}.t_embedder")
# Context embedding
self.c_embedder = MLP(in_channels,
hidden_size,
hidden_size,
act_type="silu",
dtype=dtype,
prefix=f"{prefix}.c_embedder")
# Refiner blocks
self.refiner_blocks = nn.ModuleList([
IndividualTokenRefinerBlock(
hidden_size,
num_attention_heads,
qkv_bias=qkv_bias,
dtype=dtype,
prefix=f"{prefix}.refiner_blocks.{i}",
) for i in range(depth)
])
def forward(self, x, t):
# Get timestep embeddings
timestep_aware_representations = self.t_embedder(t)
# Get context-aware representations
context_aware_representations = torch.mean(x, dim=1)
context_aware_representations = self.c_embedder(
context_aware_representations)
c = timestep_aware_representations + context_aware_representations
# Project input
x, _ = self.input_embedder(x)
# Process through refiner blocks
for block in self.refiner_blocks:
x = block(x, c)
return x
class IndividualTokenRefinerBlock(nn.Module):
"""
A transformer block for refining individual tokens with self-attention.
"""
def __init__(
self,
hidden_size,
num_attention_heads,
mlp_ratio=4.0,
qkv_bias=True,
dtype=None,
prefix: str = "",
) -> None:
super().__init__()
self.num_attention_heads = num_attention_heads
mlp_hidden_dim = int(hidden_size * mlp_ratio)
# Normalization and attention
self.norm1 = nn.LayerNorm(hidden_size,
eps=1e-6,
elementwise_affine=True,
dtype=dtype)
self.self_attn_qkv = ReplicatedLinear(hidden_size,
hidden_size * 3,
bias=qkv_bias,
params_dtype=dtype,
prefix=f"{prefix}.self_attn_qkv")
self.self_attn_proj = ReplicatedLinear(
hidden_size,
hidden_size,
bias=qkv_bias,
params_dtype=dtype,
prefix=f"{prefix}.self_attn_proj")
# MLP
self.norm2 = nn.LayerNorm(hidden_size,
eps=1e-6,
elementwise_affine=True,
dtype=dtype)
self.mlp = MLP(hidden_size,
mlp_hidden_dim,
bias=True,
act_type="silu",
dtype=dtype,
prefix=f"{prefix}.mlp")
# Modulation
self.adaLN_modulation = ModulateProjection(
hidden_size,
factor=2,
act_layer="silu",
dtype=dtype,
prefix=f"{prefix}.adaLN_modulation")
# Scaled dot product attention
self.attn = LocalAttention(
num_heads=num_attention_heads,
head_size=hidden_size // num_attention_heads,
# TODO: remove hardcode; remove STA
supported_attention_backends=(_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA),
)
def forward(self, x, c):
# Get modulation parameters
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=-1)
# Self-attention
norm_x = self.norm1(x)
qkv, _ = self.self_attn_qkv(norm_x)
batch_size, seq_len = qkv.shape[0], qkv.shape[1]
qkv = qkv.view(batch_size, seq_len, 3, self.num_attention_heads, -1)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
# Run scaled dot product attention
attn_output = self.attn(q, k, v) # [B, L, H, D]
attn_output = attn_output.reshape(batch_size, seq_len,
-1) # [B, L, H*D]
# Project and apply residual connection with gating
attn_out, _ = self.self_attn_proj(attn_output)
x = x + attn_out * gate_msa.unsqueeze(1)
# MLP
mlp_out = self.mlp(self.norm2(x))
x = x + mlp_out * gate_mlp.unsqueeze(1)
return x
class FinalLayer(nn.Module):
"""
The final layer of DiT that projects features to pixel space.
"""
def __init__(self,
hidden_size,
patch_size,
out_channels,
dtype=None,
prefix: str = "") -> None:
super().__init__()
# Normalization
self.norm_final = nn.LayerNorm(hidden_size,
eps=1e-6,
elementwise_affine=False,
dtype=dtype)
output_dim = patch_size[0] * patch_size[1] * patch_size[2] * out_channels
self.linear = ReplicatedLinear(hidden_size,
output_dim,
bias=True,
params_dtype=dtype,
prefix=f"{prefix}.linear")
# Modulation
self.adaLN_modulation = ModulateProjection(
hidden_size,
factor=2,
act_layer="silu",
dtype=dtype,
prefix=f"{prefix}.adaLN_modulation")
def forward(self, x, c):
# What the heck HF? Why you change the scale and shift order here???
scale, shift = self.adaLN_modulation(c).chunk(2, dim=-1)
x = self.norm_final(x) * (1.0 + scale.unsqueeze(1)) + shift.unsqueeze(1)
x, _ = self.linear(x)
return x
# SPDX-License-Identifier: Apache-2.0
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from fastvideo.v1.attention import DistributedAttention, LocalAttention
from fastvideo.v1.configs.models.dits import WanVideoConfig
from fastvideo.v1.distributed.parallel_state import (
get_sequence_model_parallel_world_size)
from fastvideo.v1.layers.layernorm import (LayerNormScaleShift, RMSNorm,
ScaleResidual,
ScaleResidualLayerNormScaleShift)
from fastvideo.v1.layers.linear import ReplicatedLinear
# from torch.nn import RMSNorm
# TODO: RMSNorm ....
from fastvideo.v1.layers.mlp import MLP
from fastvideo.v1.layers.rotary_embedding import (_apply_rotary_emb,
get_rotary_pos_embed)
from fastvideo.v1.layers.visual_embedding import (ModulateProjection,
PatchEmbed, TimestepEmbedder)
from fastvideo.v1.models.dits.base import BaseDiT
from fastvideo.v1.platforms import _Backend
class WanImageEmbedding(torch.nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.norm1 = nn.LayerNorm(in_features)
self.ff = MLP(in_features, in_features, out_features, act_type="gelu")
self.norm2 = nn.LayerNorm(out_features)
def forward(self,
encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
dtype = encoder_hidden_states_image.dtype
hidden_states = self.norm1(encoder_hidden_states_image)
hidden_states = self.ff(hidden_states)
hidden_states = self.norm2(hidden_states).to(dtype)
return hidden_states
class WanTimeTextImageEmbedding(nn.Module):
def __init__(
self,
dim: int,
time_freq_dim: int,
text_embed_dim: int,
image_embed_dim: Optional[int] = None,
):
super().__init__()
self.time_embedder = TimestepEmbedder(
dim, frequency_embedding_size=time_freq_dim, act_layer="silu")
self.time_modulation = ModulateProjection(dim,
factor=6,
act_layer="silu")
self.text_embedder = MLP(text_embed_dim,
dim,
dim,
bias=True,
act_type="gelu_pytorch_tanh")
self.image_embedder = None
if image_embed_dim is not None:
self.image_embedder = WanImageEmbedding(image_embed_dim, dim)
def forward(
self,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
):
temb = self.time_embedder(timestep)
timestep_proj = self.time_modulation(temb)
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
if encoder_hidden_states_image is not None:
assert self.image_embedder is not None
encoder_hidden_states_image = self.image_embedder(
encoder_hidden_states_image)
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
class WanSelfAttention(nn.Module):
def __init__(self,
dim: int,
num_heads: int,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6,
parallel_attention=False) -> None:
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
self.parallel_attention = parallel_attention
# layers
self.to_q = ReplicatedLinear(dim, dim)
self.to_k = ReplicatedLinear(dim, dim)
self.to_v = ReplicatedLinear(dim, dim)
self.to_out = ReplicatedLinear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
# Scaled dot product attention
self.attn = LocalAttention(
num_heads=num_heads,
head_size=self.head_dim,
dropout_rate=0,
softmax_scale=None,
causal=False,
supported_attention_backends=(_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA))
def forward(self, x: torch.Tensor, context: torch.Tensor,
context_lens: int):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
pass
class WanT2VCrossAttention(WanSelfAttention):
def forward(self, x, context, context_lens):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
b, n, d = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.norm_q(self.to_q(x)[0]).view(b, -1, n, d)
k = self.norm_k(self.to_k(context)[0]).view(b, -1, n, d)
v = self.to_v(context)[0].view(b, -1, n, d)
# compute attention
x = self.attn(q, k, v)
# output
x = x.flatten(2)
x, _ = self.to_out(x)
return x
class WanI2VCrossAttention(WanSelfAttention):
def __init__(
self,
dim: int,
num_heads: int,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6,
supported_attention_backends: Optional[Tuple[_Backend, ...]] = None
) -> None:
super().__init__(dim, num_heads, window_size, qk_norm, eps,
supported_attention_backends)
self.add_k_proj = ReplicatedLinear(dim, dim)
self.add_v_proj = ReplicatedLinear(dim, dim)
self.norm_added_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_added_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, x, context, context_lens):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
context_img = context[:, :257]
context = context[:, 257:]
b, n, d = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.norm_q(self.to_q(x)[0]).view(b, -1, n, d)
k = self.norm_k(self.to_k(context)[0]).view(b, -1, n, d)
v = self.to_v(context)[0].view(b, -1, n, d)
k_img = self.norm_added_k(self.add_k_proj(context_img)[0]).view(
b, -1, n, d)
v_img = self.add_v_proj(context_img)[0].view(b, -1, n, d)
img_x = self.attn(q, k_img, v_img)
# compute attention
x = self.attn(q, k, v)
# output
x = x.flatten(2)
img_x = img_x.flatten(2)
x = x + img_x
x, _ = self.to_out(x)
return x
class WanTransformerBlock(nn.Module):
def __init__(self,
dim: int,
ffn_dim: int,
num_heads: int,
qk_norm: str = "rms_norm_across_heads",
cross_attn_norm: bool = False,
eps: float = 1e-6,
added_kv_proj_dim: Optional[int] = None,
supported_attention_backends: Optional[Tuple[_Backend,
...]] = None,
prefix: str = ""):
super().__init__()
# 1. Self-attention
self.norm1 = nn.LayerNorm(dim, eps, elementwise_affine=False)
self.to_q = ReplicatedLinear(dim, dim, bias=True)
self.to_k = ReplicatedLinear(dim, dim, bias=True)
self.to_v = ReplicatedLinear(dim, dim, bias=True)
self.to_out = ReplicatedLinear(dim, dim, bias=True)
self.attn1 = DistributedAttention(
num_heads=num_heads,
head_size=dim // num_heads,
causal=False,
supported_attention_backends=supported_attention_backends,
prefix=f"{prefix}.attn1")
self.hidden_dim = dim
self.num_attention_heads = num_heads
dim_head = dim // num_heads
if qk_norm == "rms_norm":
self.norm_q = RMSNorm(dim_head, eps=eps)
self.norm_k = RMSNorm(dim_head, eps=eps)
elif qk_norm == "rms_norm_across_heads":
# LTX applies qk norm across all heads
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
else:
print("QK Norm type not supported")
raise Exception
assert cross_attn_norm is True
self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift(
dim,
norm_type="layer",
eps=eps,
elementwise_affine=True,
dtype=torch.float32)
# 2. Cross-attention
if added_kv_proj_dim is not None:
# I2V
self.attn2 = WanI2VCrossAttention(dim,
num_heads,
qk_norm=qk_norm,
eps=eps)
else:
# T2V
self.attn2 = WanT2VCrossAttention(dim,
num_heads,
qk_norm=qk_norm,
eps=eps)
self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift(
dim,
norm_type="layer",
eps=eps,
elementwise_affine=False,
dtype=torch.float32)
# 3. Feed-forward
self.ffn = MLP(dim, ffn_dim, act_type="gelu_pytorch_tanh")
self.mlp_residual = ScaleResidual()
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
if hidden_states.dim() == 4:
hidden_states = hidden_states.squeeze(1)
bs, seq_length, _ = hidden_states.shape
orig_dtype = hidden_states.dtype
assert orig_dtype != torch.float32
e = self.scale_shift_table + temb.float()
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = e.chunk(
6, dim=1)
assert shift_msa.dtype == torch.float32
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) *
(1 + scale_msa) + shift_msa).to(orig_dtype)
query, _ = self.to_q(norm_hidden_states)
key, _ = self.to_k(norm_hidden_states)
value, _ = self.to_v(norm_hidden_states)
if self.norm_q is not None:
query = self.norm_q.forward_native(query)
if self.norm_k is not None:
key = self.norm_k.forward_native(key)
query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
value = value.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
# Apply rotary embeddings
cos, sin = freqs_cis
query, key = _apply_rotary_emb(query, cos, sin,
is_neox_style=False), _apply_rotary_emb(
key, cos, sin, is_neox_style=False)
attn_output, _ = self.attn1(query, key, value)
attn_output = attn_output.flatten(2)
attn_output, _ = self.to_out(attn_output)
attn_output = attn_output.squeeze(1)
null_shift = null_scale = torch.tensor([0], device=hidden_states.device)
norm_hidden_states, hidden_states = self.self_attn_residual_norm(
hidden_states, attn_output, gate_msa, null_shift, null_scale)
norm_hidden_states, hidden_states = norm_hidden_states.to(
orig_dtype), hidden_states.to(orig_dtype)
# 2. Cross-attention
attn_output = self.attn2(norm_hidden_states,
context=encoder_hidden_states,
context_lens=None)
norm_hidden_states, hidden_states = self.cross_attn_residual_norm(
hidden_states, attn_output, 1, c_shift_msa, c_scale_msa)
norm_hidden_states, hidden_states = norm_hidden_states.to(
orig_dtype), hidden_states.to(orig_dtype)
# 3. Feed-forward
ff_output = self.ffn(norm_hidden_states)
hidden_states = self.mlp_residual(hidden_states, ff_output, c_gate_msa)
hidden_states = hidden_states.to(orig_dtype)
return hidden_states
class WanTransformer3DModel(BaseDiT):
_fsdp_shard_conditions = WanVideoConfig()._fsdp_shard_conditions
_supported_attention_backends = WanVideoConfig(
)._supported_attention_backends
_param_names_mapping = WanVideoConfig()._param_names_mapping
def __init__(self, config: WanVideoConfig) -> None:
super().__init__(config=config)
inner_dim = config.num_attention_heads * config.attention_head_dim
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.in_channels = config.in_channels
self.out_channels = config.out_channels
self.num_channels_latents = config.num_channels_latents
self.patch_size = config.patch_size
self.text_len = config.text_len
# 1. Patch & position embedding
self.patch_embedding = PatchEmbed(in_chans=config.in_channels,
embed_dim=inner_dim,
patch_size=config.patch_size,
flatten=False)
# 2. Condition embeddings
self.condition_embedder = WanTimeTextImageEmbedding(
dim=inner_dim,
time_freq_dim=config.freq_dim,
text_embed_dim=config.text_dim,
image_embed_dim=config.image_dim,
)
# 3. Transformer blocks
self.blocks = nn.ModuleList([
WanTransformerBlock(inner_dim,
config.ffn_dim,
config.num_attention_heads,
config.qk_norm,
config.cross_attn_norm,
config.eps,
config.added_kv_proj_dim,
self._supported_attention_backends,
prefix=f"{config.prefix}.blocks.{i}")
for i in range(config.num_layers)
])
# 4. Output norm & projection
self.norm_out = LayerNormScaleShift(inner_dim,
norm_type="layer",
eps=config.eps,
elementwise_affine=False,
dtype=torch.float32)
self.proj_out = nn.Linear(
inner_dim, config.out_channels * math.prod(config.patch_size))
self.scale_shift_table = nn.Parameter(
torch.randn(1, 2, inner_dim) / inner_dim**0.5)
self.gradient_checkpointing = False
self.__post_init__()
def forward(self,
hidden_states: torch.Tensor,
encoder_hidden_states: Union[torch.Tensor, List[torch.Tensor]],
timestep: torch.LongTensor,
encoder_hidden_states_image: Optional[Union[
torch.Tensor, List[torch.Tensor]]] = None,
guidance=None,
**kwargs) -> torch.Tensor:
orig_dtype = hidden_states.dtype
if not isinstance(encoder_hidden_states, torch.Tensor):
encoder_hidden_states = encoder_hidden_states[0]
if isinstance(encoder_hidden_states_image,
list) and len(encoder_hidden_states_image) > 0:
encoder_hidden_states_image = encoder_hidden_states_image[0]
else:
encoder_hidden_states_image = None
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w
# Get rotary embeddings
d = self.hidden_size // self.num_attention_heads
rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]
freqs_cos, freqs_sin = get_rotary_pos_embed(
(post_patch_num_frames * get_sequence_model_parallel_world_size(),
post_patch_height, post_patch_width),
self.hidden_size,
self.num_attention_heads,
rope_dim_list,
dtype=torch.float64,
rope_theta=10000)
freqs_cos = freqs_cos.to(hidden_states.device)
freqs_sin = freqs_sin.to(hidden_states.device)
freqs_cis = (freqs_cos.float(),
freqs_sin.float()) if freqs_cos is not None else None
hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep, encoder_hidden_states, encoder_hidden_states_image)
timestep_proj = timestep_proj.unflatten(1, (6, -1))
if encoder_hidden_states_image is not None:
encoder_hidden_states = torch.concat(
[encoder_hidden_states_image, encoder_hidden_states], dim=1)
assert encoder_hidden_states.dtype == orig_dtype
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.blocks:
hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, timestep_proj,
freqs_cis)
else:
for block in self.blocks:
hidden_states = block(hidden_states, encoder_hidden_states,
timestep_proj, freqs_cis)
# 5. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2,
dim=1)
hidden_states = self.norm_out(hidden_states.float(), shift, scale)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames,
post_patch_height,
post_patch_width, p_t, p_h, p_w,
-1)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
return output
from abc import ABC, abstractmethod
from typing import Optional, Tuple
import torch
from torch import nn
from fastvideo.v1.configs.models.encoders import (BaseEncoderOutput,
ImageEncoderConfig,
TextEncoderConfig)
from fastvideo.v1.platforms import _Backend
class TextEncoder(nn.Module, ABC):
_supported_attention_backends: Tuple[
_Backend, ...] = TextEncoderConfig()._supported_attention_backends
def __init__(self, config: TextEncoderConfig) -> None:
super().__init__()
self.config = config
if not self.supported_attention_backends:
raise ValueError(
f"Subclass {self.__class__.__name__} must define _supported_attention_backends"
)
@abstractmethod
def forward(self,
input_ids: Optional[torch.Tensor],
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
**kwargs) -> BaseEncoderOutput:
pass
@property
def supported_attention_backends(self) -> Tuple[_Backend, ...]:
return self._supported_attention_backends
class ImageEncoder(nn.Module, ABC):
_supported_attention_backends: Tuple[
_Backend, ...] = ImageEncoderConfig()._supported_attention_backends
def __init__(self, config: ImageEncoderConfig) -> None:
super().__init__()
self.config = config
if not self.supported_attention_backends:
raise ValueError(
f"Subclass {self.__class__.__name__} must define _supported_attention_backends"
)
@abstractmethod
def forward(self, pixel_values: torch.Tensor,
**kwargs) -> BaseEncoderOutput:
pass
@property
def supported_attention_backends(self) -> Tuple[_Backend, ...]:
return self._supported_attention_backends
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/clip.py
# Adapted from transformers: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from typing import Iterable, Optional, Set, Tuple, Union
import torch
import torch.nn as nn
# from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
from fastvideo.v1.attention import LocalAttention
from fastvideo.v1.configs.models.encoders import (BaseEncoderOutput,
CLIPTextConfig,
CLIPVisionConfig)
from fastvideo.v1.configs.quantization import QuantizationConfig
from fastvideo.v1.distributed import (divide,
get_tensor_model_parallel_world_size)
from fastvideo.v1.layers.activation import get_act_fn
from fastvideo.v1.layers.linear import (ColumnParallelLinear, QKVParallelLinear,
RowParallelLinear)
from fastvideo.v1.logger import init_logger
from fastvideo.v1.models.encoders.base import ImageEncoder, TextEncoder
from fastvideo.v1.models.encoders.vision import resolve_visual_encoder_outputs
# TODO: support quantization
# from vllm.model_executor.layers.quantization import QuantizationConfig
from fastvideo.v1.models.loader.weight_utils import default_weight_loader
logger = init_logger(__name__)
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class CLIPVisionEmbeddings(nn.Module):
def __init__(self, config: CLIPVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
assert self.image_size % self.patch_size == 0
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = (self.image_size // self.patch_size)**2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim)
self.register_buffer("position_ids",
torch.arange(self.num_positions).expand((1, -1)),
persistent=False)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(
dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
class CLIPTextEmbeddings(nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(config.max_position_embeddings,
embed_dim)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
"position_ids",
torch.arange(config.max_position_embeddings).expand((1, -1)),
persistent=False)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
if input_ids is not None:
seq_length = input_ids.shape[-1]
elif inputs_embeds is not None:
seq_length = inputs_embeds.shape[-2]
else:
raise ValueError(
"Either input_ids or inputs_embeds must be provided.")
max_position_embedding = self.position_embedding.weight.shape[0]
if seq_length > max_position_embedding:
raise ValueError(
f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
f"{seq_length} and max_position_embeddings: {max_position_embedding}"
)
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
class CLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: Union[CLIPVisionConfig, CLIPTextConfig],
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = LocalAttention(
self.num_heads_per_partition,
self.head_dim,
self.num_heads_per_partition,
softmax_scale=self.scale,
causal=True,
supported_attention_backends=config._supported_attention_backends)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
# use flash_attn_func
query_states = query_states.reshape(query_states.shape[0],
query_states.shape[1],
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.reshape(key_states.shape[0],
key_states.shape[1],
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.reshape(value_states.shape[0],
value_states.shape[1],
self.num_heads_per_partition,
self.head_dim)
attn_output = self.attn(query_states, key_states, value_states)
attn_output = attn_output.reshape(
attn_output.shape[0], attn_output.shape[1],
self.num_heads_per_partition * self.head_dim)
attn_output, _ = self.out_proj(attn_output)
return attn_output, None
class CLIPMLP(nn.Module):
def __init__(
self,
config: Union[CLIPVisionConfig, CLIPTextConfig],
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class CLIPEncoderLayer(nn.Module):
def __init__(
self,
config: Union[CLIPTextConfig, CLIPVisionConfig],
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.self_attn = CLIPAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class CLIPEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self
attention layers. Each layer is a [`CLIPEncoderLayer`].
Args:
config: CLIPConfig
"""
def __init__(
self,
config: Union[CLIPVisionConfig, CLIPTextConfig],
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
CLIPEncoderLayer(config=config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
])
def forward(
self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool
) -> Union[torch.Tensor, list[torch.Tensor]]:
hidden_states_pool = [inputs_embeds]
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states)
if return_all_hidden_states:
hidden_states_pool.append(hidden_states)
# If we have multiple feature sample layers, we return all hidden
# states in order and grab the ones we need by index.
if return_all_hidden_states:
return hidden_states_pool
return [hidden_states]
class CLIPTextTransformer(nn.Module):
def __init__(self,
config: CLIPTextConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
prefix: str = ""):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = CLIPTextEmbeddings(config)
self.encoder = CLIPEncoder(
config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=prefix)
self.final_layer_norm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
# For `pooled_output` computation
self.eos_token_id = config.eos_token_id
def forward(
self,
input_ids: Optional[torch.Tensor],
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseEncoderOutput:
r"""
Returns:
"""
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
if input_ids is None:
raise ValueError("You have to specify input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids,
position_ids=position_ids)
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
# causal_attention_mask = _create_4d_causal_attention_mask(
# input_shape, hidden_states.dtype, device=hidden_states.device
# )
# # expand attention_mask
# if attention_mask is not None and not self._use_flash_attention_2:
# raise NotImplementedError("attention_mask is not supported for CLIPTextTransformer")
# # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
# attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
# attention_mask=attention_mask,
# causal_attention_mask=causal_attention_mask,
# output_attentions=output_attentions,
return_all_hidden_states=output_hidden_states,
# return_dict=return_dict,
)
last_hidden_state = encoder_outputs[-1]
last_hidden_state = self.final_layer_norm(last_hidden_state)
if self.eos_token_id == 2:
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
# ------------------------------------------------------------
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0],
device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).
argmax(dim=-1),
]
else:
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0],
device=last_hidden_state.device),
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
# Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer)
(input_ids.to(dtype=torch.int, device=last_hidden_state.device
) == self.eos_token_id).int().argmax(dim=-1),
]
return BaseEncoderOutput(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs,
# attentions=encoder_outputs.attentions,
)
class CLIPTextModel(TextEncoder):
def __init__(
self,
config: CLIPTextConfig,
) -> None:
super().__init__(config)
self.text_model = CLIPTextTransformer(config=config,
quant_config=config.quant_config,
prefix=config.prefix)
def forward(
self,
input_ids: Optional[torch.Tensor],
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
) -> BaseEncoderOutput:
outputs: BaseEncoderOutput = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_hidden_states=output_hidden_states,
)
return outputs
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
# Define mapping for stacked parameters
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
# Handle q_proj, k_proj, v_proj -> qkv_proj mapping
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name in name:
# Replace the weight name with the parameter name
model_param_name = name.replace(weight_name, param_name)
if model_param_name in params_dict:
param = params_dict[model_param_name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(model_param_name)
break
else:
# Use default weight loader for all other parameters
if name in params_dict:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class CLIPVisionTransformer(nn.Module):
def __init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = CLIPVisionEmbeddings(config)
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
# the original transformers code and name of the model weights.
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = CLIPEncoder(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
)
num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers.")
# If possible, skip post_layernorm to conserve memory
if require_post_norm is None:
require_post_norm = len(self.encoder.layers) == num_hidden_layers
if require_post_norm:
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
else:
self.post_layernorm = None
def forward(
self,
pixel_values: torch.Tensor,
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor:
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
return_all_hidden_states = feature_sample_layers is not None
# Produces either the last layer output or all of the hidden states,
# depending on if we have feature_sample_layers or not
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
return_all_hidden_states=return_all_hidden_states)
if not return_all_hidden_states:
encoder_outputs = encoder_outputs[0]
# Handle post-norm (if applicable) and stacks feature layers if needed
encoder_outputs = resolve_visual_encoder_outputs(
encoder_outputs, feature_sample_layers, self.post_layernorm,
self.config.num_hidden_layers)
return encoder_outputs
class CLIPVisionModel(ImageEncoder):
config_class = CLIPVisionConfig
main_input_name = "pixel_values"
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
def __init__(self, config: CLIPVisionConfig) -> None:
super().__init__(config)
self.vision_model = CLIPVisionTransformer(
config=config,
quant_config=config.quant_config,
num_hidden_layers_override=config.num_hidden_layers_override,
require_post_norm=config.require_post_norm,
prefix=f"{config.prefix}.vision_model")
def forward(
self,
pixel_values: torch.Tensor,
feature_sample_layers: Optional[list[int]] = None,
**kwargs,
) -> BaseEncoderOutput:
last_hidden_state = self.vision_model(pixel_values,
feature_sample_layers)
return BaseEncoderOutput(last_hidden_state=last_hidden_state)
@property
def device(self):
return next(self.parameters()).device
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
layer_count = len(self.vision_model.encoder.layers)
for name, loaded_weight in weights:
if name.startswith("visual_projection"):
continue
# post_layernorm is not needed in CLIPVisionModel
if (name.startswith("vision_model.post_layernorm")
and self.vision_model.post_layernorm is None):
continue
# omit layers when num_hidden_layers_override is set
if name.startswith("vision_model.encoder.layers"):
layer_idx = int(name.split(".")[3])
if layer_idx >= layer_count:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/llama.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple
import torch
from torch import nn
# from vllm.model_executor.layers.quantization import QuantizationConfig
from fastvideo.v1.attention import LocalAttention
# from ..utils import (extract_layer_index)
from fastvideo.v1.configs.models.encoders import BaseEncoderOutput, LlamaConfig
from fastvideo.v1.configs.quantization import QuantizationConfig
from fastvideo.v1.distributed import get_tensor_model_parallel_world_size
from fastvideo.v1.layers.activation import SiluAndMul
from fastvideo.v1.layers.layernorm import RMSNorm
from fastvideo.v1.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, RowParallelLinear)
from fastvideo.v1.layers.rotary_embedding import get_rope
from fastvideo.v1.layers.vocab_parallel_embedding import VocabParallelEmbedding
from fastvideo.v1.models.encoders.base import TextEncoder
from fastvideo.v1.models.loader.weight_utils import (default_weight_loader,
maybe_remap_kv_scale_name)
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
# output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
x, _ = self.gate_up_proj(x)
x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
class LlamaAttention(nn.Module):
def __init__(self,
config: LlamaConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
prefix: str = "") -> None:
super().__init__()
# layer_idx = extract_layer_index(prefix)
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
# Phi models introduced a partial_rotary_factor parameter in the config
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias_o_proj,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
is_neox_style = True
is_gguf = quant_config and hasattr(
quant_config, "get_name") and quant_config.get_name() == "gguf"
if is_gguf and config.model_type == "llama":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_dim,
max_position=max_position_embeddings,
base=int(rope_theta),
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
)
self.attn = LocalAttention(
self.num_heads,
self.head_dim,
self.num_kv_heads,
softmax_scale=self.scaling,
causal=True,
supported_attention_backends=config._supported_attention_backends)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
# attn_output = self.attn(q, k, v)
# use flash_attn_func
# TODO (Attn abstraction and backend)
# from flash_attn import flash_attn_func
# reshape q, k, v to (batch_size, seq_len, num_heads, head_dim)
batch_size = q.shape[0]
seq_len = q.shape[1]
q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
k = k.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim)
v = v.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim)
# import pdb; pdb.set_trace()
# attn_output = flash_attn_func(q, k, v, softmax_scale=self.scaling, causal=True)
attn_output = self.attn(q, k, v)
attn_output = attn_output.reshape(batch_size, seq_len,
self.num_heads * self.head_dim)
output, _ = self.o_proj(attn_output)
return output
class LlamaDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
bias_o_proj = attention_bias
# support internlm/internlm3-8b with qkv_bias
if hasattr(config, 'qkv_bias'):
attention_bias = config.qkv_bias
self.self_attn = LlamaAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
bias_o_proj=bias_o_proj,
prefix=f"{prefix}.self_attn",
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class LlamaModel(TextEncoder):
def __init__(
self,
config: LlamaConfig,
):
super().__init__(config)
self.config = config
self.quant_config = self.config.quant_config
if config.lora_config is not None:
max_loras = 1
lora_vocab_size = 1
if hasattr(config.lora_config, "max_loras"):
max_loras = config.lora_config.max_loras
if hasattr(config.lora_config, "lora_extra_vocab_size"):
lora_vocab_size = config.lora_config.lora_extra_vocab_size
lora_vocab = lora_vocab_size * max_loras
else:
lora_vocab = 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=config.quant_config,
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config=config,
quant_config=config.quant_config,
prefix=f"{config.prefix}.layers.{i}")
for i in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
) -> BaseEncoderOutput:
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
if position_ids is None:
position_ids = torch.arange(
0, hidden_states.shape[1],
device=hidden_states.device).unsqueeze(0)
all_hidden_states: Optional[Tuple[Any, ...]] = (
) if output_hidden_states else None
for layer in self.layers:
if all_hidden_states is not None:
# TODO
all_hidden_states += (
hidden_states, ) if residual is None else (hidden_states +
residual, )
hidden_states, residual = layer(position_ids, hidden_states,
residual)
hidden_states, _ = self.norm(hidden_states, residual)
# add hidden states from the last decoder layer
if all_hidden_states is not None:
all_hidden_states += (hidden_states, )
# TODO(will): maybe unify the output format with other models and use
# our own class
output = BaseEncoderOutput(
last_hidden_state=hidden_states,
# past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
# attentions=all_self_attns,
)
return output
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# if (self.quant_config is not None and
# (scale_name := self.quant_config.get_cache_scale(name))):
# # Loading kv cache quantization scales
# param = params_dict[scale_name]
# weight_loader = getattr(param, "weight_loader",
# default_weight_loader)
# loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
# loaded_weight[0])
# weight_loader(param, loaded_weight)
# loaded_params.add(scale_name)
# continue
if "scale" in name:
# Remapping the name of FP8 kv-scale.
kv_scale_name: Optional[str] = maybe_remap_kv_scale_name(
name, params_dict)
if kv_scale_name is None:
continue
else:
name = kv_scale_name
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# SPDX-License-Identifier: Apache-2.0
# Adapted from transformers: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/t5/modeling_t5.py
# Derived from T5 implementation posted on HuggingFace; license below:
#
# coding=utf-8
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch T5 & UMT5 model."""
import math
from dataclasses import dataclass
from typing import Iterable, Optional, Set, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from fastvideo.v1.configs.models.encoders import BaseEncoderOutput, T5Config
from fastvideo.v1.configs.quantization import QuantizationConfig
from fastvideo.v1.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from fastvideo.v1.layers.activation import get_act_fn
from fastvideo.v1.layers.layernorm import RMSNorm
from fastvideo.v1.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, RowParallelLinear)
from fastvideo.v1.layers.vocab_parallel_embedding import VocabParallelEmbedding
from fastvideo.v1.models.encoders.base import TextEncoder
from fastvideo.v1.models.loader.weight_utils import default_weight_loader
class AttentionType:
"""
Attention type.
Use string to be compatible with `torch.compile`.
"""
# Decoder attention between previous layer Q/K/V
DECODER = "decoder"
# Encoder attention between previous layer Q/K/V for encoder-decoder
ENCODER = "encoder"
# Encoder attention between previous layer Q/K/V
ENCODER_ONLY = "encoder_only"
# Attention between dec. Q and enc. K/V for encoder-decoder
ENCODER_DECODER = "encoder_decoder"
@dataclass
class AttentionMetadata:
attn_bias: torch.Tensor
class T5DenseActDense(nn.Module):
def __init__(self,
config: T5Config,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.wi = MergedColumnParallelLinear(config.d_model, [config.d_ff],
bias=False)
self.wo = RowParallelLinear(config.d_ff,
config.d_model,
bias=False,
quant_config=quant_config)
self.act = get_act_fn(config.dense_act_fn)
def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.wo(hidden_states)
return hidden_states
class T5DenseGatedActDense(nn.Module):
def __init__(self,
config: T5Config,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.wi_0 = MergedColumnParallelLinear(config.d_model, [config.d_ff],
bias=False,
quant_config=quant_config)
self.wi_1 = MergedColumnParallelLinear(config.d_model, [config.d_ff],
bias=False,
quant_config=quant_config)
# Should not run in fp16 unless mixed-precision is used,
# see https://github.com/huggingface/transformers/issues/20287.
self.wo = RowParallelLinear(config.d_ff,
config.d_model,
bias=False,
quant_config=quant_config)
self.act = get_act_fn(config.dense_act_fn)
def forward(self, hidden_states) -> torch.Tensor:
hidden_gelu = self.act(self.wi_0(hidden_states)[0])
hidden_linear, _ = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states, _ = self.wo(hidden_states)
return hidden_states
class T5LayerFF(nn.Module):
def __init__(self,
config: T5Config,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
if config.is_gated_act:
self.DenseReluDense = T5DenseGatedActDense(
config, quant_config=quant_config)
else:
self.DenseReluDense = T5DenseActDense(config,
quant_config=quant_config)
self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
def forward(self, hidden_states) -> torch.Tensor:
forwarded_states = self.layer_norm.forward_native(hidden_states)
forwarded_states = self.DenseReluDense(forwarded_states)
hidden_states = hidden_states + forwarded_states
return hidden_states
# T5 has attn_bias and does not use softmax scaling
class T5MultiHeadAttention(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, q, k, v, attn_bias=None):
b, _, n, c = q.shape
attn = torch.einsum('binc,bjnc->bnij', q, k)
if attn_bias is not None:
attn += attn_bias
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
x = torch.einsum('bnij,bjnc->binc', attn, v)
x = x.reshape(b, -1, n * c)
return x
class T5Attention(nn.Module):
def __init__(self,
config: T5Config,
attn_type: str,
has_relative_attention_bias=False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.attn_type = attn_type
# Cross-attention has no relative pos encoding anyway
self.is_decoder = attn_type == AttentionType.DECODER
self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = \
config.relative_attention_num_buckets
self.relative_attention_max_distance = \
config.relative_attention_max_distance
self.d_model = config.d_model
self.key_value_proj_dim = config.d_kv
self.total_num_heads = self.total_num_kv_heads = config.num_heads
# Partition heads across multiple tensor parallel GPUs.
tp_world_size = get_tensor_model_parallel_world_size()
assert config.num_heads % tp_world_size == 0
self.n_heads = config.num_heads // tp_world_size
self.inner_dim = self.n_heads * self.key_value_proj_dim
# No GQA in t5.
# self.n_kv_heads = self.n_heads
self.qkv_proj = QKVParallelLinear(
self.d_model,
self.d_model // self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.attn = T5MultiHeadAttention()
if self.has_relative_attention_bias:
self.relative_attention_bias = \
VocabParallelEmbedding(self.relative_attention_num_buckets,
self.total_num_heads,
org_num_embeddings=self.relative_attention_num_buckets,
padding_size=self.relative_attention_num_buckets,
quant_config=quant_config)
self.o = RowParallelLinear(
self.d_model,
self.d_model,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
@staticmethod
def _relative_position_bucket(relative_position,
bidirectional=True,
num_buckets=32,
max_distance=128) -> torch.Tensor:
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position,
i.e. the distance in tokens from the attending position to the
attended-to position. If bidirectional=False, then positive relative
positions are invalid. We use smaller buckets for small absolute
relative_position and larger buckets for larger absolute
relative_positions. All relative positions >=max_distance map to the
same bucket. All relative positions <=-max_distance map to the same
bucket. This should allow for more graceful generalization to longer
sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets)
"""# noqa: E501
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(
torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position,
torch.zeros_like(relative_position))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins
# in positions up to max_distance
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact) /
math.log(max_distance / max_exact) *
(num_buckets - max_exact)).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large,
torch.full_like(relative_position_if_large, num_buckets - 1))
relative_buckets += torch.where(is_small, relative_position,
relative_position_if_large)
return relative_buckets
def compute_bias(self,
query_length,
key_length,
device=None) -> torch.Tensor:
"""Compute binned relative position bias"""
if device is None:
device = self.relative_attention_bias.weight.device
context_position = torch.arange(query_length,
dtype=torch.long,
device=device)[:, None]
memory_position = torch.arange(key_length,
dtype=torch.long,
device=device)[None, :]
# max_seq_len, nh
relative_position = memory_position - context_position
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(
relative_position_bucket
) # shape (query_length, key_length, num_heads)
x = values.permute([2, 0, 1]).unsqueeze(
0) # shape (1, num_heads, query_length, key_length)
return x
def forward(
self,
hidden_states: torch.Tensor, # (num_tokens, d_model)
attention_mask: torch.Tensor,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor:
bs, seq_len, _ = hidden_states.shape
num_seqs = bs
n, c = self.n_heads, self.d_model // self.total_num_heads
qkv, _ = self.qkv_proj(hidden_states)
# Projection of 'own' hidden state (self-attention). No GQA here.
q, k, v = qkv.split(self.inner_dim, dim=-1)
q = q.reshape(bs, seq_len, n, c)
k = k.reshape(bs, seq_len, n, c)
v = v.reshape(bs, seq_len, n, c)
assert attn_metadata is not None
attn_bias = attn_metadata.attn_bias
# Not compatible with CP here (as all encoder-decoder models),
# as it assumes homogeneous batch (prefills or decodes).
if self.has_relative_attention_bias:
# Self-attention. Compute T5 relative positional encoding.
# The bias term is computed on longest sequence in batch. Biases
# for shorter sequences are slices of the longest.
assert self.attn_type == AttentionType.ENCODER
attn_bias = self.compute_bias(seq_len,
seq_len).repeat(num_seqs, 1, 1, 1)
attn_metadata.attn_bias = attn_bias
else:
# Encoder/Decoder Self-Attention Layer, attn bias already cached.
assert attn_bias is not None
if attention_mask is not None:
attention_mask = attention_mask.view(
bs, 1, 1,
-1) if attention_mask.ndim == 2 else attention_mask.unsqueeze(1)
attn_bias.masked_fill_(attention_mask == 0,
torch.finfo(q.dtype).min)
if get_tensor_model_parallel_world_size() > 1:
rank = get_tensor_model_parallel_rank()
attn_bias = attn_bias[:, rank * self.n_heads:(rank + 1) *
self.n_heads, :, :]
attn_output = self.attn(q, k, v, attn_bias)
output, _ = self.o(attn_output)
return output
class T5LayerSelfAttention(nn.Module):
def __init__(
self,
config,
has_relative_attention_bias=False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.SelfAttention = T5Attention(
config,
AttentionType.DECODER
if "decoder" in prefix else AttentionType.ENCODER,
has_relative_attention_bias=has_relative_attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.SelfAttention")
self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor:
normed_hidden_states = self.layer_norm.forward_native(hidden_states)
attention_output = self.SelfAttention(
hidden_states=normed_hidden_states,
attention_mask=attention_mask,
attn_metadata=attn_metadata,
)
hidden_states = hidden_states + attention_output
return hidden_states
class T5LayerCrossAttention(nn.Module):
def __init__(self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.EncDecAttention = T5Attention(config,
AttentionType.ENCODER_DECODER,
has_relative_attention_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.EncDecAttention")
self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor:
normed_hidden_states = self.layer_norm.forward_native(hidden_states)
attention_output = self.EncDecAttention(
hidden_states=normed_hidden_states,
attn_metadata=attn_metadata,
)
hidden_states = hidden_states + attention_output
return hidden_states
class T5Block(nn.Module):
def __init__(self,
config: T5Config,
is_decoder: bool,
has_relative_attention_bias=False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.is_decoder = is_decoder
self.layer = nn.ModuleList()
self.layer.append(
T5LayerSelfAttention(
config,
has_relative_attention_bias=has_relative_attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.self_attn"))
if self.is_decoder:
self.layer.append(
T5LayerCrossAttention(config,
quant_config=quant_config,
prefix=f"{prefix}.cross_attn"))
self.layer.append(T5LayerFF(config, quant_config=quant_config))
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor:
hidden_states = self.layer[0](hidden_states=hidden_states,
attention_mask=attention_mask,
attn_metadata=attn_metadata)
if self.is_decoder:
hidden_states = self.layer[1](hidden_states=hidden_states,
attn_metadata=attn_metadata)
# Apply Feed Forward layer
hidden_states = self.layer[2](hidden_states)
else:
hidden_states = self.layer[1](hidden_states)
return hidden_states
class T5Stack(nn.Module):
def __init__(self,
config: T5Config,
is_decoder: bool,
n_layers: int,
embed_tokens=None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
is_umt5: bool = False):
super().__init__()
self.embed_tokens = embed_tokens
self.is_umt5 = is_umt5
if is_umt5:
self.block = nn.ModuleList([
T5Block(config,
is_decoder=is_decoder,
has_relative_attention_bias=True,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{i}") for i in range(n_layers)
])
else:
# Only the first block has relative positional encoding.
self.block = nn.ModuleList([
T5Block(config,
is_decoder=is_decoder,
has_relative_attention_bias=i == 0,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{i}") for i in range(n_layers)
])
self.final_layer_norm = RMSNorm(config.d_model,
eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for idx, block in enumerate(self.block):
hidden_states = block(
hidden_states=hidden_states,
attention_mask=attention_mask,
attn_metadata=attn_metadata,
)
hidden_states = self.final_layer_norm.forward_native(hidden_states)
return hidden_states
class T5EncoderModel(TextEncoder):
def __init__(self, config: T5Config, prefix: str = ""):
super().__init__(config)
quant_config = None
self.shared = VocabParallelEmbedding(
config.vocab_size,
config.d_model,
org_num_embeddings=config.vocab_size)
self.encoder = T5Stack(config,
False,
config.num_layers,
self.shared,
quant_config=quant_config,
prefix=f"{prefix}.encoder",
is_umt5=False)
def get_input_embeddings(self):
return self.shared
def forward(
self,
input_ids: Optional[torch.Tensor],
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
) -> BaseEncoderOutput:
attn_metadata = AttentionMetadata(None)
hidden_states = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
attn_metadata=attn_metadata,
)
return BaseEncoderOutput(last_hidden_state=hidden_states)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q", "q"),
(".qkv_proj", ".k", "k"),
(".qkv_proj", ".v", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
loaded = False
if "decoder" in name or "lm_head" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded = True
break
if not loaded:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class UMT5EncoderModel(TextEncoder):
def __init__(self, config: T5Config, prefix: str = ""):
super().__init__(config)
quant_config = None
self.shared = VocabParallelEmbedding(
config.vocab_size,
config.d_model,
org_num_embeddings=config.vocab_size)
self.encoder = T5Stack(config,
False,
config.num_layers,
self.shared,
quant_config=quant_config,
prefix=f"{prefix}.encoder",
is_umt5=True)
def get_input_embeddings(self):
return self.shared
def forward(
self,
input_ids: Optional[torch.Tensor],
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
) -> BaseEncoderOutput:
attn_metadata = AttentionMetadata(None)
hidden_states = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
attn_metadata=attn_metadata,
)
return BaseEncoderOutput(
last_hidden_state=hidden_states,
attention_mask=attention_mask,
)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q", "q"),
(".qkv_proj", ".k", "k"),
(".qkv_proj", ".v", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
loaded = False
if "decoder" in name or "lm_head" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded = True
break
if not loaded:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/vision.py
from abc import ABC, abstractmethod
from typing import Generic, Optional, TypeVar, Union
import torch
from transformers import PretrainedConfig
from fastvideo.v1.logger import init_logger
logger = init_logger(__name__)
_C = TypeVar("_C", bound=PretrainedConfig)
class VisionEncoderInfo(ABC, Generic[_C]):
def __init__(self, vision_config: _C) -> None:
super().__init__()
self.vision_config = vision_config
@abstractmethod
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
raise NotImplementedError
@abstractmethod
def get_max_image_tokens(self) -> int:
raise NotImplementedError
@abstractmethod
def get_image_size(self) -> int:
raise NotImplementedError
@abstractmethod
def get_patch_size(self) -> int:
raise NotImplementedError
@abstractmethod
def get_patch_grid_length(self) -> int:
raise NotImplementedError
def resolve_visual_encoder_outputs(
encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],
feature_sample_layers: Optional[list[int]],
post_layer_norm: Optional[torch.nn.LayerNorm],
max_possible_layers: int,
) -> torch.Tensor:
"""Given the outputs a visual encoder module that may correspond to the
output of the last layer, or a list of hidden states to be stacked,
handle post normalization and resolve it into a single output tensor.
Args:
encoder_outputs: Output of encoder's last layer or all hidden states.
feature_sample_layers: Optional layer indices to grab from the encoder
outputs; if provided, encoder outputs must be a list.
post_layer_norm: Post norm to apply to the output of the encoder.
max_possible_layers: Total layers in the fully loaded visual encoder.
"""
if feature_sample_layers is None:
if post_layer_norm is not None:
return post_layer_norm(encoder_outputs)
return encoder_outputs
# Get the hidden states corresponding to the layer indices.
# Negative values are relative to the full visual encoder,
# so offset them depending on how many layers were loaded.
# NOTE: this assumes that encoder_outputs is a list containing
# the inputs to the visual encoder, followed by the hidden states
# of each layer.
num_loaded_layers = len(encoder_outputs) - 1
offset = max_possible_layers - num_loaded_layers
hs_pool = [
encoder_outputs[layer_idx]
if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
for layer_idx in feature_sample_layers
]
# Apply post-norm on the final hidden state if we are using it
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs)
return torch.cat(hs_pool, dim=-1)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# Adapted from SGLang: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/hf_transformers_utils.py
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for Huggingface Transformers."""
import contextlib
import json
import os
from pathlib import Path
from typing import Any, Dict, Optional, Type, Union
from huggingface_hub import snapshot_download
from transformers import AutoConfig, PretrainedConfig
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
# ChatGLMConfig.model_type: ChatGLMConfig,
# DbrxConfig.model_type: DbrxConfig,
# ExaoneConfig.model_type: ExaoneConfig,
# Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
}
for name, cls in _CONFIG_REGISTRY.items():
with contextlib.suppress(ValueError):
AutoConfig.register(name, cls)
def download_from_hf(model_path: str):
if os.path.exists(model_path):
return model_path
return snapshot_download(model_path,
allow_patterns=["*.json", "*.bin", "*.model"])
def get_hf_config(
model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
model_override_args: Optional[dict] = None,
**kwargs,
):
is_gguf = check_gguf_file(model)
if is_gguf:
raise NotImplementedError("GGUF models are not supported.")
config = AutoConfig.from_pretrained(model,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs)
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision)
# NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
config._name_or_path = model
if model_override_args:
config.update(model_override_args)
# Special architecture mapping check for GGUF models
if is_gguf:
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
raise RuntimeError(
f"Can't get gguf config for {config.model_type}.")
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})
return config
def get_diffusers_config(
model: str,
fastvideo_args: Optional[dict] = None,
) -> Dict[str, Any]:
"""Gets a configuration for the given diffusers model.
Args:
model: The model name or path.
fastvideo_args: Optional inference arguments to override in the config.
Returns:
The loaded configuration.
"""
config_name = "config.json"
if "scheduler" in model:
config_name = "scheduler_config.json"
# Check if the model path exists
if os.path.exists(model):
config_file = os.path.join(model, config_name)
if os.path.exists(config_file):
try:
# Load the config directly from the file
with open(config_file) as f:
config_dict: Dict[str, Any] = json.load(f)
# TODO(will): apply any overrides from inference args
return config_dict
except Exception as e:
raise RuntimeError(
f"Failed to load diffusers config from {config_file}: {e}"
) from e
raise RuntimeError(f"Config file not found at {config_file}")
else:
raise RuntimeError(f"Diffusers config file not found at {model}")
# Models don't use the same configuration key for determining the maximum
# context length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we
# have a preference for which value gets used.
CONTEXT_LENGTH_KEYS = [
"max_sequence_length",
"seq_length",
"max_seq_len",
"model_max_length",
"max_position_embeddings",
]
def attach_additional_stop_token_ids(tokenizer):
# Special handling for stop token <|eom_id|> generated by llama 3 tool use.
if "<|eom_id|>" in tokenizer.get_added_vocab():
tokenizer.additional_stop_token_ids = set(
[tokenizer.get_added_vocab()["<|eom_id|>"]])
else:
tokenizer.additional_stop_token_ids = None
def check_gguf_file(model: Union[str, os.PathLike]) -> bool:
"""Check if the file is a GGUF model."""
model = Path(model)
if not model.is_file():
return False
elif model.suffix == ".gguf":
return True
with open(model, "rb") as f:
header = f.read(4)
return header == b"GGUF"
# SPDX-License-Identifier: Apache-2.0
import dataclasses
import glob
import json
import os
import time
from abc import ABC, abstractmethod
from typing import Any, Generator, Iterable, List, Optional, Tuple, cast
import torch
import torch.nn as nn
from safetensors.torch import load_file as safetensors_load_file
from transformers import AutoImageProcessor, AutoTokenizer
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from fastvideo.v1.fastvideo_args import FastVideoArgs
from fastvideo.v1.logger import init_logger
from fastvideo.v1.models.hf_transformer_utils import get_diffusers_config
from fastvideo.v1.models.loader.fsdp_load import load_fsdp_model
from fastvideo.v1.models.loader.utils import set_default_torch_dtype
from fastvideo.v1.models.loader.weight_utils import (
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
pt_weights_iterator, safetensors_weights_iterator)
from fastvideo.v1.models.registry import ModelRegistry
from fastvideo.v1.utils import PRECISION_TO_TYPE
logger = init_logger(__name__)
class ComponentLoader(ABC):
"""Base class for loading a specific type of model component."""
def __init__(self, device=None) -> None:
self.device = device
@abstractmethod
def load(self, model_path: str, architecture: str,
fastvideo_args: FastVideoArgs):
"""
Load the component based on the model path, architecture, and inference args.
Args:
model_path: Path to the component model
architecture: Architecture of the component model
fastvideo_args: Inference arguments
Returns:
The loaded component
"""
raise NotImplementedError
@classmethod
def for_module_type(cls, module_type: str,
transformers_or_diffusers: str) -> 'ComponentLoader':
"""
Factory method to create a component loader for a specific module type.
Args:
module_type: Type of module (e.g., "vae", "text_encoder", "transformer", "scheduler")
transformers_or_diffusers: Whether the module is from transformers or diffusers
Returns:
A component loader for the specified module type
"""
# Map of module types to their loader classes and expected library
module_loaders = {
"scheduler": (SchedulerLoader, "diffusers"),
"transformer": (TransformerLoader, "diffusers"),
"vae": (VAELoader, "diffusers"),
"text_encoder": (TextEncoderLoader, "transformers"),
"text_encoder_2": (TextEncoderLoader, "transformers"),
"tokenizer": (TokenizerLoader, "transformers"),
"tokenizer_2": (TokenizerLoader, "transformers"),
"image_processor": (ImageProcessorLoader, "transformers"),
"image_encoder": (ImageEncoderLoader, "transformers"),
}
if module_type in module_loaders:
loader_cls, expected_library = module_loaders[module_type]
# Assert that the library matches what's expected for this module type
assert transformers_or_diffusers == expected_library, f"{module_type} must be loaded from {expected_library}, got {transformers_or_diffusers}"
return loader_cls()
# For unknown module types, use a generic loader
logger.warning(
"No specific loader found for module type: %s. Using generic loader.",
module_type)
return GenericComponentLoader(transformers_or_diffusers)
class TextEncoderLoader(ComponentLoader):
"""Loader for text encoders."""
@dataclasses.dataclass
class Source:
"""A source for weights."""
model_or_path: str
"""The model ID or path."""
prefix: str = ""
"""A prefix to prepend to all weights."""
fall_back_to_pt: bool = True
"""Whether .pt weights can be used."""
allow_patterns_overrides: Optional[list[str]] = None
"""If defined, weights will load exclusively using these patterns."""
counter_before_loading_weights: float = 0.0
counter_after_loading_weights: float = 0.0
def _prepare_weights(
self,
model_name_or_path: str,
fall_back_to_pt: bool,
allow_patterns_overrides: Optional[list[str]],
) -> Tuple[str, List[str], bool]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
# model_name_or_path = (self._maybe_download_from_modelscope(
# model_name_or_path, revision) or model_name_or_path)
is_local = os.path.isdir(model_name_or_path)
assert is_local, "Model path must be a local directory"
use_safetensors = False
index_file = SAFE_WEIGHTS_INDEX_NAME
allow_patterns = ["*.safetensors", "*.bin"]
if fall_back_to_pt:
allow_patterns += ["*.pt"]
if allow_patterns_overrides is not None:
allow_patterns = allow_patterns_overrides
hf_folder = model_name_or_path
hf_weights_files: List[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if use_safetensors:
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file)
else:
hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
return hf_folder, hf_weights_files, use_safetensors
def _get_weights_iterator(
self, source: "Source"
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path, source.fall_back_to_pt,
source.allow_patterns_overrides)
if use_safetensors:
weights_iterator = safetensors_weights_iterator(hf_weights_files)
else:
weights_iterator = pt_weights_iterator(hf_weights_files)
if self.counter_before_loading_weights == 0.0:
self.counter_before_loading_weights = time.perf_counter()
# Apply the prefix.
return ((source.prefix + name, tensor)
for (name, tensor) in weights_iterator)
def _get_all_weights(
self,
model_config: Any,
model: nn.Module,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
primary_weights = TextEncoderLoader.Source(
model_config.model,
prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
allow_patterns_overrides=getattr(model, "allow_patterns_overrides",
None),
)
yield from self._get_weights_iterator(primary_weights)
secondary_weights = cast(
Iterable[TextEncoderLoader.Source],
getattr(model, "secondary_weights", ()),
)
for source in secondary_weights:
yield from self._get_weights_iterator(source)
def load(self, model_path: str, architecture: str,
fastvideo_args: FastVideoArgs):
"""Load the text encoders based on the model path, architecture, and inference args."""
# model_config: PretrainedConfig = get_hf_config(
# model=model_path,
# trust_remote_code=fastvideo_args.trust_remote_code,
# revision=fastvideo_args.revision,
# model_override_args=None,
# )
with open(os.path.join(model_path, "config.json")) as f:
model_config = json.load(f)
model_config.pop("_name_or_path", None)
model_config.pop("transformers_version", None)
model_config.pop("model_type", None)
model_config.pop("tokenizer_class", None)
model_config.pop("torch_dtype", None)
logger.info("HF Model config: %s", model_config)
# @TODO(Wei): Better way to handle this?
try:
encoder_config = fastvideo_args.text_encoder_configs[0]
encoder_config.update_model_arch(model_config)
encoder_precision = fastvideo_args.text_encoder_precisions[0]
except Exception:
encoder_config = fastvideo_args.text_encoder_configs[1]
encoder_config.update_model_arch(model_config)
encoder_precision = fastvideo_args.text_encoder_precisions[1]
target_device = torch.device(fastvideo_args.device_str)
# TODO(will): add support for other dtypes
return self.load_model(model_path, encoder_config, target_device,
encoder_precision)
def load_model(self,
model_path: str,
model_config,
target_device: torch.device,
dtype: str = "fp16"):
with set_default_torch_dtype(PRECISION_TO_TYPE[dtype]):
with target_device:
architectures = getattr(model_config, "architectures", [])
model_cls, _ = ModelRegistry.resolve_model_cls(architectures)
model = model_cls(model_config)
weights_to_load = {name for name, _ in model.named_parameters()}
model_config.model = model_path
loaded_weights = model.load_weights(
self._get_all_weights(model_config, model))
self.counter_after_loading_weights = time.perf_counter()
logger.info(
"Loading weights took %.2f seconds",
self.counter_after_loading_weights -
self.counter_before_loading_weights)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
# if loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError("Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
# TODO(will): add support for training/finetune
return model.eval()
class ImageEncoderLoader(TextEncoderLoader):
def load(self, model_path: str, architecture: str,
fastvideo_args: FastVideoArgs):
"""Load the text encoders based on the model path, architecture, and inference args."""
# model_config: PretrainedConfig = get_hf_config(
# model=model_path,
# trust_remote_code=fastvideo_args.trust_remote_code,
# revision=fastvideo_args.revision,
# model_override_args=None,
# )
with open(os.path.join(model_path, "config.json")) as f:
model_config = json.load(f)
model_config.pop("_name_or_path", None)
model_config.pop("transformers_version", None)
model_config.pop("torch_dtype", None)
model_config.pop("model_type", None)
logger.info("HF Model config: %s", model_config)
encoder_config = fastvideo_args.image_encoder_config
encoder_config.update_model_arch(model_config)
target_device = torch.device(fastvideo_args.device_str)
# TODO(will): add support for other dtypes
return self.load_model(model_path, encoder_config, target_device,
fastvideo_args.image_encoder_precision)
class ImageProcessorLoader(ComponentLoader):
"""Loader for image processor."""
def load(self, model_path: str, architecture: str,
fastvideo_args: FastVideoArgs):
"""Load the image processor based on the model path, architecture, and inference args."""
logger.info("Loading image processor from %s", model_path)
image_processor = AutoImageProcessor.from_pretrained(model_path, )
logger.info("Loaded image processor: %s",
image_processor.__class__.__name__)
return image_processor
class TokenizerLoader(ComponentLoader):
"""Loader for tokenizers."""
def load(self, model_path: str, architecture: str,
fastvideo_args: FastVideoArgs):
"""Load the tokenizer based on the model path, architecture, and inference args."""
logger.info("Loading tokenizer from %s", model_path)
tokenizer = AutoTokenizer.from_pretrained(
model_path, # "<path to model>/tokenizer"
# in v0, this was same string as encoder_name "ClipTextModel"
# TODO(will): pass these tokenizer kwargs from inference args? Maybe
# other method of config?
padding_size='right',
)
logger.info("Loaded tokenizer: %s", tokenizer.__class__.__name__)
return tokenizer
class VAELoader(ComponentLoader):
"""Loader for VAE."""
def load(self, model_path: str, architecture: str,
fastvideo_args: FastVideoArgs):
"""Load the VAE based on the model path, architecture, and inference args."""
# TODO(will): move this to a constants file
config = get_diffusers_config(model=model_path)
class_name = config.pop("_class_name")
assert class_name is not None, "Model config does not contain a _class_name attribute. Only diffusers format is supported."
config.pop("_diffusers_version")
vae_config = fastvideo_args.vae_config
vae_config.update_model_arch(config)
vae_cls, _ = ModelRegistry.resolve_model_cls(class_name)
vae = vae_cls(vae_config).to(fastvideo_args.device)
# Find all safetensors files
safetensors_list = glob.glob(
os.path.join(str(model_path), "*.safetensors"))
# TODO(PY)
assert len(
safetensors_list
) == 1, f"Found {len(safetensors_list)} safetensors files in {model_path}"
loaded = safetensors_load_file(safetensors_list[0])
vae.load_state_dict(
loaded, strict=False) # We might only load encoder or decoder
dtype = PRECISION_TO_TYPE[fastvideo_args.vae_precision]
vae = vae.eval().to(dtype)
return vae
class TransformerLoader(ComponentLoader):
"""Loader for transformer."""
def load(self, model_path: str, architecture: str,
fastvideo_args: FastVideoArgs):
"""Load the transformer based on the model path, architecture, and inference args."""
config = get_diffusers_config(model=model_path)
cls_name = config.pop("_class_name")
if cls_name is None:
raise ValueError(
"Model config does not contain a _class_name attribute. "
"Only diffusers format is supported.")
config.pop("_diffusers_version")
# Config from Diffusers supersedes fastvideo's model config
dit_config = fastvideo_args.dit_config
dit_config.update_model_arch(config)
model_cls, _ = ModelRegistry.resolve_model_cls(cls_name)
# Find all safetensors files
safetensors_list = glob.glob(
os.path.join(str(model_path), "*.safetensors"))
if not safetensors_list:
raise ValueError(f"No safetensors files found in {model_path}")
logger.info("Loading model from %s safetensors files in %s",
len(safetensors_list), model_path)
# initialize_sequence_parallel_group(fastvideo_args.sp_size)
default_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
# Load the model using FSDP loader
logger.info("Loading model from %s", cls_name)
model = load_fsdp_model(model_cls=model_cls,
init_params={"config": dit_config},
weight_dir_list=safetensors_list,
device=fastvideo_args.device,
cpu_offload=fastvideo_args.use_cpu_offload,
default_dtype=default_dtype)
total_params = sum(p.numel() for p in model.parameters())
logger.info("Loaded model with %.2fB parameters", total_params / 1e9)
dtypes = set(param.dtype for param in model.parameters())
if len(dtypes) > 1:
model = model.to(default_dtype)
model = model.eval()
return model
class SchedulerLoader(ComponentLoader):
"""Loader for scheduler."""
def load(self, model_path: str, architecture: str,
fastvideo_args: FastVideoArgs):
"""Load the scheduler based on the model path, architecture, and inference args."""
config = get_diffusers_config(model=model_path)
class_name = config.pop("_class_name")
assert class_name is not None, "Model config does not contain a _class_name attribute. Only diffusers format is supported."
config.pop("_diffusers_version")
scheduler_cls, _ = ModelRegistry.resolve_model_cls(class_name)
scheduler = scheduler_cls(**config)
if fastvideo_args.flow_shift is not None:
scheduler.set_shift(fastvideo_args.flow_shift)
return scheduler
class GenericComponentLoader(ComponentLoader):
"""Generic loader for components that don't have a specific loader."""
def __init__(self, library="transformers") -> None:
super().__init__()
self.library = library
def load(self, model_path: str, architecture: str,
fastvideo_args: FastVideoArgs):
"""Load a generic component based on the model path, architecture, and inference args."""
logger.warning("Using generic loader for %s with library %s",
model_path, self.library)
if self.library == "transformers":
from transformers import AutoModel
model = AutoModel.from_pretrained(
model_path,
trust_remote_code=fastvideo_args.trust_remote_code,
revision=fastvideo_args.revision,
)
logger.info("Loaded generic transformers model: %s",
model.__class__.__name__)
return model
elif self.library == "diffusers":
logger.warning(
"Generic loading for diffusers components is not fully implemented"
)
model_config = get_diffusers_config(model=model_path)
logger.info("Diffusers Model config: %s", model_config)
# This is a placeholder - in a real implementation, you'd need to handle this properly
return None
else:
raise ValueError(f"Unsupported library: {self.library}")
class PipelineComponentLoader:
"""
Utility class for loading pipeline components.
This replaces the chain of if-else statements in load_pipeline_module.
"""
@staticmethod
def load_module(module_name: str, component_model_path: str,
transformers_or_diffusers: str, architecture: str,
fastvideo_args: FastVideoArgs):
"""
Load a pipeline module.
Args:
module_name: Name of the module (e.g., "vae", "text_encoder", "transformer", "scheduler")
component_model_path: Path to the component model
transformers_or_diffusers: Whether the module is from transformers or diffusers
architecture: Architecture of the component model
fastvideo_args: Inference arguments
Returns:
The loaded module
"""
logger.info(
"Loading %s using %s from %s",
module_name,
transformers_or_diffusers,
component_model_path,
)
# Get the appropriate loader for this module type
loader = ComponentLoader.for_module_type(module_name,
transformers_or_diffusers)
# Load the module
return loader.load(component_model_path, architecture, fastvideo_args)
# SPDX-License-Identifier: Apache-2.0
# Adapted from torchtune
# Copyright 2024 The TorchTune Authors.
# Copyright 2025 The FastVideo Authors.
import contextlib
import re
from collections import defaultdict
from itertools import chain
from typing import (Any, Callable, DefaultDict, Dict, Generator, Hashable, List,
Optional, Tuple, Type)
import torch
from torch import nn
from torch.distributed import DeviceMesh, init_device_mesh
from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard
from torch.distributed._tensor import distribute_tensor
from torch.nn.modules.module import _IncompatibleKeys
from fastvideo.v1.distributed.parallel_state import (
get_sequence_model_parallel_world_size)
from fastvideo.v1.models.loader.weight_utils import safetensors_weights_iterator
# TODO(PY): move this to utils elsewhere
@contextlib.contextmanager
def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]:
"""
Context manager to set torch's default dtype.
Args:
dtype (torch.dtype): The desired default dtype inside the context manager.
Returns:
ContextManager: context manager for setting default dtype.
Example:
>>> with set_default_dtype(torch.bfloat16):
>>> x = torch.tensor([1, 2, 3])
>>> x.dtype
torch.bfloat16
"""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
try:
yield
finally:
torch.set_default_dtype(old_dtype)
def get_param_names_mapping(
mapping_dict: Dict[str, str]) -> Callable[[str], tuple[str, Any, Any]]:
"""
Creates a mapping function that transforms parameter names using regex patterns.
Args:
mapping_dict (Dict[str, str]): Dictionary mapping regex patterns to replacement patterns
param_name (str): The parameter name to be transformed
Returns:
Callable[[str], str]: A function that maps parameter names from source to target format
"""
def mapping_fn(name: str) -> tuple[str, Any, Any]:
# Try to match and transform the name using the regex patterns in mapping_dict
for pattern, replacement in mapping_dict.items():
match = re.match(pattern, name)
if match:
merge_index = None
total_splitted_params = None
if isinstance(replacement, tuple):
merge_index = replacement[1]
total_splitted_params = replacement[2]
replacement = replacement[0]
name = re.sub(pattern, replacement, name)
return name, merge_index, total_splitted_params
# If no pattern matches, return the original name
return name, None, None
return mapping_fn
# TODO(PY): add compile option
def load_fsdp_model(
model_cls: Type[nn.Module],
init_params: Dict[str, Any],
weight_dir_list: List[str],
device: torch.device,
cpu_offload: bool = False,
default_dtype: Optional[torch.dtype] = torch.bfloat16,
) -> torch.nn.Module:
with set_default_dtype(default_dtype), torch.device("meta"):
model = model_cls(**init_params)
device_mesh = init_device_mesh(
"cuda",
mesh_shape=(get_sequence_model_parallel_world_size(), ),
mesh_dim_names=("dp", ),
)
shard_model(model,
cpu_offload=cpu_offload,
reshard_after_forward=True,
dp_mesh=device_mesh["dp"])
weight_iterator = safetensors_weights_iterator(weight_dir_list)
param_names_mapping_fn = get_param_names_mapping(model._param_names_mapping)
load_fsdp_model_from_full_model_state_dict(
model,
weight_iterator,
device,
strict=True,
cpu_offload=cpu_offload,
param_names_mapping=param_names_mapping_fn,
)
for n, p in chain(model.named_parameters(), model.named_buffers()):
if p.is_meta:
raise RuntimeError(
f"Unexpected param or buffer {n} on meta device.")
for p in model.parameters():
p.requires_grad = False
return model
def shard_model(
model,
*,
cpu_offload: bool,
reshard_after_forward: bool = True,
dp_mesh: Optional[DeviceMesh] = None,
) -> None:
"""
Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API.
This method will over the model's named modules from the bottom-up and apply shard modules
based on whether they meet any of the criteria from shard_conditions.
Args:
model (TransformerDecoder): Model to shard with FSDP.
shard_conditions (List[Callable[[str, nn.Module], bool]]): A list of functions to determine
which modules to shard with FSDP. Each function should take module name (relative to root)
and the module itself, returning True if FSDP should shard the module and False otherwise.
If any of shard_conditions return True for a given module, it will be sharded by FSDP.
cpu_offload (bool): If set to True, FSDP will offload parameters, gradients, and optimizer
states to CPU.
reshard_after_forward (bool): Whether to reshard parameters and buffers after
the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy
from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.
dp_mesh (Optional[DeviceMesh]): Device mesh to use for FSDP sharding under multiple parallelism.
Default to None.
Raises:
ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered.
"""
fsdp_kwargs = {
"reshard_after_forward": reshard_after_forward,
"mesh": dp_mesh
}
if cpu_offload:
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()
# Shard the model with FSDP, iterating in reverse to start with
# lowest-level modules first
num_layers_sharded = 0
for n, m in reversed(list(model.named_modules())):
if any([
shard_condition(n, m)
for shard_condition in model._fsdp_shard_conditions
]):
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1
if num_layers_sharded == 0:
raise ValueError(
"No layer modules were sharded. Please check if shard conditions are working as expected."
)
# Finally shard the entire model to account for any stragglers
fully_shard(model, **fsdp_kwargs)
# TODO(PY): device mesh for cfg parallel
def load_fsdp_model_from_full_model_state_dict(
model: torch.nn.Module,
full_sd_iterator: Generator[Tuple[str, torch.Tensor], None, None],
device: torch.device,
strict: bool = False,
cpu_offload: bool = False,
param_names_mapping: Optional[Callable[[str], tuple[str, Any, Any]]] = None,
) -> _IncompatibleKeys:
"""
Converting full state dict into a sharded state dict
and loading it into FSDP model
Args:
model (FSDPModule): Model to generate fully qualified names for cpu_state_dict
full_sd_iterator (Generator): an iterator yielding (param_name, tensor) pairs
device (torch.device): device used to move full state dict tensors
strict (bool): flag to check if to load the model in strict mode
cpu_offload (bool): flag to check if offload to CPU is enabled
param_names_mapping (Optional[Callable[[str], str]]): a function that maps full param name to sharded param name
Returns:
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
Raises:
NotImplementedError: If got FSDP with more than 1D.
"""
meta_sharded_sd = model.state_dict()
sharded_sd = {}
to_merge_params: DefaultDict[Hashable, Dict[Any, Any]] = defaultdict(dict)
for source_param_name, full_tensor in full_sd_iterator:
assert param_names_mapping is not None
target_param_name, merge_index, num_params_to_merge = param_names_mapping(
source_param_name)
if merge_index is not None:
to_merge_params[target_param_name][merge_index] = full_tensor
if len(to_merge_params[target_param_name]) == num_params_to_merge:
# cat at dim=1 according to the merge_index order
sorted_tensors = [
to_merge_params[target_param_name][i]
for i in range(num_params_to_merge)
]
full_tensor = torch.cat(sorted_tensors, dim=0)
del to_merge_params[target_param_name]
else:
continue
sharded_meta_param = meta_sharded_sd.get(target_param_name)
if sharded_meta_param is None:
raise ValueError(
f"Parameter {source_param_name}-->{target_param_name} not found in meta sharded state dict"
)
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)
if not hasattr(sharded_meta_param, "device_mesh"):
# In cases where parts of the model aren't sharded, some parameters will be plain tensors
sharded_tensor = full_tensor
else:
sharded_tensor = distribute_tensor(
full_tensor,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
)
if cpu_offload:
sharded_tensor = sharded_tensor.cpu()
sharded_sd[target_param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
return model.load_state_dict(sharded_sd, strict=strict, assign=True)
# SPDX-License-Identifier: Apache-2.0
"""Utilities for selecting and loading models."""
import contextlib
import torch
from fastvideo.v1.logger import init_logger
logger = init_logger(__name__)
@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/model_loader/weight_utils.py
"""Utilities for downloading and initializing model weights."""
import fnmatch
import hashlib
import json
import os
import tempfile
import time
from collections import defaultdict
from pathlib import Path
from typing import Generator, List, Optional, Tuple, Union
import filelock
import huggingface_hub.constants
import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import safe_open
from tqdm.auto import tqdm
from fastvideo.v1.logger import init_logger
logger = init_logger(__name__)
# use system-level temp directory for file locks, so that multiple users
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
temp_dir = tempfile.gettempdir()
def enable_hf_transfer() -> None:
"""automatically activates hf_transfer
"""
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
try:
# enable hf hub transfer if available
import hf_transfer # type: ignore # noqa
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
except ImportError:
pass
enable_hf_transfer()
class DisabledTqdm(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True)
def get_lock(model_name_or_path: Union[str, Path],
cache_dir: Optional[str] = None):
lock_dir = cache_dir or temp_dir
model_name_or_path = str(model_name_or_path)
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
model_name = model_name_or_path.replace("/", "-")
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
# add hash to avoid conflict with old users' lock files
lock_file_name = hash_name + model_name + ".lock"
# mode 0o666 is required for the filelock to be shared across users
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)
return lock
def _shared_pointers(tensors):
ptrs = defaultdict(list)
for k, v in tensors.items():
ptrs[v.data_ptr()].append(k)
failing = []
for _, names in ptrs.items():
if len(names) > 1:
failing.append(names)
return failing
def download_weights_from_hf(
model_name_or_path: str,
cache_dir: Optional[str],
allow_patterns: List[str],
revision: Optional[str] = None,
ignore_patterns: Optional[Union[str, List[str]]] = None,
) -> str:
"""Download model weights from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
allow_patterns (List[str]): The allowed patterns for the
weight files. Files matched by any of the patterns will be
downloaded.
revision (Optional[str]): The revision of the model.
ignore_patterns (Optional[Union[str, List[str]]]): The patterns to
filter out the weight files. Files matched by any of the patterns
will be ignored.
Returns:
str: The path to the downloaded model weights.
"""
local_only = huggingface_hub.constants.HF_HUB_OFFLINE
if not local_only:
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
# depending on what is available we download different things
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
break
logger.info("Using model weights format %s", allow_patterns)
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
start_time = time.perf_counter()
hf_folder: str = snapshot_download(
model_name_or_path,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
cache_dir=cache_dir,
tqdm_class=DisabledTqdm,
revision=revision,
local_files_only=local_only,
)
time_taken = time.perf_counter() - start_time
if time_taken > 0.5:
logger.info("Time spent downloading weights for %s: %.6f seconds",
model_name_or_path, time_taken)
return hf_folder
def download_safetensors_index_file_from_hf(
model_name_or_path: str,
index_file: str,
cache_dir: Optional[str],
revision: Optional[str] = None,
) -> None:
"""Download hf safetensors index file from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
revision (Optional[str]): The revision of the model.
"""
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
try:
# Download the safetensors index file.
hf_hub_download(
repo_id=model_name_or_path,
filename=index_file,
cache_dir=cache_dir,
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
)
# If file not found on remote or locally, we should not fail since
# only some models will have index_file.
except huggingface_hub.utils.EntryNotFoundError:
logger.info("No %s found in remote.", index_file)
except huggingface_hub.utils.LocalEntryNotFoundError:
logger.info("No %s found in local cache.", index_file)
# For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks.
# So, we use the index_file to
# look up which safetensors files should be used.
def filter_duplicate_safetensors_files(hf_weights_files: List[str],
hf_folder: str,
index_file: str) -> List[str]:
# model.safetensors.index.json is a mapping from keys in the
# torch state_dict to safetensors file holding that weight.
index_file_name = os.path.join(hf_folder, index_file)
if not os.path.isfile(index_file_name):
return hf_weights_files
# Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use.
with open(index_file_name) as f:
weight_map = json.load(f)["weight_map"]
weight_files_in_index = set()
for weight_name in weight_map:
weight_files_in_index.add(
os.path.join(hf_folder, weight_map[weight_name]))
# Filter out any fields that are not found in the index file.
hf_weights_files = [
f for f in hf_weights_files if f in weight_files_in_index
]
return hf_weights_files
def filter_files_not_needed_for_inference(
hf_weights_files: List[str]) -> List[str]:
"""
Exclude files that are not needed for inference.
See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
"""
blacklist = [
"training_args.bin",
"optimizer.bin",
"optimizer.pt",
"scheduler.pt",
"scaler.pt",
]
hf_weights_files = [
f for f in hf_weights_files
if not any(f.endswith(x) for x in blacklist)
]
return hf_weights_files
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
def safetensors_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
def pt_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files."""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
for bin_file in tqdm(
hf_weights_files,
desc="Loading pt checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location="cpu", weights_only=True)
yield from state.items()
del state
def default_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
try:
if param.numel() == 1 and loaded_weight.numel() == 1:
# Sometimes scalar values aren't considered tensors with shapes
# so if both param and loaded_weight are a scalar,
# "broadcast" instead of copy
param.data.fill_(loaded_weight.item())
else:
assert param.size() == loaded_weight.size(), (
f"Attempted to load weight ({loaded_weight.size()}) "
f"into parameter ({param.size()})")
param.data.copy_(loaded_weight)
except Exception:
# NOTE: This exception is added for the purpose of setting breakpoint to
# debug weight loading issues.
raise
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
"""Remap the name of FP8 k/v_scale parameters.
This function handles the remapping of FP8 k/v_scale parameter names.
It detects if the given name ends with a suffix and attempts to remap
it to the expected name format in the model. If the remapped name is not
found in the params_dict, a warning is printed and None is returned.
Args:
name (str): The original loaded checkpoint parameter name.
params_dict (dict): Dictionary containing the model's named parameters.
Returns:
str: The remapped parameter name if successful, or the original name
if no remapping is needed.
None: If the remapped name is not found in params_dict.
"""
if name.endswith(".kv_scale"):
logger.warning_once(
"DEPRECATED. Found kv_scale in the checkpoint. "
"This format is deprecated in favor of separate k_scale and "
"v_scale tensors and will be removed in a future release. "
"Functionally, we will remap kv_scale to k_scale and duplicate "
"k_scale to v_scale")
# NOTE: we remap the deprecated kv_scale to k_scale
remapped_name = name.replace(".kv_scale", ".attn.k_scale")
if remapped_name not in params_dict:
logger.warning_once(
f"Found kv_scale in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_name}). kv_scale is "
"not loaded.")
return None
return remapped_name
possible_scale_names = [".k_scale", ".v_scale"]
modelopt_scale_names = [
".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"
]
for scale_name in possible_scale_names:
if name.endswith(scale_name):
if any(mo_scale_name in name
for mo_scale_name in modelopt_scale_names):
remapped_name = name.replace(
f".self_attn.{scale_name[1]}_proj{scale_name}",
f".self_attn.attn{scale_name}")
else:
remapped_name = name.replace(scale_name, f".attn{scale_name}")
if remapped_name not in params_dict:
logger.warning_once(
f"Found {scale_name} in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_name}). {scale_name} is "
"not loaded.")
return None
return remapped_name
# If there were no matches, return the untouched param name
return name
# SPDX-License-Identifier: Apache-2.0
# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/parameter.py
from fractions import Fraction
from typing import Any, Callable, Tuple, Union
import torch
from torch.nn import Parameter
from fastvideo.v1.distributed import get_tensor_model_parallel_rank
from fastvideo.v1.logger import init_logger
from fastvideo.v1.models.utils import _make_synced_weight_loader
logger = init_logger(__name__)
class BasevLLMParameter(Parameter):
"""
Base parameter for vLLM linear layers. Extends the torch.nn.parameter
by taking in a linear weight loader. Will copy the loaded weight
into the parameter when the provided weight loader is called.
"""
def __new__(cls, data: torch.Tensor, **kwargs):
return super().__new__(cls, data=data, requires_grad=False)
def __init__(self, data: torch.Tensor, weight_loader: Callable):
"""
Initialize the BasevLLMParameter
:param data: torch tensor with the parameter data
:param weight_loader: weight loader callable
:returns: a torch.nn.parameter
"""
# During weight loading, we often do something like:
# narrowed_tensor = param.data.narrow(0, offset, len)
# narrowed_tensor.copy_(real_weight)
# expecting narrowed_tensor and param.data to share the same storage.
# However, on TPUs, narrowed_tensor will lazily propagate to the base
# tensor, which is param.data, leading to the redundant memory usage.
# This sometimes causes OOM errors during model loading. To avoid this,
# we sync the param tensor after its weight loader is called.
from fastvideo.v1.platforms import current_platform
if current_platform.is_tpu():
weight_loader = _make_synced_weight_loader(weight_loader)
self._weight_loader = weight_loader
@property
def weight_loader(self):
return self._weight_loader
def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
cond1 = self.data.ndim == 1 and self.data.numel() == 1
cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
return (cond1 and cond2)
def _assert_and_load(self, loaded_weight: torch.Tensor) -> None:
assert (self.data.shape == loaded_weight.shape
or self._is_1d_and_scalar(loaded_weight))
self.data.copy_(loaded_weight)
def load_column_parallel_weight(self, loaded_weight: torch.Tensor) -> None:
self._assert_and_load(loaded_weight)
def load_row_parallel_weight(self, loaded_weight: torch.Tensor) -> None:
self._assert_and_load(loaded_weight)
def load_merged_column_weight(self, loaded_weight: torch.Tensor,
**kwargs) -> None:
self._assert_and_load(loaded_weight)
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs) -> None:
self._assert_and_load(loaded_weight)
class _ColumnvLLMParameter(BasevLLMParameter):
"""
Private class defining weight loading functionality
(load_merged_column_weight, load_qkv_weight)
for parameters being loaded into linear layers with column
parallelism. This includes QKV and MLP layers which are
not already fused on disk. Requires an output dimension
to be defined. Called within the weight loader of
each of the column parallel linear layers.
"""
def __init__(self, output_dim: int, **kwargs):
self._output_dim = output_dim
super().__init__(**kwargs)
@property
def output_dim(self):
return self._output_dim
def load_column_parallel_weight(self, loaded_weight: torch.Tensor) -> None:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(self.output_dim,
tp_rank * shard_size, shard_size)
assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)
def load_merged_column_weight(self, loaded_weight: torch.Tensor,
**kwargs) -> None:
shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size")
if shard_offset is None or shard_size is None:
raise ValueError("shard_offset and shard_size must be provided")
if isinstance(
self,
(PackedColumnParameter,
PackedvLLMParameter)) and self.packed_dim == self.output_dim:
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
shard_offset=shard_offset, shard_size=shard_size)
param_data = self.data
tp_rank = get_tensor_model_parallel_rank()
param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size)
loaded_weight = loaded_weight.narrow(self.output_dim,
tp_rank * shard_size, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs) -> None:
shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size")
shard_id = kwargs.get("shard_id")
num_heads = kwargs.get("num_heads")
assert shard_offset is not None
assert shard_size is not None
assert shard_id is not None
assert num_heads is not None
if isinstance(
self,
(PackedColumnParameter,
PackedvLLMParameter)) and self.output_dim == self.packed_dim:
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
shard_offset=shard_offset, shard_size=shard_size)
param_data = self.data
tp_rank = get_tensor_model_parallel_rank()
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size)
loaded_weight = loaded_weight.narrow(self.output_dim,
shard_id * shard_size, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class RowvLLMParameter(BasevLLMParameter):
"""
Parameter class defining weight_loading functionality
(load_row_parallel_weight) for parameters being loaded
into linear layers with row parallel functionality.
Requires an input_dim to be defined.
"""
def __init__(self, input_dim: int, **kwargs):
self._input_dim = input_dim
super().__init__(**kwargs)
@property
def input_dim(self):
return self._input_dim
def load_row_parallel_weight(self, loaded_weight: torch.Tensor) -> None:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(self.input_dim,
tp_rank * shard_size, shard_size)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)
class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
"""
Parameter class for linear layer weights. Uses both column and
row parallelism.
"""
pass
class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
"""
Parameter class for weight scales loaded for weights with
grouped quantization. Uses both column and row parallelism.
"""
pass
class ChannelQuantScaleParameter(_ColumnvLLMParameter):
"""
Parameter class for weight scales loaded for weights with
channel-wise quantization. Equivalent to _ColumnvLLMParameter.
"""
pass
class PerTensorScaleParameter(BasevLLMParameter):
"""
Parameter class for scales where the number of scales is
equivalent to the number of logical matrices in fused linear
layers (e.g. for QKV, there are 3 scales loaded from disk).
This is relevant to weights with per-tensor quantization.
Adds functionality to map the scalers to a shard during
weight loading.
Note: additional parameter manipulation may be handled
for each quantization config specifically, within
process_weights_after_loading
"""
def __init__(self, **kwargs):
self.qkv_idxs = {"q": 0, "k": 1, "v": 2}
super().__init__(**kwargs)
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id
# if not int, assume shard_id for qkv
# map to int and return
assert isinstance(shard_id, str)
assert shard_id in self.qkv_idxs
return self.qkv_idxs[shard_id]
# For row parallel layers, no sharding needed
# load weight into parameter as is
def load_row_parallel_weight(self, *args, **kwargs) -> None:
super().load_row_parallel_weight(*args, **kwargs)
def load_merged_column_weight(self, *args, **kwargs) -> None:
self._load_into_shard_id(*args, **kwargs)
def load_qkv_weight(self, *args, **kwargs) -> None:
self._load_into_shard_id(*args, **kwargs)
def load_column_parallel_weight(self, *args, **kwargs) -> None:
super().load_row_parallel_weight(*args, **kwargs)
def _load_into_shard_id(self, loaded_weight: torch.Tensor,
shard_id: Union[str, int], **kwargs):
"""
Slice the parameter data based on the shard id for
loading.
"""
param_data = self.data
shard_id = self._shard_id_as_int(shard_id)
# AutoFP8 scales do not have a shape
# compressed-tensors scales do have a shape
if len(loaded_weight.shape) != 0:
assert loaded_weight.shape[0] == 1
loaded_weight = loaded_weight[0]
param_data = param_data[shard_id]
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class PackedColumnParameter(_ColumnvLLMParameter):
"""
Parameter for model parameters which are packed on disk
and support column parallelism only. See PackedvLLMParameter
for more details on the packed properties.
"""
def __init__(self, packed_factor: Union[int, Fraction], packed_dim: int,
**kwargs):
self._packed_factor = packed_factor
self._packed_dim = packed_dim
super().__init__(**kwargs)
@property
def packed_dim(self):
return self._packed_dim
@property
def packed_factor(self):
return self._packed_factor
def adjust_shard_indexes_for_packing(self, shard_size,
shard_offset) -> Tuple[Any, Any]:
return _adjust_shard_indexes_for_packing(
shard_size=shard_size,
shard_offset=shard_offset,
packed_factor=self.packed_factor)
class PackedvLLMParameter(ModelWeightParameter):
"""
Parameter for model weights which are packed on disk.
Example: GPTQ Marlin weights are int4 or int8, packed into int32.
Extends the ModelWeightParameter to take in the
packed factor, the packed dimension, and optionally, marlin
tile size for marlin kernels. Adjusts the shard_size and
shard_offset for fused linear layers model weight loading
by accounting for packing and optionally, marlin tile size.
"""
def __init__(self, packed_factor: Union[int, Fraction], packed_dim: int,
**kwargs):
self._packed_factor = packed_factor
self._packed_dim = packed_dim
super().__init__(**kwargs)
@property
def packed_dim(self):
return self._packed_dim
@property
def packed_factor(self):
return self._packed_factor
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
return _adjust_shard_indexes_for_packing(
shard_size=shard_size,
shard_offset=shard_offset,
packed_factor=self.packed_factor)
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
"""
Parameter class for weight scales loaded for weights with
block-wise quantization. Uses both column and row parallelism.
"""
pass
def permute_param_layout_(param: BasevLLMParameter, input_dim: int,
output_dim: int, **kwargs) -> BasevLLMParameter:
"""
Permute a parameter's layout to the specified input and output dimensions,
useful for forcing the parameter into a known layout, for example, if I need
a packed (quantized) weight matrix to be in the layout
{input_dim = 0, output_dim = 1, packed_dim = 0}
then I can call:
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
to ensure x is in the correct layout (permuting it to the correct layout if
required, asserting if it cannot get it to the correct layout)
"""
curr_input_dim = getattr(param, "input_dim", None)
curr_output_dim = getattr(param, "output_dim", None)
if curr_input_dim is None or curr_output_dim is None:
assert param.data.dim() == 2,\
"permute_param_layout_ only supports 2D parameters when either "\
"input_dim or output_dim is not set"
# if one of the dimensions is not set, set it to the opposite of the other
# we can only do this since we asserted the parameter is 2D above
if curr_input_dim is None:
assert curr_output_dim is not None,\
"either input or output dim must be set"
curr_input_dim = (curr_output_dim + 1) % 2
if curr_output_dim is None:
assert curr_input_dim is not None,\
"either input or output dim must be set"
curr_output_dim = (curr_input_dim + 1) % 2
# create permutation from the current layout to the layout with
# self.input_dim at input_dim and self.output_dim at output_dim preserving
# other dimensions
perm = [
i for i in range(param.data.dim())
if i not in [curr_input_dim, curr_output_dim]
]
perm.insert(input_dim, curr_input_dim)
perm.insert(output_dim, curr_output_dim)
if "packed_dim" in kwargs:
assert hasattr(param, "packed_dim") and\
param.packed_dim == perm[kwargs["packed_dim"]],\
"permute_param_layout_ currently doesn't support repacking"
param.data = param.data.permute(*perm)
if hasattr(param, "_input_dim"):
param._input_dim = input_dim
if hasattr(param, "_output_dim"):
param._output_dim = output_dim
if "packed_dim" in kwargs and hasattr(param, "_packed_dim"):
param._packed_dim = kwargs["packed_dim"]
return param
def _adjust_shard_indexes_for_packing(shard_size, shard_offset,
packed_factor) -> Tuple[Any, Any]:
shard_size = shard_size // packed_factor
shard_offset = shard_offset // packed_factor
return shard_size, shard_offset
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/registry.py
import importlib
import os
import pickle
import subprocess
import sys
import tempfile
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import lru_cache
from typing import (AbstractSet, Callable, Dict, List, NoReturn, Optional,
Tuple, Type, TypeVar, Union, cast)
import cloudpickle
from torch import nn
from fastvideo.v1.logger import logger
# huggingface class name: (component_name, fastvideo module name, fastvideo class name)
_TEXT_TO_VIDEO_DIT_MODELS = {
"HunyuanVideoTransformer3DModel":
("dits", "hunyuanvideo", "HunyuanVideoTransformer3DModel"),
"WanTransformer3DModel": ("dits", "wanvideo", "WanTransformer3DModel"),
}
_IMAGE_TO_VIDEO_DIT_MODELS = {
# "HunyuanVideoTransformer3DModel": ("dits", "hunyuanvideo", "HunyuanVideoDiT"),
"WanTransformer3DModel": ("dits", "wanvideo", "WanTransformer3DModel"),
}
_TEXT_ENCODER_MODELS = {
"CLIPTextModel": ("encoders", "clip", "CLIPTextModel"),
"LlamaModel": ("encoders", "llama", "LlamaModel"),
"UMT5EncoderModel": ("encoders", "t5", "UMT5EncoderModel"),
}
_IMAGE_ENCODER_MODELS: dict[str, tuple] = {
# "HunyuanVideoTransformer3DModel": ("image_encoder", "hunyuanvideo", "HunyuanVideoImageEncoder"),
"CLIPVisionModelWithProjection": ("encoders", "clip", "CLIPVisionModel"),
}
_VAE_MODELS = {
"AutoencoderKLHunyuanVideo":
("vaes", "hunyuanvae", "AutoencoderKLHunyuanVideo"),
"AutoencoderKLWan": ("vaes", "wanvae", "AutoencoderKLWan"),
}
_SCHEDULERS = {
"FlowMatchEulerDiscreteScheduler":
("schedulers", "scheduling_flow_match_euler_discrete",
"FlowMatchDiscreteScheduler"),
"UniPCMultistepScheduler":
("schedulers", "scheduling_unipc_multistep", "UniPCMultistepScheduler"),
}
_FAST_VIDEO_MODELS = {
**_TEXT_TO_VIDEO_DIT_MODELS,
**_IMAGE_TO_VIDEO_DIT_MODELS,
**_TEXT_ENCODER_MODELS,
**_IMAGE_ENCODER_MODELS,
**_VAE_MODELS,
**_SCHEDULERS,
}
_SUBPROCESS_COMMAND = [
sys.executable, "-m", "fastvideo.v1.models.dits.registry"
]
_T = TypeVar("_T")
@dataclass(frozen=True)
class _ModelInfo:
architecture: str
@staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
return _ModelInfo(architecture=model.__name__, )
class _BaseRegisteredModel(ABC):
@abstractmethod
def inspect_model_cls(self) -> _ModelInfo:
raise NotImplementedError
@abstractmethod
def load_model_cls(self) -> Type[nn.Module]:
raise NotImplementedError
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
"""
Represents a model that has already been imported in the main process.
"""
interfaces: _ModelInfo
model_cls: Type[nn.Module]
@staticmethod
def from_model_cls(model_cls: Type[nn.Module]):
return _RegisteredModel(
interfaces=_ModelInfo.from_model_cls(model_cls),
model_cls=model_cls,
)
def inspect_model_cls(self) -> _ModelInfo:
return self.interfaces
def load_model_cls(self) -> Type[nn.Module]:
return self.model_cls
def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
# NOTE: We use a temporary directory instead of a temporary file to avoid
# issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
with tempfile.TemporaryDirectory() as tempdir:
output_filepath = os.path.join(tempdir, "registry_output.tmp")
# `cloudpickle` allows pickling lambda functions directly
input_bytes = cloudpickle.dumps((fn, output_filepath))
# cannot use `sys.executable __file__` here because the script
# contains relative imports
returned = subprocess.run(_SUBPROCESS_COMMAND,
input=input_bytes,
capture_output=True)
# check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(f"Error raised in subprocess:\n"
f"{returned.stderr.decode()}") from e
with open(output_filepath, "rb") as f:
return cast(_T, pickle.load(f))
@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
"""
Represents a model that has not been imported in the main process.
"""
module_name: str
component_name: str
class_name: str
# Performed in another process to avoid initializing CUDA
def inspect_model_cls(self) -> _ModelInfo:
return _run_in_subprocess(
lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
def load_model_cls(self) -> Type[nn.Module]:
mod = importlib.import_module(self.module_name)
return cast(Type[nn.Module], getattr(mod, self.class_name))
@lru_cache(maxsize=128)
def _try_load_model_cls(
model_arch: str,
model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
from fastvideo.v1.platforms import current_platform
current_platform.verify_model_arch(model_arch)
try:
return model.load_model_cls()
except Exception:
logger.exception("Error in loading model architecture '%s'", model_arch)
return None
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
model_arch: str,
model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
try:
return model.inspect_model_cls()
except Exception:
logger.exception("Error in inspecting model architecture '%s'",
model_arch)
return None
@dataclass
class _ModelRegistry:
# Keyed by model_arch
models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
def get_supported_archs(self) -> AbstractSet[str]:
return self.models.keys()
def register_model(
self,
model_arch: str,
model_cls: Union[Type[nn.Module], str],
) -> None:
"""
Register an external model to be used in vLLM.
:code:`model_cls` can be either:
- A :class:`torch.nn.Module` class directly referencing the model.
- A string in the format :code:`<module>:<class>` which can be used to
lazily import the model. This is useful to avoid initializing CUDA
when importing the model and thus the related error
:code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
"""
if model_arch in self.models:
logger.warning(
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch, model_cls)
if isinstance(model_cls, str):
split_str = model_cls.split(":")
if len(split_str) != 2:
msg = "Expected a string in the format `<module>:<class>`"
raise ValueError(msg)
model = _LazyRegisteredModel(*split_str)
else:
model = _RegisteredModel.from_model_cls(model_cls)
self.models[model_arch] = model
def _raise_for_unsupported(self, architectures: List[str]) -> NoReturn:
all_supported_archs = self.get_supported_archs()
if any(arch in all_supported_archs for arch in architectures):
raise ValueError(
f"Model architectures {architectures} failed "
"to be inspected. Please check the logs for more details.")
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {all_supported_archs}")
def _try_load_model_cls(self, model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch not in self.models:
return None
return _try_load_model_cls(model_arch, self.models[model_arch])
def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
if model_arch not in self.models:
return None
return _try_inspect_model_cls(model_arch, self.models[model_arch])
def _normalize_archs(
self,
architectures: Union[str, List[str]],
) -> List[str]:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
normalized_arch = []
for model in architectures:
if model not in self.models:
model = "TransformersModel"
normalized_arch.append(model)
return normalized_arch
def inspect_model_cls(
self,
architectures: Union[str, List[str]],
) -> Tuple[_ModelInfo, str]:
architectures = self._normalize_archs(architectures)
for arch in architectures:
model_info = self._try_inspect_model_cls(arch)
if model_info is not None:
return (model_info, arch)
return self._raise_for_unsupported(architectures)
def resolve_model_cls(
self,
architectures: Union[str, List[str]],
) -> Tuple[Type[nn.Module], str]:
architectures = self._normalize_archs(architectures)
for arch in architectures:
model_cls = self._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
return self._raise_for_unsupported(architectures)
ModelRegistry = _ModelRegistry({
model_arch:
_LazyRegisteredModel(
module_name=f"fastvideo.v1.models.{component_name}.{mod_relname}",
component_name=component_name,
class_name=cls_name,
)
for model_arch, (component_name, mod_relname,
cls_name) in _FAST_VIDEO_MODELS.items()
})
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Optional, Tuple, Union
import torch
from diffusers.utils import BaseOutput
class BaseScheduler(ABC):
timesteps: torch.Tensor
order: int
def __init__(self, *args, **kwargs) -> None:
# Check if subclass has defined all required properties
required_attributes = ['timesteps', 'order']
for attr in required_attributes:
if not hasattr(self, attr):
raise AttributeError(
f"Subclasses of BaseScheduler must define '{attr}' property"
)
@abstractmethod
def set_shift(self, shift: float) -> None:
pass
@abstractmethod
def set_timesteps(self, *args, **kwargs) -> None:
pass
@abstractmethod
def scale_model_input(self,
sample: torch.Tensor,
timestep: Optional[int] = None) -> torch.Tensor:
pass
@abstractmethod
def step(
self,
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[BaseOutput, Tuple]:
pass
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Modified from diffusers==0.29.2
#
# ==============================================================================
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import BaseOutput, logging
from fastvideo.v1.models.schedulers.base import BaseScheduler
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class FlowMatchDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin, BaseScheduler):
"""
Euler scheduler.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
shift (`float`, defaults to 1.0):
The shift value for the timestep schedule.
reverse (`bool`, defaults to `True`):
Whether to reverse the timestep schedule.
"""
_compatibles: list[Any] = []
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0,
reverse: bool = True,
solver: str = "euler",
n_tokens: Optional[int] = None,
**kwargs,
):
sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
if not reverse:
sigmas = sigmas.flip(0)
self.sigmas = sigmas
# the value fed to model
self.timesteps = (sigmas[:-1] *
num_train_timesteps).to(dtype=torch.float32)
self._step_index: int | None = None
self._begin_index = 0
self.supported_solver = ["euler"]
if solver not in self.supported_solver:
raise ValueError(
f"Solver {solver} not supported. Supported solvers: {self.supported_solver}"
)
BaseScheduler.__init__(self)
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
n_tokens: int = 0,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
n_tokens (`int`, *optional*):
Number of tokens in the input sequence.
"""
self.num_inference_steps = num_inference_steps
sigmas = torch.linspace(1, 0, num_inference_steps + 1)
sigmas = self.sd3_time_shift(sigmas)
if not self.config.reverse:
sigmas = 1 - sigmas
self.sigmas = sigmas
self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(
dtype=torch.float32, device=device)
# Reset step index
self._step_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None) -> int:
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
idx: int = indices[pos].item()
return idx
def set_shift(self, shift: float) -> None:
self.config.shift = shift
def _init_step_index(self, timestep) -> None:
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def scale_model_input(self,
sample: torch.Tensor,
timestep: Optional[int] = None) -> torch.Tensor:
return sample
def sd3_time_shift(self, t: torch.Tensor):
return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
return_dict: bool = True,
**kwargs,
) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
n_tokens (`int`, *optional*):
Number of tokens in the input sequence.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
raise ValueError((
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."), )
if self.step_index is None:
self._init_step_index(timestep)
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
assert self.step_index is not None
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
if self.config.solver == "euler":
prev_sample = sample + model_output.to(torch.float32) * dt
else:
raise ValueError(
f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}"
)
# upon completion increase step index by one
assert self._step_index is not None
self._step_index += 1
if not return_dict:
return (prev_sample, )
return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
def __len__(self):
return self.config.num_train_timesteps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment