Unverified Commit 7bc1dae0 authored by Mick's avatar Mick Committed by GitHub
Browse files

WIP: initial multimodal-gen support (#12484)


Co-authored-by: default avataryhyang201 <yhyang201@gmail.com>
Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
Co-authored-by: default avatarJiLi <leege233@gmail.com>
Co-authored-by: default avatarCHEN Xi <78632976+RubiaCx@users.noreply.github.com>
Co-authored-by: default avatarlaixin <xielx@shanghaitech.edu.cn>
Co-authored-by: default avatarSolitaryThinker <wlsaidhi@gmail.com>
Co-authored-by: default avatarjzhang38 <a1286225768@gmail.com>
Co-authored-by: default avatarBrianChen1129 <yongqichcd@gmail.com>
Co-authored-by: default avatarKevin Lin <42618777+kevin314@users.noreply.github.com>
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>
Co-authored-by: default avatarrlsu9 <r3su@ucsd.edu>
Co-authored-by: default avatarJinzhe Pan <48981407+eigensystem@users.noreply.github.com>
Co-authored-by: default avatarforeverpiano <pianoqwz@qq.com>
Co-authored-by: default avatarRandNMR73 <notomatthew31@gmail.com>
Co-authored-by: default avatarPorridgeSwim <yz3883@columbia.edu>
Co-authored-by: default avatarJiali Chen <90408393+gary-chenjl@users.noreply.github.com>
parent 4fe53e58
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/activation.py
"""Custom activation functions."""
import math
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
# TODO (will): remove this dependency
from sglang.multimodal_gen.runtime.layers.custom_op import CustomOp
@CustomOp.register("silu_and_mul")
class SiluAndMul(CustomOp):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def __init__(self) -> None:
super().__init__()
def forward_cuda(self, *args, **kwargs) -> Any:
return self.forward_native(*args, **kwargs)
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
@CustomOp.register("gelu_and_mul")
class GeluAndMul(CustomOp):
"""An activation function for GeGLU.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
return: (batch_size, seq_len, d) or (num_tokens, d)
"""
def __init__(self, approximate: str = "none"):
super().__init__()
self.approximate = approximate
if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}")
def forward_cuda(self, *args, **kwargs) -> Any:
return self.forward_native(*args, **kwargs)
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
def extra_repr(self) -> str:
return f"approximate={repr(self.approximate)}"
@CustomOp.register("gelu_new")
class NewGELU(CustomOp):
def __init__(self):
super().__init__()
def forward_cuda(self, *args, **kwargs) -> Any:
return self.forward_native(*args, **kwargs)
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
@CustomOp.register("quick_gelu")
class QuickGELU(CustomOp):
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
def __init__(self):
super().__init__()
def forward_cuda(self, *args, **kwargs) -> Any:
return self.forward_native(*args, **kwargs)
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return x * torch.sigmoid(1.702 * x)
_ACTIVATION_REGISTRY = {
"gelu": nn.GELU,
"gelu_new": NewGELU,
"gelu_pytorch_tanh": lambda: nn.GELU(approximate="tanh"),
"relu": nn.ReLU,
"silu": nn.SiLU,
"quick_gelu": QuickGELU,
}
def get_act_fn(act_fn_name: str) -> nn.Module:
"""Get an activation function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_REGISTRY:
raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
return _ACTIVATION_REGISTRY[act_fn_name]()
_ACTIVATION_AND_MUL_REGISTRY = {
"gelu": GeluAndMul,
"silu": SiluAndMul,
}
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]()
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import json
import os
from collections import defaultdict
from typing import Any
import numpy as np
from sglang.multimodal_gen.utils import dict_to_3d_list
def configure_sta(
mode: str = "STA_searching",
layer_num: int = 40,
time_step_num: int = 50,
head_num: int = 40,
**kwargs,
) -> list[list[list[Any]]]:
"""
Configure Sliding Tile Attention (STA) parameters based on the specified mode.
Parameters:
----------
mode : str
The STA mode to use. Options are:
- 'STA_searching': Generate a set of mask candidates for initial search
- 'STA_tuning': Select best mask strategy based on previously saved results
- 'STA_inference': Load and use a previously tuned mask strategy
layer_num: int, number of layers
time_step_num: int, number of timesteps
head_num: int, number of heads
**kwargs : dict
Mode-specific parameters:
For 'STA_searching':
- mask_candidates: list of str, optional, mask candidates to use
- mask_selected: list of int, optional, indices of selected masks
For 'STA_tuning':
- mask_search_files_path: str, required, path to mask search results
- mask_candidates: list of str, optional, mask candidates to use
- mask_selected: list of int, optional, indices of selected masks
- skip_time_steps: int, optional, number of time steps to use full attention (default 12)
- save_dir: str, optional, directory to save mask strategy (default "mask_candidates")
For 'STA_inference':
- load_path: str, optional, path to load mask strategy (default "mask_candidates/mask_strategy.json")
"""
valid_modes = ["STA_searching", "STA_tuning", "STA_inference", "STA_tuning_cfg"]
if mode not in valid_modes:
raise ValueError(f"Mode must be one of {valid_modes}, got {mode}")
if mode == "STA_searching":
# Get parameters with defaults
mask_candidates: list[str] | None = kwargs.get("mask_candidates")
if mask_candidates is None:
raise ValueError("mask_candidates is required for STA_searching mode")
mask_selected: list[int] = kwargs.get(
"mask_selected", list(range(len(mask_candidates)))
)
# Parse selected masks
selected_masks: list[list[int]] = []
for index in mask_selected:
mask = mask_candidates[index]
masks_list = [int(x) for x in mask.split(",")]
selected_masks.append(masks_list)
# Create 3D mask structure with fixed dimensions (t=50, l=60)
masks_3d: list[list[list[list[int]]]] = []
for i in range(time_step_num): # Fixed t dimension = 50
row = []
for j in range(layer_num): # Fixed l dimension = 60
row.append(selected_masks) # Add all masks at each position
masks_3d.append(row)
return masks_3d
elif mode == "STA_tuning":
# Get required parameters
mask_search_files_path: str | None = kwargs.get("mask_search_files_path")
if not mask_search_files_path:
raise ValueError("mask_search_files_path is required for STA_tuning mode")
# Get optional parameters with defaults
mask_candidates_tuning: list[str] | None = kwargs.get("mask_candidates")
if mask_candidates_tuning is None:
raise ValueError("mask_candidates is required for STA_tuning mode")
mask_selected_tuning: list[int] = kwargs.get(
"mask_selected", list(range(len(mask_candidates_tuning)))
)
skip_time_steps_tuning: int | None = kwargs.get("skip_time_steps")
save_dir_tuning: str | None = kwargs.get("save_dir", "mask_candidates")
# Parse selected masks
selected_masks_tuning: list[list[int]] = []
for index in mask_selected_tuning:
mask = mask_candidates_tuning[index]
masks_list = [int(x) for x in mask.split(",")]
selected_masks_tuning.append(masks_list)
# Read JSON results
results = read_specific_json_files(mask_search_files_path)
averaged_results = average_head_losses(results, selected_masks_tuning)
# Add full attention mask for specific cases
full_attention_mask_tuning: list[int] | None = kwargs.get("full_attention_mask")
if full_attention_mask_tuning is not None:
selected_masks_tuning.append(full_attention_mask_tuning)
# Select best mask strategy
timesteps_tuning: int = kwargs.get("timesteps", time_step_num)
if skip_time_steps_tuning is None:
skip_time_steps_tuning = 12
mask_strategy, sparsity, strategy_counts = select_best_mask_strategy(
averaged_results,
selected_masks_tuning,
skip_time_steps_tuning,
timesteps_tuning,
head_num,
)
# Save mask strategy
if save_dir_tuning is not None:
os.makedirs(save_dir_tuning, exist_ok=True)
file_path = os.path.join(
save_dir_tuning, f"mask_strategy_s{skip_time_steps_tuning}.json"
)
with open(file_path, "w") as f:
json.dump(mask_strategy, f, indent=4)
print(f"Successfully saved mask_strategy to {file_path}")
# Print sparsity and strategy counts for information
print(f"Overall sparsity: {sparsity:.4f}")
print("\nStrategy usage counts:")
total_heads = time_step_num * layer_num * head_num # Fixed dimensions
for strategy, count in strategy_counts.items():
print(f"Strategy {strategy}: {count} heads ({count/total_heads*100:.2f}%)")
# Convert dictionary to 3D list with fixed dimensions
mask_strategy_3d = dict_to_3d_list(
mask_strategy, t_max=time_step_num, l_max=layer_num, h_max=head_num
)
return mask_strategy_3d
elif mode == "STA_tuning_cfg":
# Get required parameters for both positive and negative paths
mask_search_files_path_pos: str | None = kwargs.get(
"mask_search_files_path_pos"
)
mask_search_files_path_neg: str | None = kwargs.get(
"mask_search_files_path_neg"
)
save_dir_cfg: str | None = kwargs.get("save_dir")
if (
not mask_search_files_path_pos
or not mask_search_files_path_neg
or not save_dir_cfg
):
raise ValueError(
"mask_search_files_path_pos, mask_search_files_path_neg, and save_dir are required for STA_tuning_cfg mode"
)
# Get optional parameters with defaults
mask_candidates_cfg: list[str] | None = kwargs.get("mask_candidates")
if mask_candidates_cfg is None:
raise ValueError("mask_candidates is required for STA_tuning_cfg mode")
mask_selected_cfg: list[int] = kwargs.get(
"mask_selected", list(range(len(mask_candidates_cfg)))
)
skip_time_steps_cfg: int | None = kwargs.get("skip_time_steps")
# Parse selected masks
selected_masks_cfg: list[list[int]] = []
for index in mask_selected_cfg:
mask = mask_candidates_cfg[index]
masks_list = [int(x) for x in mask.split(",")]
selected_masks_cfg.append(masks_list)
# Read JSON results for both positive and negative paths
pos_results = read_specific_json_files(mask_search_files_path_pos)
neg_results = read_specific_json_files(mask_search_files_path_neg)
# Combine positive and negative results into one list
combined_results = pos_results + neg_results
# Average the combined results
averaged_results = average_head_losses(combined_results, selected_masks_cfg)
# Add full attention mask for specific cases
full_attention_mask_cfg: list[int] | None = kwargs.get("full_attention_mask")
if full_attention_mask_cfg is not None:
selected_masks_cfg.append(full_attention_mask_cfg)
timesteps_cfg: int = kwargs.get("timesteps", time_step_num)
if skip_time_steps_cfg is None:
skip_time_steps_cfg = 12
# Select best mask strategy using combined results
mask_strategy, sparsity, strategy_counts = select_best_mask_strategy(
averaged_results,
selected_masks_cfg,
skip_time_steps_cfg,
timesteps_cfg,
head_num,
)
# Save mask strategy
os.makedirs(save_dir_cfg, exist_ok=True)
file_path = os.path.join(
save_dir_cfg, f"mask_strategy_s{skip_time_steps_cfg}.json"
)
with open(file_path, "w") as f:
json.dump(mask_strategy, f, indent=4)
print(f"Successfully saved mask_strategy to {file_path}")
# Print sparsity and strategy counts for information
print(f"Overall sparsity: {sparsity:.4f}")
print("\nStrategy usage counts:")
total_heads = time_step_num * layer_num * head_num # Fixed dimensions
for strategy, count in strategy_counts.items():
print(f"Strategy {strategy}: {count} heads ({count/total_heads*100:.2f}%)")
# Convert dictionary to 3D list with fixed dimensions
mask_strategy_3d = dict_to_3d_list(
mask_strategy, t_max=time_step_num, l_max=layer_num, h_max=head_num
)
return mask_strategy_3d
else: # STA_inference
# Get parameters with defaults
load_path: str | None = kwargs.get(
"load_path", "mask_candidates/mask_strategy.json"
)
if load_path is None:
raise ValueError("load_path is required for STA_inference mode")
# Load previously saved mask strategy
with open(load_path) as f:
mask_strategy = json.load(f)
# Convert dictionary to 3D list with fixed dimensions
mask_strategy_3d = dict_to_3d_list(
mask_strategy, t_max=time_step_num, l_max=layer_num, h_max=head_num
)
return mask_strategy_3d
# Helper functions
def read_specific_json_files(folder_path: str) -> list[dict[str, Any]]:
"""Read and parse JSON files containing mask search results."""
json_contents: list[dict[str, Any]] = []
# List files only in the current directory (no walk)
files = os.listdir(folder_path)
# Filter files
matching_files = [f for f in files if "mask" in f and f.endswith(".json")]
print(f"Found {len(matching_files)} matching files: {matching_files}")
for file_name in matching_files:
file_path = os.path.join(folder_path, file_name)
with open(file_path) as file:
data = json.load(file)
json_contents.append(data)
return json_contents
def average_head_losses(
results: list[dict[str, Any]], selected_masks: list[list[int]]
) -> dict[str, dict[str, np.ndarray]]:
"""Average losses across all prompts for each mask strategy."""
# Initialize a dictionary to store the averaged results
averaged_losses: dict[str, dict[str, np.ndarray]] = {}
loss_type = "L2_loss"
# Get all loss types (e.g., 'L2_loss')
averaged_losses[loss_type] = {}
for mask in selected_masks:
mask_str = str(mask)
data_shape = np.array(results[0][loss_type][mask_str]).shape
accumulated_data = np.zeros(data_shape)
# Sum across all prompts
for prompt_result in results:
accumulated_data += np.array(prompt_result[loss_type][mask_str])
# Average by dividing by number of prompts
averaged_data = accumulated_data / len(results)
averaged_losses[loss_type][mask_str] = averaged_data
return averaged_losses
def select_best_mask_strategy(
averaged_results: dict[str, dict[str, np.ndarray]],
selected_masks: list[list[int]],
skip_time_steps: int = 12,
timesteps: int = 50,
head_num: int = 40,
) -> tuple[dict[str, list[int]], float, dict[str, int]]:
"""Select the best mask strategy for each head based on loss minimization."""
best_mask_strategy: dict[str, list[int]] = {}
loss_type = "L2_loss"
# Get the shape of time steps and layers
layers = len(averaged_results[loss_type][str(selected_masks[0])][0])
# Counter for sparsity calculation
total_tokens = 0 # total number of masked tokens
total_length = 0 # total sequence length
strategy_counts: dict[str, int] = {str(strategy): 0 for strategy in selected_masks}
full_attn_strategy = selected_masks[-1] # Last strategy is full attention
print(f"Strategy {full_attn_strategy}, skip first {skip_time_steps} steps ")
for t in range(timesteps):
for layer_idx in range(layers):
for h in range(head_num):
if t < skip_time_steps: # First steps use full attention
strategy = full_attn_strategy
else:
# Get losses for this head across all strategies
head_losses = []
for strategy in selected_masks[:-1]: # Exclude full attention
head_losses.append(
averaged_results[loss_type][str(strategy)][t][layer_idx][h]
)
# Find which strategy gives minimum loss
best_strategy_idx = np.argmin(head_losses)
strategy = selected_masks[best_strategy_idx]
best_mask_strategy[f"{t}_{layer_idx}_{h}"] = strategy
# Calculate sparsity
nums = strategy # strategy is already a list of numbers
total_tokens += (
nums[0] * nums[1] * nums[2]
) # masked tokens for chosen strategy
total_length += (
full_attn_strategy[0]
* full_attn_strategy[1]
* full_attn_strategy[2]
)
# Count strategy usage
strategy_counts[str(strategy)] += 1
overall_sparsity = 1 - total_tokens / total_length
return best_mask_strategy, overall_sparsity, strategy_counts
def save_mask_search_results(
mask_search_final_result: list[dict[str, list[float]]],
prompt: str,
mask_strategies: list[str],
output_dir: str = "output/mask_search_result/",
) -> str | None:
if not mask_search_final_result:
print("No mask search results to save")
return None
# Create result dictionary with defaultdict for nested lists
mask_search_dict: dict[str, dict[str, list[list[float]]]] = {
"L2_loss": defaultdict(list),
"L1_loss": defaultdict(list),
}
mask_selected = list(range(len(mask_strategies)))
selected_masks: list[list[int]] = []
for index in mask_selected:
mask = mask_strategies[index]
masks_list = [int(x) for x in mask.split(",")]
selected_masks.append(masks_list)
# Process each mask strategy
for i, mask_strategy in enumerate(selected_masks):
mask_strategy_str = str(mask_strategy)
# Process L2 loss
step_results: list[list[float]] = []
for step_data in mask_search_final_result:
if isinstance(step_data, dict) and "L2_loss" in step_data:
layer_losses = [float(loss) for loss in step_data["L2_loss"]]
step_results.append(layer_losses)
mask_search_dict["L2_loss"][mask_strategy_str] = step_results
step_results = []
for step_data in mask_search_final_result:
if isinstance(step_data, dict) and "L1_loss" in step_data:
layer_losses = [float(loss) for loss in step_data["L1_loss"]]
step_results.append(layer_losses)
mask_search_dict["L1_loss"][mask_strategy_str] = step_results
# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Create a filename based on the first 20 characters of the prompt
filename = prompt[:50].replace(" ", "_")
filepath = os.path.join(output_dir, f"mask_search_{filename}.json")
# Save the results to a JSON file
with open(filepath, "w") as f:
json.dump(mask_search_dict, f, indent=4)
print(f"Successfully saved mask research results to {filepath}")
return filepath
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (
AttentionBackend,
AttentionMetadata,
AttentionMetadataBuilder,
)
from sglang.multimodal_gen.runtime.layers.attention.layer import (
LocalAttention,
UlyssesAttention,
UlyssesAttention_VSA,
USPAttention,
)
from sglang.multimodal_gen.runtime.layers.attention.selector import get_attn_backend
__all__ = [
"USPAttention",
"LocalAttention",
"UlyssesAttention",
"UlyssesAttention_VSA",
"AttentionBackend",
"AttentionMetadata",
"AttentionMetadataBuilder",
# "AttentionState",
"get_attn_backend",
]
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import aiter
import torch
from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
)
class AITerBackend(AttentionBackend):
"""
Backend for AITemplate attention implementation.
"""
@staticmethod
def get_name() -> str:
return "AITER"
@staticmethod
def get_impl_cls() -> type["AITerImpl"]:
return AITerImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
# AITer backend does not require special metadata.
return AttentionMetadata
@staticmethod
def get_builder_cls() -> type["AttentionMetadataBuilder"]:
raise NotImplementedError("AITer backend does not have a metadata builder.")
class AITerImpl(AttentionImpl):
"""
Implementation of attention using AITemplate.
"""
def __init__(
self,
num_heads: int,
head_size: int,
softmax_scale: float,
causal: bool = False,
num_kv_heads: int | None = None,
prefix: str = "",
dropout_p: float = 0.0,
**extra_impl_args,
) -> None:
super().__init__(
num_heads=num_heads,
head_size=head_size,
softmax_scale=softmax_scale,
causal=causal,
num_kv_heads=num_kv_heads,
prefix=prefix,
**extra_impl_args,
)
if num_kv_heads is not None and num_kv_heads != num_heads:
raise NotImplementedError(
"AITer backend does not support Grouped Query Attention yet."
)
self.causal = causal
self.dropout_p = dropout_p
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata | None = None,
) -> torch.Tensor:
"""
Performs attention using aiter.flash_attn_func.
Args:
query: Query tensor of shape [batch_size, num_heads, seq_len, head_dim]
key: Key tensor of shape [batch_size, num_heads, seq_len, head_dim]
value: Value tensor of shape [batch_size, num_heads, seq_len, head_dim]
attn_metadata: Metadata for the attention operation (unused).
Returns:
Output tensor of shape [batch_size, num_heads, seq_len, head_dim]
"""
# aiter.flash_attn_func expects tensors in [B, H, S, D] layout,
# which is what ring_attn provides.
output, _ = aiter.flash_attn_func(
query,
key,
value,
dropout_p=self.dropout_p,
causal=self.causal,
return_attn_probs=False,
return_lse=True,
)
return output
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/attention/backends/abstract.py
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar
if TYPE_CHECKING:
pass
import torch
class AttentionBackend(ABC):
"""Abstract class for attention backends."""
# For some attention backends, we allocate an output tensor before
# calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False
@staticmethod
@abstractmethod
def get_name() -> str:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_impl_cls() -> type["AttentionImpl"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
raise NotImplementedError
# @staticmethod
# @abstractmethod
# def get_state_cls() -> Type["AttentionState"]:
# raise NotImplementedError
# @classmethod
# def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
# return cls.get_metadata_cls()(*args, **kwargs)
@staticmethod
@abstractmethod
def get_builder_cls() -> type["AttentionMetadataBuilder"]:
return None
@dataclass
class AttentionMetadata:
"""Attention metadata for prefill and decode batched together."""
# Current step of diffusion process
current_timestep: int
def asdict_zerocopy(self, skip_fields: set[str] | None = None) -> dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
if skip_fields is None:
skip_fields = set()
# Note that if we add dataclasses as fields, they will need
# similar handling.
return {
field.name: getattr(self, field.name)
for field in fields(self)
if field.name not in skip_fields
}
T = TypeVar("T", bound=AttentionMetadata)
class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""
@abstractmethod
def __init__(self) -> None:
"""Create the builder, remember some configuration and parameters."""
raise NotImplementedError
@abstractmethod
def prepare(self) -> None:
"""Prepare for one batch."""
raise NotImplementedError
@abstractmethod
def build(
self,
**kwargs: dict[str, Any],
) -> AttentionMetadata:
"""Build attention metadata with on-device tensors."""
raise NotImplementedError
class AttentionLayer(Protocol):
_k_scale: torch.Tensor
_v_scale: torch.Tensor
_k_scale_float: float
_v_scale_float: float
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ...
class AttentionImpl(ABC, Generic[T]):
@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
softmax_scale: float,
causal: bool = False,
num_kv_heads: int | None = None,
prefix: str = "",
**extra_impl_args,
) -> None:
raise NotImplementedError
def preprocess_qkv(self, qkv: torch.Tensor, attn_metadata: T) -> torch.Tensor:
"""Preprocess QKV tensor before performing attention operation.
Default implementation returns the tensor unchanged.
Subclasses can override this to implement custom preprocessing
like reshaping, tiling, scaling, or other transformations.
Called AFTER all_to_all for distributed attention
Args:
qkv: The query-key-value tensor
attn_metadata: Metadata for the attention operation
Returns:
Processed QKV tensor
"""
return qkv
def postprocess_output(
self,
output: torch.Tensor,
attn_metadata: T,
) -> torch.Tensor:
"""Postprocess the output tensor after the attention operation.
Default implementation returns the tensor unchanged.
Subclasses can override this to implement custom postprocessing
like untiling, scaling, or other transformations.
Called BEFORE all_to_all for distributed attention
Args:
output: The output tensor from the attention operation
attn_metadata: Metadata for the attention operation
Returns:
Postprocessed output tensor
"""
return output
@abstractmethod
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: T,
) -> torch.Tensor:
raise NotImplementedError
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Any
import torch
from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionMetadata
try:
from sgl_kernel.flash_attn import flash_attn_varlen_func
# flash_attn 3 no longer have a different API, see following commit:
# https://github.com/Dao-AILab/flash-attention/commit/ed209409acedbb2379f870bbd03abce31a7a51b7
flash_attn_func = flash_attn_varlen_func
except ImportError as e:
raise e
from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
@dataclass
class FlashAttentionMetadata:
# Sequence lengths for the forward batch
# Maximum sequence length for query
max_seqlen_q: int = 1
# Maximum sequence length for key
max_seqlen_k: int = 0
# Cumulative sequence lengths for query
cu_seqlens_q: torch.Tensor = None
# Cumulative sequence lengths for key
cu_seqlens_k: torch.Tensor = None
class FlashAttentionMetadataBuilder(AttentionMetadataBuilder):
def __init__(self):
pass
def prepare(self):
pass
def build( # type: ignore
self,
raw_latent_shape=list,
**kwargs: dict[str, Any],
) -> FlashAttentionMetadata:
# TODO: put empty values here to be set at first-run, since the q_len calculation can be complicated
return FlashAttentionMetadata(max_seqlen_q=None, max_seqlen_k=None)
class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "FLASH_ATTN"
@staticmethod
def get_impl_cls() -> type["FlashAttentionImpl"]:
return FlashAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
raise NotImplementedError
@staticmethod
def get_builder_cls() -> type["AttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
class FlashAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
causal: bool,
softmax_scale: float,
num_kv_heads: int | None = None,
prefix: str = "",
**extra_impl_args,
) -> None:
self.causal = causal
self.softmax_scale = softmax_scale
self.attention_metadata = FlashAttentionMetadata()
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata = None,
*,
return_softmax_lse: bool = False,
):
attn_metadata: FlashAttentionMetadata = get_forward_context().attn_metadata
if attn_metadata is not None and attn_metadata.max_seqlen_q is None:
attn_metadata.max_seqlen_q = query.shape[1]
attn_metadata.max_seqlen_k = key.shape[1]
max_seqlen_q = attn_metadata.max_seqlen_q
max_seqlen_k = attn_metadata.max_seqlen_k
else:
max_seqlen_q = query.shape[1]
max_seqlen_k = key.shape[1]
output = flash_attn_func(
q=query, # type: ignore[no-untyped-call]
k=key,
v=value,
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.softmax_scale,
causal=self.causal,
return_softmax_lse=return_softmax_lse,
)
return output
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import torch
from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
)
from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import (
flash_attn_func,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class FlashAttention2Backend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "FA3"
@staticmethod
def get_impl_cls() -> type["FlashAttention2Impl"]:
return FlashAttention2Impl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
raise NotImplementedError
@staticmethod
def get_builder_cls() -> type["AttentionMetadataBuilder"]:
raise NotImplementedError
class FlashAttention2Impl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
causal: bool,
softmax_scale: float,
num_kv_heads: int | None = None,
prefix: str = "",
**extra_impl_args,
) -> None:
self.causal = causal
self.softmax_scale = softmax_scale
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata,
):
output = flash_attn_func(
q=query, # type: ignore[no-untyped-call]
k=key,
v=value,
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seqlen_q=None,
max_seqlen_k=None,
softmax_scale=self.softmax_scale,
causal=self.causal,
)
return output
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import torch
from sageattention import sageattn
from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( # FlashAttentionMetadata,
AttentionBackend,
AttentionImpl,
AttentionMetadata,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class SageAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "SAGE_ATTN"
@staticmethod
def get_impl_cls() -> type["SageAttentionImpl"]:
return SageAttentionImpl
# @staticmethod
# def get_metadata_cls() -> Type["AttentionMetadata"]:
# return FlashAttentionMetadata
class SageAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
causal: bool,
softmax_scale: float,
num_kv_heads: int | None = None,
prefix: str = "",
**extra_impl_args,
) -> None:
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout = extra_impl_args.get("dropout_p", 0.0)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
output = sageattn(
query,
key,
value,
# since input is (batch_size, seq_len, head_num, head_dim)
tensor_layout="NHD",
is_causal=self.causal,
)
return output
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import torch
from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
)
from sglang.multimodal_gen.runtime.layers.attention.backends.sageattn.api import (
sageattn_blackwell,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class SageAttention3Backend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [64, 128, 256]
@staticmethod
def get_name() -> str:
return "SAGE_ATTN_THREE"
@staticmethod
def get_impl_cls() -> type["SageAttention3Impl"]:
return SageAttention3Impl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
raise NotImplementedError
@staticmethod
def get_builder_cls() -> type["AttentionMetadataBuilder"]:
raise NotImplementedError
# @staticmethod
# def get_metadata_cls() -> Type["AttentionMetadata"]:
# return FlashAttentionMetadata
class SageAttention3Impl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
causal: bool,
softmax_scale: float,
num_kv_heads: int | None = None,
prefix: str = "",
**extra_impl_args,
) -> None:
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout = extra_impl_args.get("dropout_p", 0.0)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
output = sageattn_blackwell(query, key, value, is_causal=self.causal)
output = output.transpose(1, 2)
return output
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import torch
from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( # FlashAttentionMetadata,
AttentionBackend,
AttentionImpl,
AttentionMetadata,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class SDPABackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "SDPA"
@staticmethod
def get_impl_cls() -> type["SDPAImpl"]:
return SDPAImpl
# @staticmethod
# def get_metadata_cls() -> Type["AttentionMetadata"]:
# return FlashAttentionMetadata
class SDPAImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
causal: bool,
softmax_scale: float,
num_kv_heads: int | None = None,
prefix: str = "",
**extra_impl_args,
) -> None:
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout = extra_impl_args.get("dropout_p", 0.0)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# transpose to bs, heads, seq_len, head_dim
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn_kwargs = {
"attn_mask": None,
"dropout_p": self.dropout,
"is_causal": self.causal,
"scale": self.softmax_scale,
}
if query.shape[1] != key.shape[1]:
attn_kwargs["enable_gqa"] = True
output = torch.nn.functional.scaled_dot_product_attention(
query, key, value, **attn_kwargs
)
output = output.transpose(1, 2)
return output
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import json
from dataclasses import dataclass
from typing import Any
import torch
from einops import rearrange
import sglang.multimodal_gen.envs as envs
from sglang.multimodal_gen.runtime.distributed import get_sp_group
from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
)
from sglang.multimodal_gen.runtime.managers.forward_context import (
ForwardContext,
get_forward_context,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import dict_to_3d_list
try:
from st_attn import sliding_tile_attention
st_attn_backend_available = True
except Exception:
st_attn_backend_available = False
logger = init_logger(__name__)
class RangeDict(dict):
def __getitem__(self, item: int) -> str:
for key in self.keys():
if isinstance(key, tuple):
low, high = key
if low <= item <= high:
return str(super().__getitem__(key))
elif key == item:
return str(super().__getitem__(key))
raise KeyError(f"seq_len {item} not supported for STA")
class SlidingTileAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
# TODO(will-refactor): check this
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "SLIDING_TILE_ATTN"
@staticmethod
def get_impl_cls() -> type["SlidingTileAttentionImpl"]:
return SlidingTileAttentionImpl
@staticmethod
def get_metadata_cls() -> type["SlidingTileAttentionMetadata"]:
return SlidingTileAttentionMetadata
@staticmethod
def get_builder_cls() -> type["SlidingTileAttentionMetadataBuilder"]:
return SlidingTileAttentionMetadataBuilder
@dataclass
class SlidingTileAttentionMetadata(AttentionMetadata):
current_timestep: int
STA_param: list[
list[Any]
] # each timestep with one metadata, shape [num_layers, num_heads]
class SlidingTileAttentionMetadataBuilder(AttentionMetadataBuilder):
def __init__(self):
pass
def prepare(self):
pass
def build( # type: ignore
self,
STA_param: list[list[Any]],
current_timestep: int,
**kwargs: dict[str, Any],
) -> SlidingTileAttentionMetadata:
param = STA_param
if param is None:
return SlidingTileAttentionMetadata(
current_timestep=current_timestep, STA_param=[]
)
return SlidingTileAttentionMetadata(
current_timestep=current_timestep, STA_param=param[current_timestep]
)
class SlidingTileAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
causal: bool,
softmax_scale: float,
num_kv_heads: int | None = None,
prefix: str = "",
**extra_impl_args,
) -> None:
if not st_attn_backend_available:
raise ValueError("st attn not supported")
# TODO(will-refactor): for now this is the mask strategy, but maybe we should
# have a more general config for STA?
config_file = envs.SGL_DIFFUSION_ATTENTION_CONFIG
if config_file is None:
raise ValueError("SGL_DIFFUSION_ATTENTION_CONFIG is not set")
# TODO(kevin): get mask strategy for different STA modes
with open(config_file) as f:
mask_strategy = json.load(f)
self.mask_strategy = dict_to_3d_list(mask_strategy)
self.prefix = prefix
sp_group = get_sp_group()
self.sp_size = sp_group.world_size
# STA config
self.STA_base_tile_size = [6, 8, 8]
self.dit_seq_shape_mapping = RangeDict(
{
(115200, 115456): "30x48x80",
82944: "36x48x48",
69120: "18x48x80",
}
)
self.full_window_mapping = {
"30x48x80": [5, 6, 10],
"36x48x48": [6, 6, 6],
"18x48x80": [3, 6, 10],
}
def tile(self, x: torch.Tensor) -> torch.Tensor:
return rearrange(
x,
"b (n_t ts_t n_h ts_h n_w ts_w) h d -> b (n_t n_h n_w ts_t ts_h ts_w) h d",
n_t=self.full_window_size[0],
n_h=self.full_window_size[1],
n_w=self.full_window_size[2],
ts_t=self.STA_base_tile_size[0],
ts_h=self.STA_base_tile_size[1],
ts_w=self.STA_base_tile_size[2],
)
def untile(self, x: torch.Tensor) -> torch.Tensor:
x = rearrange(
x,
"b (n_t n_h n_w ts_t ts_h ts_w) h d -> b (n_t ts_t n_h ts_h n_w ts_w) h d",
n_t=self.full_window_size[0],
n_h=self.full_window_size[1],
n_w=self.full_window_size[2],
ts_t=self.STA_base_tile_size[0],
ts_h=self.STA_base_tile_size[1],
ts_w=self.STA_base_tile_size[2],
)
return x
def preprocess_qkv(
self,
qkv: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
img_sequence_length = qkv.shape[1]
self.dit_seq_shape_str = self.dit_seq_shape_mapping[img_sequence_length]
self.full_window_size = self.full_window_mapping[self.dit_seq_shape_str]
self.dit_seq_shape_int = list(map(int, self.dit_seq_shape_str.split("x")))
self.img_seq_length = (
self.dit_seq_shape_int[0]
* self.dit_seq_shape_int[1]
* self.dit_seq_shape_int[2]
)
return self.tile(qkv)
def postprocess_output(
self,
output: torch.Tensor,
attn_metadata: SlidingTileAttentionMetadata,
) -> torch.Tensor:
return self.untile(output)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_metadata: SlidingTileAttentionMetadata,
) -> torch.Tensor:
if self.mask_strategy is None:
raise ValueError("mask_strategy cannot be None for SlidingTileAttention")
if self.mask_strategy[0] is None:
raise ValueError("mask_strategy[0] cannot be None for SlidingTileAttention")
timestep = attn_metadata.current_timestep
forward_context: ForwardContext = get_forward_context()
forward_batch = forward_context.forward_batch
if forward_batch is None:
raise ValueError("forward_batch cannot be None")
# pattern:'.double_blocks.0.attn.impl' or '.single_blocks.0.attn.impl'
layer_idx = int(self.prefix.split(".")[-3])
if attn_metadata.STA_param is None or len(attn_metadata.STA_param) <= layer_idx:
raise ValueError("Invalid STA_param")
STA_param = attn_metadata.STA_param[layer_idx]
text_length = q.shape[1] - self.img_seq_length
has_text = text_length > 0
query = q.transpose(1, 2).contiguous()
key = k.transpose(1, 2).contiguous()
value = v.transpose(1, 2).contiguous()
head_num = query.size(1)
sp_group = get_sp_group()
current_rank = sp_group.rank_in_group
start_head = current_rank * head_num
# searching or tuning mode
if len(STA_param) < head_num * sp_group.world_size:
sparse_attn_hidden_states_all = []
full_mask_window = STA_param[-1]
for window_size in STA_param[:-1]:
sparse_hidden_states = sliding_tile_attention(
query,
key,
value,
[window_size] * head_num,
text_length,
has_text,
self.dit_seq_shape_str,
).transpose(1, 2)
sparse_attn_hidden_states_all.append(sparse_hidden_states)
hidden_states = sliding_tile_attention(
query,
key,
value,
[full_mask_window] * head_num,
text_length,
has_text,
self.dit_seq_shape_str,
).transpose(1, 2)
attn_L2_loss = []
attn_L1_loss = []
# average loss across all heads
for sparse_attn_hidden_states in sparse_attn_hidden_states_all:
# L2 loss
attn_L2_loss_ = (
torch.mean(
(sparse_attn_hidden_states.float() - hidden_states.float())
** 2,
dim=[0, 1, 3],
)
.cpu()
.numpy()
)
attn_L2_loss_ = [round(float(x), 6) for x in attn_L2_loss_]
attn_L2_loss.append(attn_L2_loss_)
# L1 loss
attn_L1_loss_ = (
torch.mean(
torch.abs(
sparse_attn_hidden_states.float() - hidden_states.float()
),
dim=[0, 1, 3],
)
.cpu()
.numpy()
)
attn_L1_loss_ = [round(float(x), 6) for x in attn_L1_loss_]
attn_L1_loss.append(attn_L1_loss_)
layer_loss_save = {"L2_loss": attn_L2_loss, "L1_loss": attn_L1_loss}
if forward_batch.is_cfg_negative:
if forward_batch.mask_search_final_result_neg is not None:
forward_batch.mask_search_final_result_neg[timestep].append(
layer_loss_save
)
else:
if forward_batch.mask_search_final_result_pos is not None:
forward_batch.mask_search_final_result_pos[timestep].append(
layer_loss_save
)
else:
windows = [STA_param[head_idx + start_head] for head_idx in range(head_num)]
hidden_states = sliding_tile_attention(
query,
key,
value,
windows,
text_length,
has_text,
self.dit_seq_shape_str,
).transpose(1, 2)
return hidden_states
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import functools
import math
from dataclasses import dataclass
import torch
try:
from vsa import video_sparse_attn
except ImportError:
video_sparse_attn = None
from typing import Any
from sglang.multimodal_gen.runtime.distributed import get_sp_group
from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
VSA_TILE_SIZE = (4, 4, 4)
@functools.lru_cache(maxsize=10)
def get_tile_partition_indices(
dit_seq_shape: tuple[int, int, int],
tile_size: tuple[int, int, int],
device: torch.device,
) -> torch.LongTensor:
T, H, W = dit_seq_shape
ts, hs, ws = tile_size
indices = torch.arange(T * H * W, device=device, dtype=torch.long).reshape(T, H, W)
ls = []
for t in range(math.ceil(T / ts)):
for h in range(math.ceil(H / hs)):
for w in range(math.ceil(W / ws)):
ls.append(
indices[
t * ts : min(t * ts + ts, T),
h * hs : min(h * hs + hs, H),
w * ws : min(w * ws + ws, W),
].flatten()
)
index = torch.cat(ls, dim=0)
return index
@functools.lru_cache(maxsize=10)
def get_reverse_tile_partition_indices(
dit_seq_shape: tuple[int, int, int],
tile_size: tuple[int, int, int],
device: torch.device,
) -> torch.LongTensor:
return torch.argsort(get_tile_partition_indices(dit_seq_shape, tile_size, device))
@functools.lru_cache(maxsize=10)
def construct_variable_block_sizes(
dit_seq_shape: tuple[int, int, int],
num_tiles: tuple[int, int, int],
device: torch.device,
) -> torch.LongTensor:
"""
Compute the number of valid (non‑padded) tokens inside every
(ts_t × ts_h × ts_w) tile after padding ‑‑ flattened in the order
(t‑tile, h‑tile, w‑tile) that `rearrange` uses.
Returns
-------
torch.LongTensor # shape: [∏ full_window_size]
"""
# unpack
t, h, w = dit_seq_shape
ts_t, ts_h, ts_w = VSA_TILE_SIZE
n_t, n_h, n_w = num_tiles
def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor:
"""Vector with the size of each tile along one dimension."""
sizes = torch.full((n_tiles,), tile, dtype=torch.int, device=device)
# size of last (possibly partial) tile
remainder = dim_len - (n_tiles - 1) * tile
sizes[-1] = remainder if remainder > 0 else tile
return sizes
t_sizes = _sizes(t, ts_t, n_t) # [n_t]
h_sizes = _sizes(h, ts_h, n_h) # [n_h]
w_sizes = _sizes(w, ts_w, n_w) # [n_w]
# broadcast‑multiply to get voxels per tile, then flatten
block_sizes = (
t_sizes[:, None, None] # [n_t, 1, 1]
* h_sizes[None, :, None] # [1, n_h, 1]
* w_sizes[None, None, :] # [1, 1, n_w]
).reshape(
-1
) # [n_t * n_h * n_w]
return block_sizes
@functools.lru_cache(maxsize=10)
def get_non_pad_index(
variable_block_sizes: torch.LongTensor,
max_block_size: int,
):
n_win = variable_block_sizes.shape[0]
device = variable_block_sizes.device
starts_pad = torch.arange(n_win, device=device) * max_block_size
index_pad = (
starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :]
)
index_mask = (
torch.arange(max_block_size, device=device)[None, :]
< variable_block_sizes[:, None]
)
return index_pad[index_mask]
class VideoSparseAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [64, 128]
@staticmethod
def get_name() -> str:
return "VIDEO_SPARSE_ATTN"
@staticmethod
def get_impl_cls() -> type["VideoSparseAttentionImpl"]:
return VideoSparseAttentionImpl
@staticmethod
def get_metadata_cls() -> type["VideoSparseAttentionMetadata"]:
return VideoSparseAttentionMetadata
@staticmethod
def get_builder_cls() -> type["VideoSparseAttentionMetadataBuilder"]:
return VideoSparseAttentionMetadataBuilder
@dataclass
class VideoSparseAttentionMetadata(AttentionMetadata):
current_timestep: int
dit_seq_shape: list[int]
VSA_sparsity: float
num_tiles: list[int]
total_seq_length: int
tile_partition_indices: torch.LongTensor
reverse_tile_partition_indices: torch.LongTensor
variable_block_sizes: torch.LongTensor
non_pad_index: torch.LongTensor
# adaption for FastWan2.1-T2V-1.3B-Diffusers
# Sequence lengths for the forward batch
# Maximum sequence length for query
max_seqlen_q: int = 1
# Maximum sequence length for key
max_seqlen_k: int = 0
class VideoSparseAttentionMetadataBuilder(AttentionMetadataBuilder):
def __init__(self):
pass
def prepare(self):
pass
def build( # type: ignore
self,
current_timestep: int,
raw_latent_shape: tuple[int, int, int],
patch_size: tuple[int, int, int],
VSA_sparsity: float,
device: torch.device,
**kwargs: dict[str, Any],
) -> VideoSparseAttentionMetadata:
patch_size = patch_size
dit_seq_shape = (
raw_latent_shape[0] // patch_size[0],
raw_latent_shape[1] // patch_size[1],
raw_latent_shape[2] // patch_size[2],
)
num_tiles = (
math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]),
math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]),
math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2]),
)
total_seq_length = math.prod(dit_seq_shape)
tile_partition_indices = get_tile_partition_indices(
dit_seq_shape, VSA_TILE_SIZE, device
)
reverse_tile_partition_indices = get_reverse_tile_partition_indices(
dit_seq_shape, VSA_TILE_SIZE, device
)
variable_block_sizes = construct_variable_block_sizes(
dit_seq_shape, num_tiles, device
)
non_pad_index = get_non_pad_index(
variable_block_sizes, math.prod(VSA_TILE_SIZE)
)
return VideoSparseAttentionMetadata(
current_timestep=current_timestep,
dit_seq_shape=dit_seq_shape, # type: ignore
VSA_sparsity=VSA_sparsity, # type: ignore
num_tiles=num_tiles, # type: ignore
total_seq_length=total_seq_length, # type: ignore
tile_partition_indices=tile_partition_indices, # type: ignore
reverse_tile_partition_indices=reverse_tile_partition_indices,
variable_block_sizes=variable_block_sizes,
non_pad_index=non_pad_index,
)
class VideoSparseAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
causal: bool,
softmax_scale: float,
num_kv_heads: int | None = None,
prefix: str = "",
**extra_impl_args,
) -> None:
self.prefix = prefix
sp_group = get_sp_group()
self.sp_size = sp_group.world_size
def tile(
self,
x: torch.Tensor,
num_tiles: list[int],
tile_partition_indices: torch.LongTensor,
non_pad_index: torch.LongTensor,
) -> torch.Tensor:
t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0]
h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1]
w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2]
x_padded = torch.zeros(
(
x.shape[0],
t_padded_size * h_padded_size * w_padded_size,
x.shape[-2],
x.shape[-1],
),
device=x.device,
dtype=x.dtype,
)
x_padded[:, non_pad_index] = x[:, tile_partition_indices]
return x_padded
def untile(
self,
x: torch.Tensor,
reverse_tile_partition_indices: torch.LongTensor,
non_pad_index: torch.LongTensor,
) -> torch.Tensor:
x = x[:, non_pad_index][:, reverse_tile_partition_indices]
return x
def preprocess_qkv(
self,
qkv: torch.Tensor,
attn_metadata: VideoSparseAttentionMetadata,
) -> torch.Tensor:
return self.tile(
qkv,
attn_metadata.num_tiles,
attn_metadata.tile_partition_indices,
attn_metadata.non_pad_index,
)
def postprocess_output(
self,
output: torch.Tensor,
attn_metadata: VideoSparseAttentionMetadata,
) -> torch.Tensor:
return self.untile(
output,
attn_metadata.reverse_tile_partition_indices,
attn_metadata.non_pad_index,
)
def forward( # type: ignore[override]
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
gate_compress: torch.Tensor,
attn_metadata: VideoSparseAttentionMetadata,
) -> torch.Tensor:
query = query.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
gate_compress = gate_compress.transpose(1, 2).contiguous()
VSA_sparsity = attn_metadata.VSA_sparsity
cur_topk = math.ceil(
(1 - VSA_sparsity)
* (attn_metadata.total_seq_length / math.prod(VSA_TILE_SIZE))
)
if video_sparse_attn is None:
raise NotImplementedError("video_sparse_attn is not installed")
hidden_states = video_sparse_attn(
query,
key,
value,
variable_block_sizes=attn_metadata.variable_block_sizes,
topk=cur_topk,
block_size=VSA_TILE_SIZE,
compress_attn_weight=gate_compress,
).transpose(1, 2)
return hidden_states
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import re
from dataclasses import dataclass
import torch
from einops import rearrange
from kernel.attn.vmoba_attn.vmoba import (
moba_attn_varlen,
process_moba_input,
process_moba_output,
)
from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class VMOBAAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "VMOBA_ATTN"
@staticmethod
def get_impl_cls() -> type["VMOBAAttentionImpl"]:
return VMOBAAttentionImpl
@staticmethod
def get_metadata_cls() -> type["VideoMobaAttentionMetadata"]:
return VideoMobaAttentionMetadata
@staticmethod
def get_builder_cls() -> type["VideoMobaAttentionMetadataBuilder"]:
return VideoMobaAttentionMetadataBuilder
@dataclass
class VideoMobaAttentionMetadata(AttentionMetadata):
current_timestep: int
temporal_chunk_size: int
temporal_topk: int
spatial_chunk_size: tuple[int, int]
spatial_topk: int
st_chunk_size: tuple[int, int, int]
st_topk: int
moba_select_mode: str
moba_threshold: float
moba_threshold_type: str
patch_resolution: list[int]
first_full_step: int = 12
first_full_layer: int = 0
# temporal_layer -> spatial_layer -> st_layer
temporal_layer: int = 1
spatial_layer: int = 1
st_layer: int = 1
def pad_input(hidden_states, indices, batch, seqlen):
"""
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
batch: int, batch size for the padded sequence.
seqlen: int, maximum sequence length for the padded sequence.
Return:
hidden_states: (batch, seqlen, ...)
"""
dim = hidden_states.shape[1:]
output = torch.zeros(
(batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype
)
output[indices] = hidden_states
return rearrange(output, "(b s) ... -> b s ...", b=batch)
class VideoMobaAttentionMetadataBuilder(AttentionMetadataBuilder):
def __init__(self):
pass
def prepare(self):
pass
def build( # type: ignore
self,
current_timestep: int,
raw_latent_shape: tuple[int, int, int],
patch_size: tuple[int, int, int],
temporal_chunk_size: int,
temporal_topk: int,
spatial_chunk_size: tuple[int, int],
spatial_topk: int,
st_chunk_size: tuple[int, int, int],
st_topk: int,
moba_select_mode: str = "threshold",
moba_threshold: float = 0.25,
moba_threshold_type: str = "query_head",
device: torch.device = None,
first_full_layer: int = 0,
first_full_step: int = 12,
temporal_layer: int = 1,
spatial_layer: int = 1,
st_layer: int = 1,
**kwargs,
) -> VideoMobaAttentionMetadata:
if device is None:
device = torch.device("cpu")
assert (
raw_latent_shape[0] % patch_size[0] == 0
and raw_latent_shape[1] % patch_size[1] == 0
and raw_latent_shape[2] % patch_size[2] == 0
), f"spatial patch_resolution {raw_latent_shape} should be divisible by patch_size {patch_size}"
patch_resolution = [
t // pt for t, pt in zip(raw_latent_shape, patch_size, strict=False)
]
return VideoMobaAttentionMetadata(
current_timestep=current_timestep,
temporal_chunk_size=temporal_chunk_size,
temporal_topk=temporal_topk,
spatial_chunk_size=spatial_chunk_size,
spatial_topk=spatial_topk,
st_chunk_size=st_chunk_size,
st_topk=st_topk,
moba_select_mode=moba_select_mode,
moba_threshold=moba_threshold,
moba_threshold_type=moba_threshold_type,
patch_resolution=patch_resolution,
first_full_layer=first_full_layer,
first_full_step=first_full_step,
temporal_layer=temporal_layer,
spatial_layer=spatial_layer,
st_layer=st_layer,
)
class VMOBAAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads,
head_size,
softmax_scale,
causal=False,
num_kv_heads=None,
prefix="",
**extra_impl_args,
) -> None:
self.prefix = prefix
self.layer_idx = self._get_layer_idx(prefix)
self.pad_input = pad_input
def _get_layer_idx(self, prefix: str) -> int | None:
match = re.search(r"blocks\.(\d+)", prefix)
if not match:
raise ValueError(f"Invalid prefix: {prefix}")
return int(match.group(1))
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
"""
query: [B, L, H, D]
key: [B, L, H, D]
value: [B, L, H, D]
attn_metadata: AttentionMetadata
"""
batch_size, sequence_length, num_heads, head_dim = query.shape
# select chunk type according to layer idx:
loop_layer_num = (
attn_metadata.temporal_layer
+ attn_metadata.spatial_layer
+ attn_metadata.st_layer
)
moba_layer = self.layer_idx - attn_metadata.first_full_layer
if moba_layer % loop_layer_num < attn_metadata.temporal_layer:
moba_chunk_size = attn_metadata.temporal_chunk_size
moba_topk = attn_metadata.temporal_topk
elif (
moba_layer % loop_layer_num
< attn_metadata.temporal_layer + attn_metadata.spatial_layer
):
moba_chunk_size = attn_metadata.spatial_chunk_size
moba_topk = attn_metadata.spatial_topk
elif (
moba_layer % loop_layer_num
< attn_metadata.temporal_layer
+ attn_metadata.spatial_layer
+ attn_metadata.st_layer
):
moba_chunk_size = attn_metadata.st_chunk_size
moba_topk = attn_metadata.st_topk
query, chunk_size = process_moba_input(
query, attn_metadata.patch_resolution, moba_chunk_size
)
key, chunk_size = process_moba_input(
key, attn_metadata.patch_resolution, moba_chunk_size
)
value, chunk_size = process_moba_input(
value, attn_metadata.patch_resolution, moba_chunk_size
)
max_seqlen = query.shape[1]
indices_q = torch.arange(
0, query.shape[0] * query.shape[1], device=query.device
)
cu_seqlens = torch.arange(
0,
query.shape[0] * query.shape[1] + 1,
query.shape[1],
dtype=torch.int32,
device=query.device,
)
query = rearrange(query, "b s ... -> (b s) ...")
key = rearrange(key, "b s ... -> (b s) ...")
value = rearrange(value, "b s ... -> (b s) ...")
# current_timestep=attn_metadata.current_timestep
hidden_states = moba_attn_varlen(
query,
key,
value,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
moba_chunk_size=chunk_size,
moba_topk=moba_topk,
select_mode=attn_metadata.moba_select_mode,
simsum_threshold=attn_metadata.moba_threshold,
threshold_type=attn_metadata.moba_threshold_type,
)
hidden_states = self.pad_input(
hidden_states, indices_q, batch_size, sequence_length
)
hidden_states = process_moba_output(
hidden_states, attn_metadata.patch_resolution, moba_chunk_size
)
return hidden_states
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from typing import Type
import torch
import torch.nn as nn
from sglang.multimodal_gen.runtime.distributed.communication_op import (
sequence_model_parallel_all_gather,
sequence_model_parallel_all_to_all_4D,
)
from sglang.multimodal_gen.runtime.distributed.parallel_state import (
get_ring_parallel_world_size,
get_sequence_parallel_world_size,
get_sp_parallel_rank,
get_sp_world_size,
get_ulysses_parallel_world_size,
)
from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (
AttentionImpl,
)
from sglang.multimodal_gen.runtime.layers.attention.selector import (
backend_name_to_enum,
get_attn_backend,
)
from sglang.multimodal_gen.runtime.layers.usp import (
_usp_input_all_to_all,
_usp_output_all_to_all,
ring_attn,
)
from sglang.multimodal_gen.runtime.managers.forward_context import (
ForwardContext,
get_forward_context,
)
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
from sglang.multimodal_gen.utils import get_compute_dtype
class UlyssesAttention(nn.Module):
"""Ulysses-style SequenceParallelism attention layer."""
def __init__(
self,
num_heads: int,
head_size: int,
num_kv_heads: int | None = None,
softmax_scale: float | None = None,
causal: bool = False,
supported_attention_backends: set[AttentionBackendEnum] | None = None,
prefix: str = "",
**extra_impl_args,
) -> None:
super().__init__()
if softmax_scale is None:
self.softmax_scale = head_size**-0.5
else:
self.softmax_scale = softmax_scale
if num_kv_heads is None:
num_kv_heads = num_heads
dtype = get_compute_dtype()
attn_backend = get_attn_backend(
head_size, dtype, supported_attention_backends=supported_attention_backends
)
impl_cls = attn_backend.get_impl_cls()
self.attn_impl = impl_cls(
num_heads=num_heads,
head_size=head_size,
causal=causal,
softmax_scale=self.softmax_scale,
num_kv_heads=num_kv_heads,
prefix=f"{prefix}.impl",
**extra_impl_args,
)
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
self.backend = backend_name_to_enum(attn_backend.get_name())
self.dtype = dtype
@torch.compiler.disable
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
replicated_q: torch.Tensor | None = None,
replicated_k: torch.Tensor | None = None,
replicated_v: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Forward pass for distributed attention.
Args:
q (torch.Tensor): Query tensor [batch_size, seq_len, num_heads, head_dim]
k (torch.Tensor): Key tensor [batch_size, seq_len, num_heads, head_dim]
v (torch.Tensor): Value tensor [batch_size, seq_len, num_heads, head_dim]
replicated_q (Optional[torch.Tensor]): Replicated query tensor, typically for text tokens
replicated_k (Optional[torch.Tensor]): Replicated key tensor
replicated_v (Optional[torch.Tensor]): Replicated value tensor
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]: A tuple containing:
- o (torch.Tensor): Output tensor after attention for the main sequence
- replicated_o (Optional[torch.Tensor]): Output tensor for replicated tokens, if provided
"""
# Check input shapes
assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "Expected 4D tensors"
batch_size, seq_len, num_heads, head_dim = q.shape
local_rank = get_sp_parallel_rank()
world_size = get_sp_world_size()
forward_context: ForwardContext = get_forward_context()
ctx_attn_metadata = forward_context.attn_metadata
# Stack QKV
qkv = torch.cat([q, k, v], dim=0) # [3, seq_len, num_heads, head_dim]
# Redistribute heads across sequence dimension
qkv = sequence_model_parallel_all_to_all_4D(qkv, scatter_dim=2, gather_dim=1)
# Apply backend-specific preprocess_qkv
qkv = self.attn_impl.preprocess_qkv(qkv, ctx_attn_metadata)
# Concatenate with replicated QKV if provided
if replicated_q is not None:
assert replicated_k is not None and replicated_v is not None
replicated_qkv = torch.cat(
[replicated_q, replicated_k, replicated_v], dim=0
) # [3, seq_len, num_heads, head_dim]
heads_per_rank = num_heads // world_size
replicated_qkv = replicated_qkv[
:, :, local_rank * heads_per_rank : (local_rank + 1) * heads_per_rank
]
qkv = torch.cat([qkv, replicated_qkv], dim=1)
q, k, v = qkv.chunk(3, dim=0)
output = self.attn_impl.forward(q, k, v, ctx_attn_metadata)
# Redistribute back if using sequence parallelism
replicated_output = None
if replicated_q is not None:
replicated_output = output[:, seq_len * world_size :]
output = output[:, : seq_len * world_size]
# TODO: make this asynchronous
replicated_output = sequence_model_parallel_all_gather(
replicated_output.contiguous(), dim=2
)
# Apply backend-specific postprocess_output
output = self.attn_impl.postprocess_output(output, ctx_attn_metadata)
output = sequence_model_parallel_all_to_all_4D(
output, scatter_dim=1, gather_dim=2
)
return output, replicated_output
class UlyssesAttention_VSA(UlyssesAttention):
"""Distributed attention layer with VSA support."""
@torch.compiler.disable
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
replicated_q: torch.Tensor | None = None,
replicated_k: torch.Tensor | None = None,
replicated_v: torch.Tensor | None = None,
gate_compress: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Forward pass for distributed attention.
Args:
q (torch.Tensor): Query tensor [batch_size, seq_len, num_heads, head_dim]
k (torch.Tensor): Key tensor [batch_size, seq_len, num_heads, head_dim]
v (torch.Tensor): Value tensor [batch_size, seq_len, num_heads, head_dim]
gate_compress (torch.Tensor): Gate compress tensor [batch_size, seq_len, num_heads, head_dim]
replicated_q (Optional[torch.Tensor]): Replicated query tensor, typically for text tokens
replicated_k (Optional[torch.Tensor]): Replicated key tensor
replicated_v (Optional[torch.Tensor]): Replicated value tensor
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]: A tuple containing:
- o (torch.Tensor): Output tensor after attention for the main sequence
- replicated_o (Optional[torch.Tensor]): Output tensor for replicated tokens, if provided
"""
# Check text tokens are not supported for VSA now
assert (
replicated_q is None and replicated_k is None and replicated_v is None
), "Replicated QKV is not supported for VSA now"
# Check input shapes
assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "Expected 4D tensors"
forward_context: ForwardContext = get_forward_context()
ctx_attn_metadata = forward_context.attn_metadata
# Stack QKV
qkvg = torch.cat(
[q, k, v, gate_compress], dim=0
) # [3, seq_len, num_heads, head_dim]
# Redistribute heads across sequence dimension
qkvg = sequence_model_parallel_all_to_all_4D(qkvg, scatter_dim=2, gather_dim=1)
qkvg = self.attn_impl.preprocess_qkv(qkvg, ctx_attn_metadata)
q, k, v, gate_compress = qkvg.chunk(4, dim=0)
output = self.attn_impl.forward(
q, k, v, gate_compress=gate_compress, attn_metadata=ctx_attn_metadata
) # type: ignore[call-arg]
# Redistribute back if using sequence parallelism
replicated_output = None
# Apply backend-specific postprocess_output
output = self.attn_impl.postprocess_output(output, ctx_attn_metadata)
output = sequence_model_parallel_all_to_all_4D(
output, scatter_dim=1, gather_dim=2
)
return output, replicated_output
class LocalAttention(nn.Module):
"""Attention layer."""
def __init__(
self,
num_heads: int,
head_size: int,
num_kv_heads: int | None = None,
softmax_scale: float | None = None,
causal: bool = False,
supported_attention_backends: set[AttentionBackendEnum] | None = None,
**extra_impl_args,
) -> None:
super().__init__()
if softmax_scale is None:
self.softmax_scale = head_size**-0.5
else:
self.softmax_scale = softmax_scale
if num_kv_heads is None:
num_kv_heads = num_heads
dtype = get_compute_dtype()
attn_backend = get_attn_backend(
head_size, dtype, supported_attention_backends=supported_attention_backends
)
impl_cls = attn_backend.get_impl_cls()
self.attn_impl = impl_cls(
num_heads=num_heads,
head_size=head_size,
softmax_scale=self.softmax_scale,
num_kv_heads=num_kv_heads,
causal=causal,
**extra_impl_args,
)
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
self.backend = backend_name_to_enum(attn_backend.get_name())
self.dtype = dtype
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
"""
Apply local attention between query, key and value tensors.
Args:
q (torch.Tensor): Query tensor of shape [batch_size, seq_len, num_heads, head_dim]
k (torch.Tensor): Key tensor of shape [batch_size, seq_len, num_heads, head_dim]
v (torch.Tensor): Value tensor of shape [batch_size, seq_len, num_heads, head_dim]
Returns:
torch.Tensor: Output tensor after local attention
"""
# Check input shapes
assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "Expected 4D tensors"
forward_context: ForwardContext = get_forward_context()
ctx_attn_metadata = forward_context.attn_metadata
output = self.attn_impl.forward(q, k, v, attn_metadata=ctx_attn_metadata)
return output
class USPAttention(nn.Module):
"""
Ulysses Sequence Parallelism with Ring Attention.
This class implements the USP algorithm, which is a combination of
Ulysses-style all-to-all communication for sequence-head dimension sharding
and Ring Attention for fine-grained sequence parallelism within subgroups.
"""
def __init__(
self,
num_heads: int,
head_size: int,
num_kv_heads: int | None = None,
softmax_scale: float | None = None,
causal: bool = False,
supported_attention_backends: set[AttentionBackendEnum] | None = None,
prefix: str = "",
dropout_p: float = 0.0,
**extra_impl_args,
) -> None:
super().__init__()
if softmax_scale is None:
self.softmax_scale = head_size**-0.5
else:
self.softmax_scale = softmax_scale
if num_kv_heads is None:
num_kv_heads = num_heads
dtype = get_compute_dtype()
attn_backend = get_attn_backend(
head_size, dtype, supported_attention_backends=supported_attention_backends
)
impl_cls: Type["AttentionImpl"] = attn_backend.get_impl_cls()
self.attn_impl = impl_cls(
num_heads=num_heads,
head_size=head_size,
causal=causal,
softmax_scale=self.softmax_scale,
num_kv_heads=num_kv_heads,
prefix=f"{prefix}.impl",
**extra_impl_args,
)
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
self.backend = backend_name_to_enum(attn_backend.get_name())
self.dtype = dtype
self.causal = causal
self.dropout_p = dropout_p
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
replicated_q: torch.Tensor | None = None,
replicated_k: torch.Tensor | None = None,
replicated_v: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Forward pass for USPAttention.
q, k, v: [B, S_local, H, D]
Note: Replicated tensors are not supported in this implementation.
"""
assert (
replicated_q is None and replicated_k is None and replicated_v is None
), "USPAttention does not support replicated_qkv."
forward_context: ForwardContext = get_forward_context()
ctx_attn_metadata = forward_context.attn_metadata
if get_sequence_parallel_world_size() == 1:
# No sequence parallelism, just run local attention.
out = self.attn_impl.forward(q, k, v, ctx_attn_metadata)
return out, None
# Ulysses-style All-to-All for sequence/head sharding
if get_ulysses_parallel_world_size() > 1:
# -> [B, S, H_local, D]
q = _usp_input_all_to_all(q, head_dim=2)
k = _usp_input_all_to_all(k, head_dim=2)
v = _usp_input_all_to_all(v, head_dim=2)
# Ring Attention within subgroups or local attention
if get_ring_parallel_world_size() > 1:
out = ring_attn(
q,
k,
v,
attn_impl=self.attn_impl,
is_causal=self.causal,
dropout_p=self.dropout_p,
)
else:
# -> [B, S, H_local, D]
out = self.attn_impl.forward(q, k, v, ctx_attn_metadata)
# Ulysses-style All-to-All to restore original sharding
if get_ulysses_parallel_world_size() > 1:
# -> [B, S_local, H, D]
out = _usp_output_all_to_all(out, head_dim=2)
return out, None
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/attention/selector.py
import os
from collections.abc import Generator
from contextlib import contextmanager
from functools import cache
from typing import cast
import torch
from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (
AttentionBackend,
)
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
from sglang.multimodal_gen.runtime.server_args import get_global_server_args
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
logger = init_logger(__name__)
def backend_name_to_enum(backend_name: str) -> AttentionBackendEnum | None:
"""
Convert a string backend name to a _Backend enum value.
Returns:
* _Backend: enum value if backend_name is a valid in-tree type
* None: otherwise it's an invalid in-tree type or an out-of-tree platform is
loaded.
"""
assert backend_name is not None
return (
AttentionBackendEnum[backend_name]
if backend_name in AttentionBackendEnum.__members__
else None
)
def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
"""
Get the backend override specified by the sgl-diffusion attention
backend environment variable, if one is specified.
Returns:
* _Backend enum value if an override is specified
* None otherwise
"""
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
return None if backend_name is None else backend_name_to_enum(backend_name)
# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
# FASTVIDEO ATTENTION BACKEND ENVIRONMENT VARIABLE
forced_attn_backend: AttentionBackendEnum | None = None
def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None:
"""
Force all attention operations to use a specified backend.
Passing `None` for the argument re-enables automatic
backend selection.,
Arguments:
* attn_backend: backend selection (None to revert to auto)
"""
global forced_attn_backend
forced_attn_backend = attn_backend
def get_global_forced_attn_backend() -> AttentionBackendEnum | None:
"""
Get the currently-forced choice of attention backend,
or None if auto-selection is currently enabled.
"""
return forced_attn_backend
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
supported_attention_backends: set[AttentionBackendEnum] | None = None,
) -> type[AttentionBackend]:
if supported_attention_backends is not None:
# Sort the backend names to ensure consistent cache key
be_tuple = tuple(
sorted(list(supported_attention_backends), key=lambda b: b.name)
)
else:
be_tuple = None
return _cached_get_attn_backend(head_size, dtype, be_tuple)
@cache
def _cached_get_attn_backend(
head_size: int,
dtype: torch.dtype,
supported_attention_backends: tuple[AttentionBackendEnum] | None = None,
) -> type[AttentionBackend]:
# Check whether a particular choice of backend was
# previously forced.
#
# THIS SELECTION OVERRIDES THE SGL_DIFFUSION_ATTENTION_BACKEND
# ENVIRONMENT VARIABLE.
from sglang.multimodal_gen.runtime.platforms import current_platform
supported_attention_backends = set(supported_attention_backends)
if not supported_attention_backends:
raise ValueError("supported_attention_backends is empty")
selected_backend = None
backend_by_global_setting: AttentionBackendEnum | None = (
get_global_forced_attn_backend()
)
if backend_by_global_setting is not None:
selected_backend = backend_by_global_setting
else:
# Check the server arguments for a backend override
server_args = get_global_server_args()
if server_args.attention_backend is not None:
try:
selected_backend = AttentionBackendEnum[
server_args.attention_backend.upper()
]
except KeyError:
raise ValueError(
f"Invalid attention backend '{server_args.attention_backend}' specified via command line. "
f"Available options are: {[e.name.lower() for e in AttentionBackendEnum]}"
)
# get device-specific attn_backend
if selected_backend is None:
logger.debug(f"Attention backend not specified")
elif (
not supported_attention_backends
or selected_backend not in supported_attention_backends
):
supported_attention_backends_str = [
supported_attention_backend.__str__()
for supported_attention_backend in supported_attention_backends
]
logger.debug(
f"Selected attention backend: '{selected_backend}' not in supported attention backends: {supported_attention_backends_str}"
)
selected_backend = None
attention_cls = current_platform.get_attn_backend_cls_str(
selected_backend, head_size, dtype
)
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}"
)
return cast(type[AttentionBackend], resolve_obj_by_qualname(attention_cls))
@contextmanager
def global_force_attn_backend_context_manager(
attn_backend: AttentionBackendEnum,
) -> Generator[None, None, None]:
"""
Globally force a sgl-diffusion attention backend override within a
context manager, reverting the global attention backend
override to its prior state upon exiting the context
manager.
Arguments:
* attn_backend: attention backend to force
Returns:
* Generator
"""
# Save the current state of the global backend override (if any)
original_value = get_global_forced_attn_backend()
# Globally force the new backend override
global_force_attn_backend(attn_backend)
# Yield control back to the enclosed code block
try:
yield
finally:
# Revert the original global backend override, if any
global_force_attn_backend(original_value)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/custom_op.py
from collections.abc import Callable
from typing import Any
import torch.nn as nn
from sglang.multimodal_gen.runtime.utils.common import (
is_cpu,
is_cuda,
is_hip,
is_npu,
is_xpu,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_cpu = is_cpu()
_is_npu = is_npu()
_is_xpu = is_xpu()
class CustomOp(nn.Module):
"""
Base class for custom ops.
Dispatches the forward method to the appropriate backend.
"""
def __init__(self) -> None:
super().__init__()
self._forward_method = self.dispatch_forward()
def forward(self, *args, **kwargs) -> Any:
return self._forward_method(*args, **kwargs)
def forward_native(self, *args, **kwargs) -> Any:
"""PyTorch-native implementation of the forward method.
This method is optional. If implemented, it can be used with compilers
such as torch.compile or PyTorch XLA. Also, it can be used for testing
purposes.
"""
raise NotImplementedError
def forward_cuda(self, *args, **kwargs) -> Any:
raise NotImplementedError
def forward_cpu(self, *args, **kwargs) -> Any:
# By default, we assume that CPU ops are compatible with CUDA ops.
return self.forward_cuda(*args, **kwargs)
def forward_tpu(self, *args, **kwargs) -> Any:
# By default, we assume that TPU ops are compatible with the
# PyTorch-native implementation.
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_native(*args, **kwargs)
def forward_oot(self, *args, **kwargs) -> Any:
# By default, we assume that OOT ops are compatible with the
# PyTorch-native implementation.
return self.forward_native(*args, **kwargs)
def dispatch_forward(self) -> Callable:
if _is_cuda:
return self.forward_cuda
elif _is_hip:
return self.forward_hip
elif _is_npu:
return self.forward_npu
elif _is_xpu:
return self.forward_xpu
else:
return self.forward_native
@classmethod
def enabled(cls) -> bool:
# since we are not using Inductor, we always return True
return True
@staticmethod
def default_on() -> bool:
"""
On by default if level < CompilationLevel.PIECEWISE
Specifying 'all' or 'none' in custom_op takes precedence.
"""
raise NotImplementedError
# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
# Examples:
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry: dict[str, type["CustomOp"]] = {}
# Decorator to register custom ops.
@classmethod
def register(cls, name: str) -> Callable:
def decorator(op_cls):
assert name not in cls.op_registry, f"Duplicate op name: {name}"
op_cls.name = name
cls.op_registry[name] = op_cls
return op_cls
return decorator
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/layernorm.py
"""Custom normalization layers."""
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from sglang.multimodal_gen.runtime.layers.custom_op import CustomOp
from sglang.multimodal_gen.runtime.layers.triton_ops import (
fuse_scale_shift_kernel,
norm_infer,
rms_norm_fn,
)
from sglang.multimodal_gen.runtime.utils.common import (
get_bool_env_var,
is_cpu,
is_cuda,
is_hip,
is_npu,
is_xpu,
)
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_npu = is_npu()
_is_cpu = is_cpu()
_is_xpu = is_xpu()
from sgl_kernel import fused_add_rmsnorm, rmsnorm
# Copied and adapted from sglang
@CustomOp.register("rms_norm")
class RMSNorm(CustomOp):
"""Root mean square normalization.
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
Refer to https://arxiv.org/abs/1910.07467
"""
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
dtype: torch.dtype = torch.float32,
var_hidden_size: Optional[int] = None,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.hidden_size = hidden_size
self.variance_size_override = (
None if var_hidden_size == hidden_size else var_hidden_size
)
if get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE"):
self._forward_method = self.forward_native
def forward_triton(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
return rms_norm_fn(
x, self.weight, bias=None, residual=residual, eps=self.variance_epsilon
)
def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
shape = x.shape
x = x.view(-1, shape[-1])
if residual is not None:
residual_shape = residual.shape
residual = residual.view(-1, shape[-1])
if x.dtype == torch.float:
# fp32
out = self.forward_triton(x, residual)
elif self.variance_size_override is not None:
return self.forward_native(x, residual)
elif residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
return x.view(shape), residual.view(residual_shape)
else:
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
out = out.view(shape)
return out
def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not x.is_contiguous():
x = x.contiguous()
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
raise ValueError(
"Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}"
)
if self.variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}"
)
x_var = x[..., : self.variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = (x * self.weight).to(orig_dtype)
if residual is None:
return x
else:
return x, residual
def forward_cpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
return self.forward_native(x, residual)
def extra_repr(self) -> str:
s = f"hidden_size={self.weight.data.size(0)}"
s += f", eps={self.variance_epsilon}"
return s
# Copied and adapted from sglang
@CustomOp.register("layer_norm")
class LayerNorm(CustomOp):
def __init__(
self,
hidden_size: int,
eps=1e-5,
bias: bool = True,
elementwise_affine=True,
device=None,
dtype=None,
) -> None:
super().__init__()
self.eps = eps
factory_kwargs = {"device": device, "dtype": dtype}
self.hidden_size = hidden_size
if elementwise_affine:
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.bias = (
torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
if bias
else None
)
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
# Lazy cache for ones vector (not a registered buffer to avoid FSDP/meta issues)
self._weight_fallback_cache = None
def _get_weight_fallback(self, x: torch.Tensor) -> torch.Tensor:
wf = getattr(self, "_weight_fallback_cache", None)
if (
wf is None
or wf.device != x.device
or wf.dtype != x.dtype
or wf.numel() != self.hidden_size
):
wf = torch.ones(self.hidden_size, device=x.device, dtype=x.dtype)
self._weight_fallback_cache = wf
return wf
def forward_triton(self, x: torch.Tensor):
# Fast inference kernel without residual/dropout branches
return norm_infer(
x.view(-1, self.hidden_size),
self.weight,
self.bias,
eps=self.eps,
is_rms_norm=False,
).view(x.shape)
def forward_cuda(
self,
x: torch.Tensor,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
shape = x.shape
x = x.view(-1, self.hidden_size)
return self.forward_triton(x).view(shape)
@torch.compile(backend="inductor")
def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
input_dtype = x.dtype
mean = x.mean(-1, keepdim=True)
variance = (x - mean).pow(2).mean(-1, keepdim=True)
x = (x - mean) * torch.rsqrt(variance + self.eps)
if self.weight is not None:
x = self.weight * x
# if no affine, this is a no-op
if self.bias is not None:
x = x + self.bias
return x.to(input_dtype)
def forward_cpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
return self.forward_native(x, residual)
def extra_repr(self) -> str:
s = f"hidden_size={self.weight.data.size(0)}"
s += f", eps={self.variance_epsilon}"
return s
class ScaleResidual(nn.Module):
"""
Applies gated residual connection.
"""
def __init__(self, prefix: str = ""):
super().__init__()
def forward(
self, residual: torch.Tensor, x: torch.Tensor, gate: torch.Tensor
) -> torch.Tensor:
"""Apply gated residual connection."""
# x.shape: [batch_size, seq_len, inner_dim]
if gate.dim() == 4:
# gate.shape: [batch_size, num_frames, 1, inner_dim]
num_frames = gate.shape[1]
frame_seqlen = x.shape[1] // num_frames
return residual + (
x.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * gate
).flatten(1, 2)
else:
# gate.shape: [batch_size, 1, inner_dim]
return residual + x * gate
# adapted from Diffusers: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/normalization.py
# NOTE(will): Needed to match behavior of diffusers and wan2.1 even while using
# FSDP's MixedPrecisionPolicy
class FP32LayerNorm(nn.LayerNorm):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
origin_dtype = inputs.dtype
return F.layer_norm(
inputs.float(),
self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
).to(origin_dtype)
class ScaleResidualLayerNormScaleShift(nn.Module):
"""
Fused operation that combines:
1. Gated residual connection
2. LayerNorm
3. Scale and shift operations
This reduces memory bandwidth by combining memory-bound operations.
"""
def __init__(
self,
hidden_size: int,
norm_type: str = "rms",
eps: float = 1e-6,
elementwise_affine: bool = False,
dtype: torch.dtype = torch.float32,
compute_dtype: torch.dtype | None = None,
prefix: str = "",
):
super().__init__()
if norm_type == "rms":
self.norm = RMSNorm(
hidden_size, has_weight=elementwise_affine, eps=eps, dtype=dtype
)
elif norm_type == "layer":
if compute_dtype == torch.float32:
self.norm = FP32LayerNorm(
hidden_size, elementwise_affine=elementwise_affine, eps=eps
)
else:
self.norm = LayerNorm(
hidden_size,
elementwise_affine=elementwise_affine,
eps=eps,
dtype=dtype,
)
else:
raise NotImplementedError(f"Norm type {norm_type} not implemented")
def forward(
self,
residual: torch.Tensor,
x: torch.Tensor,
gate: torch.Tensor | int,
shift: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply gated residual connection, followed by layernorm and
scale/shift in a single fused operation.
Returns:
Tuple containing:
- normalized and modulated output of shape: [batch_size, seq_len, inner_dim]
- residual value (value after residual connection
but before normalization)
"""
# x.shape: [batch_size, seq_len, inner_dim]
# Apply residual connection with gating
if isinstance(gate, int):
# used by cross-attention, should be 1
assert gate == 1
residual_output = residual + x
elif isinstance(gate, torch.Tensor):
if gate.dim() == 4:
# gate.shape: [batch_size, num_frames, 1, inner_dim]
num_frames = gate.shape[1]
frame_seqlen = x.shape[1] // num_frames
residual_output = residual + (
x.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * gate
).flatten(1, 2)
else:
# used by bidirectional self attention
# gate.shape: [batch_size, 1, inner_dim]
residual_output = residual + x * gate
else:
raise ValueError(f"Gate type {type(gate)} not supported")
# residual_output.shape: [batch_size, seq_len, inner_dim]
# Apply normalization
normalized = self.norm(residual_output)
# modulated = fused_scale_shift(
# normalized,
# scale,
# shift,
# )
modulated = fuse_scale_shift_kernel(
normalized,
scale,
shift,
)
return modulated, residual_output
class LayerNormScaleShift(nn.Module):
"""
Fused operation that combines LayerNorm with scale and shift operations.
This reduces memory bandwidth by combining memory-bound operations.
"""
def __init__(
self,
hidden_size: int,
norm_type: str = "rms",
eps: float = 1e-6,
elementwise_affine: bool = False,
dtype: torch.dtype = torch.float32,
compute_dtype: torch.dtype | None = None,
prefix: str = "",
):
super().__init__()
self.compute_dtype = compute_dtype
if norm_type == "rms":
self.norm = RMSNorm(hidden_size, has_weight=elementwise_affine, eps=eps)
elif norm_type == "layer":
if self.compute_dtype == torch.float32:
self.norm = FP32LayerNorm(
hidden_size, elementwise_affine=elementwise_affine, eps=eps
)
else:
self.norm = nn.LayerNorm(
hidden_size,
elementwise_affine=elementwise_affine,
eps=eps,
dtype=dtype,
)
else:
raise NotImplementedError(f"Norm type {norm_type} not implemented")
def forward(
self, x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
"""Apply ln followed by scale and shift in a single fused operation."""
# x.shape: [batch_size, seq_len, inner_dim]
normalized = self.norm(x)
if self.compute_dtype == torch.float32:
normalized = normalized.float()
if scale.dim() == 4:
# scale.shape: [batch_size, num_frames, 1, inner_dim]
num_frames = scale.shape[1]
frame_seqlen = normalized.shape[1] // num_frames
output = (
normalized.unflatten(dim=1, sizes=(num_frames, frame_seqlen))
* (1.0 + scale)
+ shift
).flatten(1, 2)
else:
# scale.shape: [batch_size, 1, inner_dim]
# shift.shape: [batch_size, 1, inner_dim]
output = normalized * (1.0 + scale) + shift
if self.compute_dtype == torch.float32:
output = output.to(x.dtype)
return output
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/linear.py
from abc import abstractmethod
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from sglang.multimodal_gen.runtime.distributed import (
divide,
get_tp_rank,
get_tp_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.multimodal_gen.runtime.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
# yapf: disable
from sglang.multimodal_gen.runtime.models.parameter import (
BasevLLMParameter,
BlockQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
PerTensorScaleParameter,
RowvLLMParameter,
)
# yapf: enable
from sglang.multimodal_gen.runtime.models.utils import set_weight_attrs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod",
"AWQMarlinLinearMethod",
"AWQLinearMethod",
"GPTQMarlinLinearMethod",
"Fp8LinearMethod",
"MarlinLinearMethod",
"QQQLinearMethod",
"GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod",
"GPTQLinearMethod",
"FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod",
"IPEXAWQLinearMethod",
"IPEXGPTQLinearMethod",
"HQQMarlinMethod",
"QuarkLinearMethod",
]
def adjust_scalar_to_fused_array(
param: torch.Tensor, loaded_weight: torch.Tensor, shard_id: str | int
) -> tuple[torch.Tensor, torch.Tensor]:
"""For fused modules (QKV and MLP) we have an array of length
N that holds 1 scale for each "logical" matrix. So the param
is an array of length N. The loaded_weight corresponds to
one of the shards on disk. Here, we slice the param based on
the shard_id for loading.
"""
qkv_idxs = {"q": 0, "k": 1, "v": 2}
if isinstance(shard_id, str):
shard_id = qkv_idxs[shard_id]
elif not isinstance(shard_id, int):
raise ValueError(f"Unknown Shard Id {shard_id}")
# AutoFP8 scales do not have a shape
# compressed-tensors scales do have a shape
if len(loaded_weight.shape) != 0:
assert loaded_weight.shape[0] == 1
loaded_weight = loaded_weight[0]
return param[shard_id], loaded_weight
class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""
@abstractmethod
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
"""Create weights for a linear layer.
The weights will be set as attributes of the layer.
Args:
layer: The layer that is using the LinearMethodBase factory.
input_size_per_partition: Size of the weight input dim on rank X.
output_partition_sizes: Sizes of the output dim of each logical
weight on rank X. E.g., output_partition_sizes for QKVLinear
is a list contains the width of Wq, Wk, Wv on rank X.
input_size: Size of the input dim of the weight across all ranks.
output_size: Size of the output dim of the weight across all ranks.
params_dtype: Datatype of the parameters.
"""
raise NotImplementedError
@abstractmethod
def apply(
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
class UnquantizedLinearMethod(LinearMethodBase):
"""Linear method without quantization."""
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
weight = Parameter(
torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply(
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
output = (
F.linear(x, layer.weight, bias)
if torch.cuda.is_available() or bias is None
else F.linear(x, layer.weight, bias.to(x.dtype))
) # NOTE: this line assumes that we are using amp when using cuda and is needed to account for the fact that amp isn't supported in mps
return output
class LinearBase(torch.nn.Module):
"""Base linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_size: int,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.skip_bias_add = skip_bias_add
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.quant_config = quant_config
self.prefix = prefix
if quant_config is None:
self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, Parameter | None]:
raise NotImplementedError
class ReplicatedLinear(LinearBase):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__(
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(
self,
self.input_size,
[self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader,
)
if bias:
self.bias = Parameter(
torch.empty(
self.output_size,
dtype=self.params_dtype,
)
)
set_weight_attrs(
self.bias,
{
"output_dim": 0,
"weight_loader": self.weight_loader,
},
)
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor) -> None:
# If the weight on disk does not have a shape, give it one
# (such scales for AutoFp8).
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param.size() == loaded_weight.size(), (
f"Tried to load weights of size {loaded_weight.size()}"
f"to a parameter of size {param.size()}"
)
param.data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
return s
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Args:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias.
gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
output_sizes: list[int] | None = None,
prefix: str = "",
):
# Divide the weight matrix along the last dimension.
self.tp_size = get_tp_world_size()
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size) for output_size in self.output_sizes
]
super().__init__(
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
)
self.gather_output = gather_output
if output_sizes is None:
output_sizes = [output_size]
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
else self.weight_loader
),
)
if bias:
self.bias = Parameter(
torch.empty(
self.output_size_per_partition,
dtype=params_dtype,
)
)
set_weight_attrs(
self.bias,
{
"output_dim": 0,
"weight_loader": self.weight_loader,
},
)
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor) -> None:
tp_rank = get_tp_rank()
output_dim = getattr(param, "output_dim", None)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
is_sharded_weight = is_sharded_weight
param_data = param.data
if output_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[output_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor) -> None:
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_column_parallel_weight(loaded_weight=loaded_weight)
def forward(self, input_: torch.Tensor) -> tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size_per_partition}"
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tp_world_size()}"
s += f", gather_output={self.gather_output}"
return s
class MergedColumnParallelLinear(ColumnParallelLinear):
"""Packed linear layers with column parallelism.
Similar to ColumnParallelLinear, but the weight matrix is concatenated
along the output dimension. When the weight matrix is loaded, the
different partitions are sharded separately.
Args:
input_size: input dimension of the linear layer.
output_sizes: list of output dimensions of the linear layer.
bias: If true, add bias.
gather_output: If true, call all-gather on output and make the output
available to all GPUs, otherwise, every GPU will have
its own output.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
self.output_sizes = output_sizes
tp_size = get_tp_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(
input_size=input_size,
output_size=sum(output_sizes),
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
)
def weight_loader(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: int | None = None,
) -> None:
param_data = param.data
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)
# Special case for per-tensor scale to load scalar into fused array.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
if loaded_shard_id is None:
# Loaded weight is already fused on disk (mlp).
# (e.g., Phi-3's gate_up_proj).
if output_dim is None:
if needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, 0
)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
current_shard_offset = 0
shard_offsets: list[tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
return
assert loaded_shard_id < len(self.output_sizes)
tp_rank = get_tp_rank()
tp_size = get_tp_world_size()
if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
start_idx = tp_rank * shard_size
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
# Special case for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size)
# Special case for per-tensor scales in fused case.
elif needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, loaded_shard_id
)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
logger.warning(
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions."
)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def _load_fused_module_from_checkpoint(
self, param: BasevLLMParameter, loaded_weight: torch.Tensor
) -> None:
"""
Handle special case for models where MLP layers are already
fused on disk. In this case, we have no shard id. This function
determmines the shard id by splitting these layers and then calls
the weight loader using the shard id.
An example of a model with these fused layers:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
"""
current_shard_offset = 0
shard_offsets: list[tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if (
isinstance(param, PackedColumnParameter | PackedvLLMParameter)
and param.packed_dim == param.output_dim
):
shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset
)
loaded_weight_shard = loaded_weight.narrow(
param.output_dim, shard_offset, shard_size
)
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
def weight_loader_v2(
self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: int | None = None,
) -> None:
if loaded_shard_id is None:
if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight)
return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
assert loaded_shard_id < len(self.output_sizes)
tp_size = get_tp_world_size()
if isinstance(param, BlockQuantScaleParameter):
raise NotImplementedError("FP8 is not implemented yet")
# FIXME(will): add fp8 support
# from vllm.model_executor.layers.quantization.fp8 import (
# Fp8LinearMethod, Fp8MoEMethod)
# assert self.quant_method is not None
# assert isinstance(self.quant_method,
# (Fp8LinearMethod, Fp8MoEMethod))
# weight_block_size = self.quant_method.quant_config.weight_block_size
# assert weight_block_size is not None
# block_n, _ = weight_block_size[0], weight_block_size[1]
# shard_offset = (
# (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
# block_n) // tp_size
# shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
# block_n // tp_size)
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
param.load_merged_column_weight(
loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size,
)
class QKVParallelLinear(ColumnParallelLinear):
"""Linear layers for the attention's QKV transformation.
Linear layers for the linear transformation of the query, key, and value
vectors in the attention layer. The weight matrix is concatenated along
the output dimension. The layer is parallelized along the head dimension.
When the number of key/value heads is smaller than the number of query
heads (e.g., multi-query/grouped-query attention), the key/value head may
be replicated while the query heads are partitioned.
Args:
hidden_size: input hidden state size of the transformer.
head_size: size of each attention head.
total_num_heads: total number of attention query heads.
total_num_kv_heads: total number of attention key/value heads. If
None, assume total_num_kv_heads = total_num_heads.
bias: If true, add bias.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: int | None = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size = get_tp_world_size()
self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads:
self.num_kv_heads = 1
self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
else:
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
self.num_kv_head_replicas = 1
input_size = self.hidden_size
output_size = (
(self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
)
self.output_sizes = [
self.num_heads * self.head_size * tp_size, # q_proj
self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj
]
super().__init__(
input_size=input_size,
output_size=output_size,
bias=bias,
gather_output=False,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
)
def _get_shard_offset_mapping(self, loaded_shard_id: str) -> int | None:
shard_offset_mapping = {
"q": 0,
"k": self.num_heads * self.head_size,
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
}
return shard_offset_mapping.get(loaded_shard_id)
def _get_shard_size_mapping(self, loaded_shard_id: str) -> int | None:
shard_size_mapping = {
"q": self.num_heads * self.head_size,
"k": self.num_kv_heads * self.head_size,
"v": self.num_kv_heads * self.head_size,
}
return shard_size_mapping.get(loaded_shard_id)
def _load_fused_module_from_checkpoint(
self, param: BasevLLMParameter, loaded_weight: torch.Tensor
):
"""
Handle special case for models where QKV layers are already
fused on disk. In this case, we have no shard id. This function
determmines the shard id by splitting these layers and then calls
the weight loader using the shard id.
An example of a model with these fused layers:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
"""
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
(
"k",
self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size,
),
(
"v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.head_size,
),
]
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if (
isinstance(param, PackedColumnParameter | PackedvLLMParameter)
and param.packed_dim == param.output_dim
):
shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset
)
loaded_weight_shard = loaded_weight.narrow(
param.output_dim, shard_offset, shard_size
)
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
def weight_loader_v2(
self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: str | None = None,
):
if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter):
param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_qkv_weight(loaded_weight=loaded_weight)
return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
assert loaded_shard_id in ["q", "k", "v"]
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
shard_size = self._get_shard_size_mapping(loaded_shard_id)
param.load_qkv_weight(
loaded_weight=loaded_weight,
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size,
)
def weight_loader(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: str | None = None,
):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)
# Special case for per-tensor scales in fused case.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv).
# (e.g., Phi-3's qkv_proj).
if output_dim is None:
if needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, 0
)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
(
"k",
self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size,
),
(
"v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.head_size,
),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
return
tp_rank = get_tp_rank()
assert loaded_shard_id in ["q", "k", "v"]
# If output dim is defined, use the default loading process.
if output_dim is not None:
if loaded_shard_id == "q":
shard_offset = 0
shard_size = self.num_heads * self.head_size
elif loaded_shard_id == "k":
shard_offset = self.num_heads * self.head_size
shard_size = self.num_kv_heads * self.head_size
elif loaded_shard_id == "v":
shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
shard_size = self.num_kv_heads * self.head_size
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight
shard_idx = 0
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
if loaded_shard_id == "q":
shard_idx = tp_rank
else:
shard_idx = tp_rank // self.num_kv_head_replicas
start_idx = shard_idx * shard_size
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
# Special case for for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0]
shard_index = ["q", "k", "v"].index(loaded_shard_id)
param_data = param_data.narrow(0, shard_index * shard_size, shard_size)
# Special case for per-tensor scales in fused case.
elif needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, loaded_shard_id
)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
logger.warning(
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions."
)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
skip_bias_add: This was added to enable performance optimization where
bias can be fused with other element-wise operations.
We skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
reduce_results: bool = True,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
# Divide the weight matrix along the first dimension.
self.tp_rank = get_tp_rank()
self.tp_size = get_tp_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
super().__init__(
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
else self.weight_loader
),
)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError(
"When not reduce the results, adding bias to the "
"results can lead to incorrect results"
)
if bias:
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(
self.bias,
{
"output_dim": 0,
"weight_loader": self.weight_loader,
},
)
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tp_rank()
input_dim = getattr(param, "input_dim", None)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight
param_data = param.data
if input_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward(self, input_) -> tuple[torch.Tensor, Parameter | None]:
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tp_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size
)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def extra_repr(self) -> str:
s = f"input_features={self.input_size_per_partition}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
s += f", tp_size={self.tp_size}"
s += f", reduce_results={self.reduce_results}"
return s
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Code adapted from SGLang https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/lora/layers.py
import math
import torch
from torch import nn
from torch.distributed._composable.fsdp import (
CPUOffloadPolicy,
OffloadPolicy,
fully_shard,
)
from torch.distributed.tensor import DTensor
from sglang.multimodal_gen.runtime.distributed import (
get_local_torch_device,
get_tp_rank,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.multimodal_gen.runtime.layers.linear import (
ColumnParallelLinear,
LinearBase,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.multimodal_gen.runtime.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.multimodal_gen.utils import get_mixed_precision_state
torch._dynamo.config.recompile_limit = 16
class BaseLayerWithLoRA(nn.Module):
def __init__(
self,
base_layer: nn.Module,
lora_rank: int | None = None,
lora_alpha: int | None = None,
training_mode: bool = False,
):
super().__init__()
self.base_layer: nn.Module = base_layer
self.merged: bool = False
self.cpu_weight = base_layer.weight.to("cpu")
# indicates adapter weights don't contain this layer
# (which shouldn't normally happen, but we want to separate it from the case of erroneous merging)
self.disable_lora: bool = False
self.lora_rank = lora_rank
self.lora_alpha = lora_alpha
self.training_mode = training_mode
self.lora_path: str | None = None
if training_mode:
assert (
self.lora_rank is not None
), "LoRA rank must be set for training mode"
if self.lora_rank is None or self.lora_alpha is None:
self.lora_alpha = lora_rank
self.base_layer.requires_grad_(False)
in_dim = self.base_layer.weight.shape[1]
out_dim = self.base_layer.weight.shape[0]
self.lora_A = nn.Parameter(
torch.zeros(
self.lora_rank,
in_dim,
device=self.base_layer.weight.device,
dtype=self.base_layer.weight.dtype,
)
)
self.lora_B = nn.Parameter(
torch.zeros(
out_dim,
self.lora_rank,
device=self.base_layer.weight.device,
dtype=self.base_layer.weight.dtype,
)
)
torch.nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_B)
else:
self.lora_A = None
self.lora_B = None
@torch.compile()
def forward(self, x: torch.Tensor) -> torch.Tensor:
lora_A = self.lora_A
lora_B = self.lora_B
if isinstance(self.lora_B, DTensor):
lora_B = self.lora_B.to_local()
lora_A = self.lora_A.to_local()
if not self.merged and not self.disable_lora:
lora_A_sliced = self.slice_lora_a_weights(lora_A.to(x, non_blocking=True))
lora_B_sliced = self.slice_lora_b_weights(lora_B.to(x, non_blocking=True))
delta = x @ lora_A_sliced.T @ lora_B_sliced.T
if self.lora_alpha != self.lora_rank:
delta = delta * (
self.lora_alpha / self.lora_rank # type: ignore
) # type: ignore
out, output_bias = self.base_layer(x)
return out + delta, output_bias
else:
out, output_bias = self.base_layer(x)
return out.to(x), output_bias
def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor:
return A
def slice_lora_b_weights(self, B: torch.Tensor) -> torch.Tensor:
return B
def set_lora_weights(
self,
A: torch.Tensor,
B: torch.Tensor,
training_mode: bool = False,
lora_path: str | None = None,
) -> None:
self.lora_A = torch.nn.Parameter(
A
) # share storage with weights in the pipeline
self.lora_B = torch.nn.Parameter(B)
self.disable_lora = False
if not training_mode:
self.merge_lora_weights()
self.lora_path = lora_path
@torch.no_grad()
def merge_lora_weights(self) -> None:
if self.disable_lora:
return
if self.merged:
self.unmerge_lora_weights()
assert (
self.lora_A is not None and self.lora_B is not None
), "LoRA weights not set. Please set them first."
if isinstance(self.base_layer.weight, DTensor):
mesh = self.base_layer.weight.data.device_mesh
unsharded_base_layer = ReplicatedLinear(
input_size=self.base_layer.input_size,
output_size=self.base_layer.output_size,
bias=getattr(self.base_layer, "bias", None) is not None,
skip_bias_add=self.base_layer.skip_bias_add,
params_dtype=self.base_layer.params_dtype,
quant_config=self.base_layer.quant_config,
prefix=self.base_layer.prefix,
)
# Using offload param is on CPU, so current_device is for "CPU -> GPU -> merge -> CPU"
current_device = self.base_layer.weight.data.device
data = self.base_layer.weight.data.to(
get_local_torch_device()
).full_tensor()
data += self.slice_lora_b_weights(self.lora_B).to(
data
) @ self.slice_lora_a_weights(self.lora_A).to(data)
unsharded_base_layer.weight = nn.Parameter(data.to(current_device))
if isinstance(getattr(self.base_layer, "bias", None), DTensor):
unsharded_base_layer.bias = nn.Parameter(
self.base_layer.bias.to(get_local_torch_device(), non_blocking=True)
.full_tensor()
.to(current_device)
)
offload_policy = (
CPUOffloadPolicy() if "cpu" in str(current_device) else OffloadPolicy()
)
mp_policy = get_mixed_precision_state().mp_policy
self.base_layer = fully_shard(
unsharded_base_layer,
mesh=mesh,
mp_policy=mp_policy,
offload_policy=offload_policy,
)
else:
current_device = self.base_layer.weight.data.device
data = self.base_layer.weight.data.to(get_local_torch_device())
data += self.slice_lora_b_weights(
self.lora_B.to(data)
) @ self.slice_lora_a_weights(self.lora_A.to(data))
self.base_layer.weight.data = data.to(current_device, non_blocking=True)
self.merged = True
@torch.no_grad()
# @torch.compile(dynamic=True)
def unmerge_lora_weights(self) -> None:
if self.disable_lora:
return
if not self.merged:
raise ValueError(
"LoRA weights not merged. Please merge them first before unmerging."
)
# avoid precision loss
if isinstance(self.base_layer.weight, DTensor):
device = self.base_layer.weight.data.device
self.base_layer.weight = nn.Parameter(
self.cpu_weight.to(device, non_blocking=True)
)
else:
self.base_layer.weight.data = self.cpu_weight.data.to(
self.base_layer.weight, non_blocking=True
)
self.merged = False
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
"""
Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation).
Note: The current version does not yet implement the LoRA functionality.
This class behaves exactly the same as the base VocabParallelEmbedding.
Future versions will integrate LoRA functionality to support efficient parameter fine-tuning.
"""
def __init__(
self,
base_layer: VocabParallelEmbedding,
) -> None:
super().__init__(base_layer)
def forward(self, input_: torch.Tensor) -> torch.Tensor:
raise NotImplementedError(
"We don't support VocabParallelEmbeddingWithLoRA yet."
)
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self,
base_layer: ColumnParallelLinear,
lora_rank: int | None = None,
lora_alpha: int | None = None,
training_mode: bool = False,
) -> None:
super().__init__(base_layer, lora_rank, lora_alpha, training_mode)
def forward(self, input_: torch.Tensor) -> torch.Tensor:
# duplicate the logic in ColumnParallelLinear
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
output_parallel = self.base_layer.quant_method.apply(
self.base_layer, input_, bias
)
if self.base_layer.gather_output:
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
return output, output_bias
def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor:
return A
def slice_lora_b_weights(self, B: torch.Tensor) -> torch.Tensor:
tp_rank = get_tp_rank()
shard_size = self.base_layer.output_partition_sizes[0]
start_idx = tp_rank * shard_size
end_idx = (tp_rank + 1) * shard_size
B = B[start_idx:end_idx, :]
return B
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self,
base_layer: MergedColumnParallelLinear,
lora_rank: int | None = None,
lora_alpha: int | None = None,
training_mode: bool = False,
) -> None:
super().__init__(base_layer, lora_rank, lora_alpha, training_mode)
def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor:
return A.to(self.base_layer.weight)
def slice_lora_b_weights(self, B: torch.Tensor) -> torch.Tensor:
tp_rank = get_tp_rank()
# Since the outputs for both gate and up are identical, we use a random one.
shard_size = self.base_layer.output_partition_sizes[0]
start_idx = tp_rank * shard_size
end_idx = (tp_rank + 1) * shard_size
return B[:, start_idx:end_idx, :]
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self,
base_layer: QKVParallelLinear,
lora_rank: int | None = None,
lora_alpha: int | None = None,
training_mode: bool = False,
) -> None:
super().__init__(base_layer, lora_rank, lora_alpha, training_mode)
def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor:
return A
def slice_lora_b_weights(
self, B: list[torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor]:
tp_rank = get_tp_rank()
B_q, B_kv = B
base_layer = self.base_layer
q_proj_shard_size = base_layer.q_proj_shard_size
kv_proj_shard_size = base_layer.kv_proj_shard_size
num_kv_head_replicas = base_layer.num_kv_head_replicas
q_start_idx = q_proj_shard_size * tp_rank
q_end_idx = q_start_idx + q_proj_shard_size
kv_shard_id = tp_rank // num_kv_head_replicas
kv_start_idx = kv_proj_shard_size * kv_shard_id
kv_end_idx = kv_start_idx + kv_proj_shard_size
return B_q[q_start_idx:q_end_idx, :], B_kv[:, kv_start_idx:kv_end_idx, :]
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self,
base_layer: RowParallelLinear,
lora_rank: int | None = None,
lora_alpha: int | None = None,
training_mode: bool = False,
) -> None:
super().__init__(base_layer, lora_rank, lora_alpha, training_mode)
def forward(self, input_: torch.Tensor):
# duplicate the logic in RowParallelLinear
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tp_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size
)
input_parallel = splitted_input[tp_rank].contiguous()
output_parallel = self.base_layer.quant_method.apply(
self.base_layer, input_parallel
)
if self.set_lora:
output_parallel = self.apply_lora(output_parallel, input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
if not self.base_layer.skip_bias_add:
output = (
output_ + self.base_layer.bias
if self.base_layer.bias is not None
else output_
)
output_bias = None
else:
output = output_
output_bias = self.base_layer.bias
return output, output_bias
def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor:
tp_rank = get_tp_rank()
shard_size = self.base_layer.input_size_per_partition
start_idx = tp_rank * shard_size
end_idx = (tp_rank + 1) * shard_size
A = A[:, start_idx:end_idx].contiguous()
return A
def slice_lora_b_weights(self, B: torch.Tensor) -> torch.Tensor:
return B
def get_lora_layer(
layer: nn.Module,
lora_rank: int | None = None,
lora_alpha: int | None = None,
training_mode: bool = False,
) -> BaseLayerWithLoRA | None:
supported_layer_types: dict[type[LinearBase], type[BaseLayerWithLoRA]] = {
# the order matters
# VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
QKVParallelLinear: QKVParallelLinearWithLoRA,
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
RowParallelLinear: RowParallelLinearWithLoRA,
ReplicatedLinear: BaseLayerWithLoRA,
}
for src_layer_type, lora_layer_type in supported_layer_types.items():
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
ret = lora_layer_type(
layer,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
training_mode=training_mode,
)
return ret
return None
# source: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/vllm/lora/utils.py#L9
def replace_submodule(
model: nn.Module, module_name: str, new_module: nn.Module
) -> nn.Module:
"""Replace a submodule in a model with a new module."""
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
target_name = module_name.split(".")[-1]
setattr(parent, target_name, new_module)
return new_module
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment