Commit c07946d8 authored by hepj's avatar hepj
Browse files

dit & video

parents
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/envs.py
import os
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
if TYPE_CHECKING:
FASTVIDEO_RINGBUFFER_WARNING_INTERVAL: int = 60
FASTVIDEO_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
FASTVIDEO_CACHE_ROOT: str = os.path.expanduser("~/.cache/fastvideo")
FASTVIDEO_CONFIG_ROOT: str = os.path.expanduser("~/.config/fastvideo")
FASTVIDEO_CONFIGURE_LOGGING: int = 1
FASTVIDEO_LOGGING_LEVEL: str = "INFO"
FASTVIDEO_LOGGING_PREFIX: str = ""
FASTVIDEO_LOGGING_CONFIG_PATH: Optional[str] = None
FASTVIDEO_TRACE_FUNCTION: int = 0
FASTVIDEO_ATTENTION_BACKEND: Optional[str] = None
FASTVIDEO_ATTENTION_CONFIG: Optional[str] = None
FASTVIDEO_WORKER_MULTIPROC_METHOD: str = "fork"
FASTVIDEO_TARGET_DEVICE: str = "cuda"
MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None
CMAKE_BUILD_TYPE: Optional[str] = None
VERBOSE: bool = False
FASTVIDEO_SERVER_DEV_MODE: bool = False
def get_default_cache_root() -> str:
return os.getenv(
"XDG_CACHE_HOME",
os.path.join(os.path.expanduser("~"), ".cache"),
)
def get_default_config_root() -> str:
return os.getenv(
"XDG_CONFIG_HOME",
os.path.join(os.path.expanduser("~"), ".config"),
)
def maybe_convert_int(value: Optional[str]) -> Optional[int]:
if value is None:
return None
return int(value)
# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.
# begin-env-vars-definition
environment_variables: Dict[str, Callable[[], Any]] = {
# ================== Installation Time Env Vars ==================
# Target device of FastVideo, supporting [cuda (by default),
# rocm, neuron, cpu, openvino]
"FASTVIDEO_TARGET_DEVICE":
lambda: os.getenv("FASTVIDEO_TARGET_DEVICE", "cuda"),
# Maximum number of compilation jobs to run in parallel.
# By default this is the number of CPUs
"MAX_JOBS":
lambda: os.getenv("MAX_JOBS", None),
# Number of threads to use for nvcc
# By default this is 1.
# If set, `MAX_JOBS` will be reduced to avoid oversubscribing the CPU.
"NVCC_THREADS":
lambda: os.getenv("NVCC_THREADS", None),
# If set, fastvideo will use precompiled binaries (*.so)
"FASTVIDEO_USE_PRECOMPILED":
lambda: bool(os.environ.get("FASTVIDEO_USE_PRECOMPILED")) or bool(
os.environ.get("FASTVIDEO_PRECOMPILED_WHEEL_LOCATION")),
# CMake build type
# If not set, defaults to "Debug" or "RelWithDebInfo"
# Available options: "Debug", "Release", "RelWithDebInfo"
"CMAKE_BUILD_TYPE":
lambda: os.getenv("CMAKE_BUILD_TYPE"),
# If set, fastvideo will print verbose logs during installation
"VERBOSE":
lambda: bool(int(os.getenv('VERBOSE', '0'))),
# Root directory for FASTVIDEO configuration files
# Defaults to `~/.config/fastvideo` unless `XDG_CONFIG_HOME` is set
# Note that this not only affects how fastvideo finds its configuration files
# during runtime, but also affects how fastvideo installs its configuration
# files during **installation**.
"FASTVIDEO_CONFIG_ROOT":
lambda: os.path.expanduser(
os.getenv(
"FASTVIDEO_CONFIG_ROOT",
os.path.join(get_default_config_root(), "fastvideo"),
)),
# ================== Runtime Env Vars ==================
# Root directory for FASTVIDEO cache files
# Defaults to `~/.cache/fastvideo` unless `XDG_CACHE_HOME` is set
"FASTVIDEO_CACHE_ROOT":
lambda: os.path.expanduser(
os.getenv(
"FASTVIDEO_CACHE_ROOT",
os.path.join(get_default_cache_root(), "fastvideo"),
)),
# Interval in seconds to log a warning message when the ring buffer is full
"FASTVIDEO_RINGBUFFER_WARNING_INTERVAL":
lambda: int(os.environ.get("FASTVIDEO_RINGBUFFER_WARNING_INTERVAL", "60")),
# Path to the NCCL library file. It is needed because nccl>=2.19 brought
# by PyTorch contains a bug: https://github.com/NVIDIA/nccl/issues/1234
"FASTVIDEO_NCCL_SO_PATH":
lambda: os.environ.get("FASTVIDEO_NCCL_SO_PATH", None),
# when `FASTVIDEO_NCCL_SO_PATH` is not set, fastvideo will try to find the nccl
# library file in the locations specified by `LD_LIBRARY_PATH`
"LD_LIBRARY_PATH":
lambda: os.environ.get("LD_LIBRARY_PATH", None),
# Internal flag to enable Dynamo fullgraph capture
"FASTVIDEO_TEST_DYNAMO_FULLGRAPH_CAPTURE":
lambda: bool(
os.environ.get("FASTVIDEO_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK":
lambda: int(os.environ.get("LOCAL_RANK", "0")),
# used to control the visible devices in the distributed setting
"CUDA_VISIBLE_DEVICES":
lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None),
# timeout for each iteration in the engine
"FASTVIDEO_ENGINE_ITERATION_TIMEOUT_S":
lambda: int(os.environ.get("FASTVIDEO_ENGINE_ITERATION_TIMEOUT_S", "60")),
# Logging configuration
# If set to 0, fastvideo will not configure logging
# If set to 1, fastvideo will configure logging using the default configuration
# or the configuration file specified by FASTVIDEO_LOGGING_CONFIG_PATH
"FASTVIDEO_CONFIGURE_LOGGING":
lambda: int(os.getenv("FASTVIDEO_CONFIGURE_LOGGING", "1")),
"FASTVIDEO_LOGGING_CONFIG_PATH":
lambda: os.getenv("FASTVIDEO_LOGGING_CONFIG_PATH"),
# this is used for configuring the default logging level
"FASTVIDEO_LOGGING_LEVEL":
lambda: os.getenv("FASTVIDEO_LOGGING_LEVEL", "INFO"),
# if set, FASTVIDEO_LOGGING_PREFIX will be prepended to all log messages
"FASTVIDEO_LOGGING_PREFIX":
lambda: os.getenv("FASTVIDEO_LOGGING_PREFIX", ""),
# Trace function calls
# If set to 1, fastvideo will trace function calls
# Useful for debugging
"FASTVIDEO_TRACE_FUNCTION":
lambda: int(os.getenv("FASTVIDEO_TRACE_FUNCTION", "0")),
# Backend for attention computation
# Available options:
# - "TORCH_SDPA": use torch.nn.MultiheadAttention
# - "FLASH_ATTN": use FlashAttention
# - "SLIDING_TILE_ATTN" : use Sliding Tile Attention
# - "SAGE_ATTN": use Sage Attention
"FASTVIDEO_ATTENTION_BACKEND":
lambda: os.getenv("FASTVIDEO_ATTENTION_BACKEND", None),
# Path to the attention configuration file. Only used for sliding tile
# attention for now.
"FASTVIDEO_ATTENTION_CONFIG":
lambda: (None if os.getenv("FASTVIDEO_ATTENTION_CONFIG", None) is None else
os.path.expanduser(os.getenv("FASTVIDEO_ATTENTION_CONFIG", "."))),
# Use dedicated multiprocess context for workers.
# Both spawn and fork work
"FASTVIDEO_WORKER_MULTIPROC_METHOD":
lambda: os.getenv("FASTVIDEO_WORKER_MULTIPROC_METHOD", "fork"),
# Enables torch profiler if set. Path to the directory where torch profiler
# traces are saved. Note that it must be an absolute path.
"FASTVIDEO_TORCH_PROFILER_DIR":
lambda: (None
if os.getenv("FASTVIDEO_TORCH_PROFILER_DIR", None) is None else os.
path.expanduser(os.getenv("FASTVIDEO_TORCH_PROFILER_DIR", "."))),
# If set, fastvideo will run in development mode, which will enable
# some additional endpoints for developing and debugging,
# e.g. `/reset_prefix_cache`
"FASTVIDEO_SERVER_DEV_MODE":
lambda: bool(int(os.getenv("FASTVIDEO_SERVER_DEV_MODE", "0"))),
}
# end-env-vars-definition
def __getattr__(name: str):
# lazy evaluation of environment variables
if name in environment_variables:
return environment_variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__():
return list(environment_variables.keys())
# Basic Video Generation Tutorial
The `VideoGenerator` class provides the primary Python interface for doing offline video generation, which is interacting with a diffusion pipeline without using a separate inference api server.
## Usage
The first script in this example shows the most basic usage of FastVideo. If you are new to Python and FastVideo, you should start here.
```bash
python fastvideo/v1/examples/inference/basic/basic.py
```
## Basic Walkthrough
All you need to generate videos using multi-gpus from state-of-the-art diffusion pipelines is the following few lines!
```python
from fastvideo import VideoGenerator
generator = VideoGenerator.from_pretrained(
"FastVideo/FastHunyuan-Diffusers",
num_gpus=2,
)
prompt = "A beautiful woman in a red dress walking down a street"
video = generator.generate_video(prompt)
```
More to come! These examples and APIs are still under construction!
\ No newline at end of file
from fastvideo import VideoGenerator
# from fastvideo.v1.configs.sample import SamplingParam
def main():
# FastVideo will automatically use the optimal default arguments for the
# model.
# If a local path is provided, FastVideo will make a best effort
# attempt to identify the optimal arguments.
generator = VideoGenerator.from_pretrained(
"FastVideo/FastHunyuan-diffusers",
# if num_gpus > 1, FastVideo will automatically handle distributed setup
num_gpus=4,
)
# sampling_param = SamplingParam.from_pretrained("/workspace/data/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers")
# sampling_param.num_frames = 45
# sampling_param.image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
# Generate videos with the same simple API, regardless of GPU count
prompt = "A beautiful woman in a red dress walking down a street"
video = generator.generate_video(prompt)
# video = generator.generate_video(prompt, sampling_param=sampling_param, output_path="wan_t2v_videos/")
# Generate another video with a different prompt, without reloading the
# model!
prompt2 = "A beautiful woman in a blue dress walking down a street"
video2 = generator.generate_video(prompt2)
if __name__ == "__main__":
main()
# FastVideo VideoGenerator Gradio Demo
This is a Gradio-based web interface for generating videos using the FastVideo framework. The demo allows users to create videos from text prompts with various customization options.
## Overview
The demo uses the FastVideo framework to generate videos based on text prompts. It provides a simple web interface built with Gradio that allows users to:
- Enter text prompts to generate videos
- Customize video parameters (dimensions, number of frames, etc.)
- Use negative prompts to guide the generation process
- Set or randomize seeds for reproducibility
---
## Requirements
- Linux-based OS
- Python 3.10
- Cuda 12.4
- FastVideo
## Installation
```bash
pip install fastvideo
```
## Usage
Run the demo with:
```bash
python fastvideo/v1/examples/inference/gradio/gradio_demo.py
```
This will start a web server at `http://0.0.0.0:7860` where you can access the interface.
---
## Model Initialization
```python
args = FastVideoArgs(model_path="FastVideo/FastHunyuan-Diffusers", num_gpus=2)
generator = VideoGenerator.from_pretrained(
model_path=args.model_path,
num_gpus=args.num_gpus
)
```
This demo initializes a `VideoGenerator` with the minimum required arguments for inference. Users can seamlessly adjust inference options between generations, including prompts, resolution, video length, or even the number of inference steps, *without ever needing to reload the model*.
## Video Generation
The core functionality is in the `generate_video` function, which:
1. Processes user inputs
2. Uses the FastVideo VideoGenerator from earlier to run inference (`generator.generate_video()`)
3. Returns an output path that Gradio uses to display the generated video
## Gradio Interface
The interface is built with several components:
- A text input for the prompt
- A video display for the result
- Inference options in a collapsible accordion:
- Height and width sliders
- Number of frames slider
- Guidance scale slider
- Inference steps slider
- Negative prompt options
- Seed controls
### Inference Options
- **Height/Width**: Control the resolution of the generated video
- **Number of Frames**: Set how many frames to generate
- **Guidance Scale**: Control how closely the generation follows the prompt
- **Inference Steps**: More steps can improve quality but take longer
- **Negative Prompt**: Specify what you don't want to see in the video
- **Seed**: Control randomness for reproducible results
\ No newline at end of file
import os
import gradio as gr
import torch
import argparse
from copy import deepcopy
from fastvideo.v1.fastvideo_args import FastVideoArgs
from fastvideo import VideoGenerator
from fastvideo.v1.configs.sample.base import SamplingParam
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="FastVideo Gradio Demo")
parser.add_argument("--model_path", type=str, default="FastVideo/FastHunyuan-diffusers", help="Path to the model")
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use")
parser.add_argument("--output_path", type=str, default="outputs", help="Path to save generated videos")
parsed_args = parser.parse_args()
# args = FastVideoArgs(model_path="FastVideo/FastHunyuan-Diffusers", num_gpus=2)
generator = VideoGenerator.from_pretrained(
model_path=parsed_args.model_path,
num_gpus=parsed_args.num_gpus
)
default_params = SamplingParam.from_pretrained(parsed_args.model_path)
def generate_video(
prompt,
negative_prompt,
use_negative_prompt,
seed,
guidance_scale,
num_frames,
height,
width,
num_inference_steps,
randomize_seed=False,
):
params = deepcopy(default_params)
params.prompt = prompt
params.negative_prompt = negative_prompt
params.seed = seed
params.guidance_scale = guidance_scale
params.num_frames = num_frames
params.height = height
params.width = width
params.num_inference_steps = num_inference_steps
if randomize_seed:
params.seed = torch.randint(0, 1000000, (1, )).item()
if not use_negative_prompt:
params.negative_prompt = None
generator.generate_video(
prompt=prompt,
sampling_param=params
)
output_path = os.path.join(parsed_args.output_path, f"{params.prompt[:100]}.mp4")
return output_path, params.seed
examples = [
"A hand enters the frame, pulling a sheet of plastic wrap over three balls of dough placed on a wooden surface. The plastic wrap is stretched to cover the dough more securely. The hand adjusts the wrap, ensuring that it is tight and smooth over the dough. The scene focuses on the hand’s movements as it secures the edges of the plastic wrap. No new objects appear, and the camera remains stationary, focusing on the action of covering the dough.",
"A vintage train snakes through the mountains, its plume of white steam rising dramatically against the jagged peaks. The cars glint in the late afternoon sun, their deep crimson and gold accents lending a touch of elegance. The tracks carve a precarious path along the cliffside, revealing glimpses of a roaring river far below. Inside, passengers peer out the large windows, their faces lit with awe as the landscape unfolds.",
"A crowded rooftop bar buzzes with energy, the city skyline twinkling like a field of stars in the background. Strings of fairy lights hang above, casting a warm, golden glow over the scene. Groups of people gather around high tables, their laughter blending with the soft rhythm of live jazz. The aroma of freshly mixed cocktails and charred appetizers wafts through the air, mingling with the cool night breeze.",
]
with gr.Blocks() as demo:
gr.Markdown("# FastVideo Inference Demo")
with gr.Group():
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Video(label="Result", show_label=False)
with gr.Accordion("Advanced options", open=False):
with gr.Group():
with gr.Row():
height = gr.Slider(
label="Height",
minimum=256,
maximum=1024,
step=32,
value=default_params.height,
)
width = gr.Slider(
label="Width",
minimum=256,
maximum=1024, step=32, value=default_params.width)
with gr.Row():
num_frames = gr.Slider(
label="Number of Frames",
minimum=21,
maximum=163,
value=default_params.num_frames,
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=12,
value=default_params.guidance_scale,
)
num_inference_steps = gr.Slider(
label="Inference Steps",
minimum=4,
maximum=100,
value=default_params.num_inference_steps,
)
with gr.Row():
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=False,
)
seed = gr.Slider(label="Seed", minimum=0, maximum=1000000, step=1, value=default_params.seed)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
seed_output = gr.Number(label="Used Seed")
gr.Examples(examples=examples, inputs=prompt)
use_negative_prompt.change(
fn=lambda x: gr.update(visible=x),
inputs=use_negative_prompt,
outputs=default_params.negative_prompt,
)
run_button.click(
fn=generate_video,
inputs=[
prompt,
negative_prompt,
use_negative_prompt,
seed,
guidance_scale,
num_frames,
height,
width,
num_inference_steps,
randomize_seed,
],
outputs=[result, seed_output],
)
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860)
# SPDX-License-Identifier: Apache-2.0
# Inspired by SGLang: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py
"""The arguments of FastVideo Inference."""
import argparse
import dataclasses
from contextlib import contextmanager
from dataclasses import field
from typing import Any, Callable, List, Optional, Tuple
from fastvideo.v1.configs.models import DiTConfig, EncoderConfig, VAEConfig
from fastvideo.v1.logger import init_logger
from fastvideo.v1.utils import FlexibleArgumentParser
logger = init_logger(__name__)
def preprocess_text(prompt: str) -> str:
return prompt
def postprocess_text(output: Any) -> Any:
raise NotImplementedError
@dataclasses.dataclass
class FastVideoArgs:
# Model and path configuration
model_path: str
# Distributed executor backend
distributed_executor_backend: str = "mp"
inference_mode: bool = True # if False == training mode
# HuggingFace specific parameters
trust_remote_code: bool = False
revision: Optional[str] = None
# Parallelism
num_gpus: int = 1
tp_size: Optional[int] = None
sp_size: Optional[int] = None
dist_timeout: Optional[int] = None # timeout for torch.distributed
# Video generation parameters
embedded_cfg_scale: float = 6.0
flow_shift: Optional[float] = None
output_type: str = "pil"
# DiT configuration
dit_config: DiTConfig = field(default_factory=DiTConfig)
precision: str = "bf16"
# VAE configuration
vae_precision: str = "fp16"
vae_tiling: bool = True # Might change in between forward passes
vae_sp: bool = False # Might change in between forward passes
# vae_scale_factor: Optional[int] = None # Deprecated
vae_config: VAEConfig = field(default_factory=VAEConfig)
# Image encoder configuration
image_encoder_precision: str = "fp32"
image_encoder_config: EncoderConfig = field(default_factory=EncoderConfig)
# Text encoder configuration
text_encoder_precisions: Tuple[str, ...] = field(
default_factory=lambda: ("fp16", ))
text_encoder_configs: Tuple[EncoderConfig, ...] = field(
default_factory=lambda: (EncoderConfig(), ))
preprocess_text_funcs: Tuple[Callable[[str], str], ...] = field(
default_factory=lambda: (preprocess_text, ))
postprocess_text_funcs: Tuple[Callable[[Any], Any], ...] = field(
default_factory=lambda: (postprocess_text, ))
# STA (Spatial-Temporal Attention) parameters
mask_strategy_file_path: Optional[str] = None
enable_torch_compile: bool = False
use_cpu_offload: bool = False
disable_autocast: bool = False
# Logging
log_level: str = "info"
# Inference parameters
device_str: Optional[str] = None
device = None
def __post_init__(self):
pass
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
# Model and path configuration
parser.add_argument(
"--model-path",
type=str,
required=True,
help=
"The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
)
parser.add_argument(
"--dit-weight",
type=str,
help="Path to the DiT model weights",
)
parser.add_argument(
"--model-dir",
type=str,
help="Directory containing StepVideo model",
)
# distributed_executor_backend
parser.add_argument(
"--distributed-executor-backend",
type=str,
choices=["mp"],
default=FastVideoArgs.distributed_executor_backend,
help="The distributed executor backend to use",
)
# HuggingFace specific parameters
parser.add_argument(
"--trust-remote-code",
action="store_true",
default=FastVideoArgs.trust_remote_code,
help="Trust remote code when loading HuggingFace models",
)
parser.add_argument(
"--revision",
type=str,
default=FastVideoArgs.revision,
help=
"The specific model version to use (can be a branch name, tag name, or commit id)",
)
# Parallelism
parser.add_argument(
"--num-gpus",
type=int,
default=FastVideoArgs.num_gpus,
help="The number of GPUs to use.",
)
parser.add_argument(
"--tensor-parallel-size",
"--tp-size",
type=int,
default=FastVideoArgs.tp_size,
help="The tensor parallelism size.",
)
parser.add_argument(
"--sequence-parallel-size",
"--sp-size",
type=int,
default=FastVideoArgs.sp_size,
help="The sequence parallelism size.",
)
parser.add_argument(
"--dist-timeout",
type=int,
default=FastVideoArgs.dist_timeout,
help="Set timeout for torch.distributed initialization.",
)
parser.add_argument(
"--embedded-cfg-scale",
type=float,
default=FastVideoArgs.embedded_cfg_scale,
help="Embedded CFG scale",
)
parser.add_argument(
"--flow-shift",
"--shift",
type=float,
default=FastVideoArgs.flow_shift,
help="Flow shift parameter",
)
parser.add_argument(
"--output-type",
type=str,
default=FastVideoArgs.output_type,
choices=["pil"],
help="Output type for the generated video",
)
parser.add_argument(
"--precision",
type=str,
default=FastVideoArgs.precision,
choices=["fp32", "fp16", "bf16"],
help="Precision for the model",
)
# VAE configuration
parser.add_argument(
"--vae-precision",
type=str,
default=FastVideoArgs.vae_precision,
choices=["fp32", "fp16", "bf16"],
help="Precision for VAE",
)
parser.add_argument(
"--vae-tiling",
action="store_true",
default=FastVideoArgs.vae_tiling,
help="Enable VAE tiling",
)
parser.add_argument(
"--vae-sp",
action="store_true",
help="Enable VAE spatial parallelism",
)
parser.add_argument(
"--text-encoder-precision",
nargs="+",
type=str,
default=FastVideoArgs.text_encoder_precisions,
choices=["fp32", "fp16", "bf16"],
help="Precision for each text encoder",
)
# Image encoder config
parser.add_argument(
"--image-encoder-precision",
type=str,
default=FastVideoArgs.image_encoder_precision,
choices=["fp32", "fp16", "bf16"],
help="Precision for image encoder",
)
# STA (Spatial-Temporal Attention) parameters
parser.add_argument(
"--mask-strategy-file-path",
type=str,
help="Path to mask strategy JSON file for STA",
)
parser.add_argument(
"--enable-torch-compile",
action="store_true",
help=
"Use torch.compile for speeding up STA inference without teacache",
)
parser.add_argument(
"--use-cpu-offload",
action="store_true",
help="Use CPU offload for the model load",
)
parser.add_argument(
"--disable-autocast",
action="store_true",
help=
"Disable autocast for denoising loop and vae decoding in pipeline sampling",
)
# Logging
parser.add_argument(
"--log-level",
type=str,
default=FastVideoArgs.log_level,
help="The logging level of all loggers.",
)
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs":
args.tp_size = args.tensor_parallel_size
args.sp_size = args.sequence_parallel_size
args.flow_shift = getattr(args, "shift", args.flow_shift)
# Get all fields from the dataclass
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Create a dictionary of attribute values, with defaults for missing attributes
kwargs = {}
for attr in attrs:
# Handle renamed attributes or those with multiple CLI names
if attr == 'tp_size' and hasattr(args, 'tensor_parallel_size'):
kwargs[attr] = args.tensor_parallel_size
elif attr == 'sp_size' and hasattr(args, 'sequence_parallel_size'):
kwargs[attr] = args.sequence_parallel_size
elif attr == 'flow_shift' and hasattr(args, 'shift'):
kwargs[attr] = args.shift
# Use getattr with default value from the dataclass for potentially missing attributes
else:
default_value = getattr(cls, attr, None)
kwargs[attr] = getattr(args, attr, default_value)
return cls(**kwargs)
def check_fastvideo_args(self) -> None:
"""Validate inference arguments for consistency"""
if self.tp_size is None:
self.tp_size = self.num_gpus
if self.sp_size is None:
self.sp_size = self.num_gpus
if self.num_gpus < max(self.tp_size, self.sp_size):
self.num_gpus = max(self.tp_size, self.sp_size)
if self.tp_size != self.sp_size:
raise ValueError(
f"tp_size ({self.tp_size}) must be equal to sp_size ({self.sp_size})"
)
# Validate VAE spatial parallelism with VAE tiling
if self.vae_sp and not self.vae_tiling:
raise ValueError(
"Currently enabling vae_sp requires enabling vae_tiling, please set --vae-tiling to True."
)
if len(self.text_encoder_configs) != len(self.text_encoder_precisions):
raise ValueError(
f"Length of text encoder configs ({len(self.text_encoder_configs)}) must be equal to length of text encoder precisions ({len(self.text_encoder_precisions)})"
)
if len(self.text_encoder_configs) != len(self.preprocess_text_funcs):
raise ValueError(
f"Length of text encoder configs ({len(self.text_encoder_configs)}) must be equal to length of text preprocessing functions ({len(self.preprocess_text_funcs)})"
)
if len(self.preprocess_text_funcs) != len(self.postprocess_text_funcs):
raise ValueError(
f"Length of text postprocess functions ({len(self.postprocess_text_funcs)}) must be equal to length of text preprocessing functions ({len(self.preprocess_text_funcs)})"
)
_current_fastvideo_args = None
def prepare_fastvideo_args(argv: List[str]) -> FastVideoArgs:
"""
Prepare the inference arguments from the command line arguments.
Args:
argv: The command line arguments. Typically, it should be `sys.argv[1:]`
to ensure compatibility with `parse_args` when no arguments are passed.
Returns:
The inference arguments.
"""
parser = FlexibleArgumentParser()
FastVideoArgs.add_cli_args(parser)
raw_args = parser.parse_args(argv)
fastvideo_args = FastVideoArgs.from_cli_args(raw_args)
fastvideo_args.check_fastvideo_args()
global _current_fastvideo_args
_current_fastvideo_args = fastvideo_args
return fastvideo_args
@contextmanager
def set_current_fastvideo_args(fastvideo_args: FastVideoArgs):
"""
Temporarily set the current fastvideo config.
Used during model initialization.
We save the current fastvideo config in a global variable,
so that all modules can access it, e.g. custom ops
can access the fastvideo config to determine how to dispatch.
"""
global _current_fastvideo_args
old_fastvideo_args = _current_fastvideo_args
try:
_current_fastvideo_args = fastvideo_args
yield
finally:
_current_fastvideo_args = old_fastvideo_args
def get_current_fastvideo_args() -> FastVideoArgs:
if _current_fastvideo_args is None:
# in ci, usually when we test custom ops/modules directly,
# we don't set the fastvideo config. In that case, we set a default
# config.
# TODO(will): may need to handle this for CI.
raise ValueError("Current fastvideo args is not set.")
return _current_fastvideo_args
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/forward_context.py
import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import torch
from fastvideo.v1.fastvideo_args import FastVideoArgs
from fastvideo.v1.logger import init_logger
if TYPE_CHECKING:
from fastvideo.v1.attention import AttentionMetadata
logger = init_logger(__name__)
# TODO(will): check if this is needed
# track_batchsize: bool = envs.FASTVIDEO_LOG_BATCHSIZE_INTERVAL >= 0
track_batchsize: bool = False
last_logging_time: float = 0
forward_start_time: float = 0
# batchsize_logging_interval: float = envs.FASTVIDEO_LOG_BATCHSIZE_INTERVAL
batchsize_logging_interval: float = 1000
batchsize_forward_time: defaultdict = defaultdict(list)
#
@dataclass
class ForwardContext:
# TODO(will): check this arg
# copy from vllm_config.compilation_config.static_forward_context
# attn_layers: Dict[str, Any]
# TODO: extend to support per-layer dynamic forward context
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
_forward_context: Optional[ForwardContext] = None
def get_forward_context() -> ForwardContext:
"""Get the current forward context."""
assert _forward_context is not None, (
"Forward context is not set. "
"Please use `set_forward_context` to set the forward context.")
return _forward_context
# TODO(will): finalize the interface
@contextmanager
def set_forward_context(current_timestep,
attn_metadata,
fastvideo_args: Optional[FastVideoArgs] = None):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
"""
global forward_start_time
need_to_track_batchsize = track_batchsize and attn_metadata is not None
if need_to_track_batchsize:
forward_start_time = time.perf_counter()
global _forward_context
prev_context = _forward_context
_forward_context = ForwardContext(attn_metadata=attn_metadata)
try:
yield
finally:
global last_logging_time, batchsize_logging_interval
if need_to_track_batchsize:
if hasattr(attn_metadata, "num_prefill_tokens"):
# for v0 attention backends
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends
batchsize = attn_metadata.num_input_tokens
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch
torch.cuda.synchronize()
now = time.perf_counter()
# time measurement is in milliseconds
batchsize_forward_time[batchsize].append(
(now - forward_start_time) * 1000)
if now - last_logging_time > batchsize_logging_interval:
last_logging_time = now
forward_stats = []
for bs, times in batchsize_forward_time.items():
if len(times) <= 1:
# can be cudagraph / profiling run
continue
medium = torch.quantile(torch.tensor(times), q=0.5).item()
medium = round(medium, 2)
forward_stats.append((bs, len(times), medium))
forward_stats.sort(key=lambda x: x[1], reverse=True)
if forward_stats:
logger.info(("Batchsize forward time stats "
"(batchsize, count, median_time(ms)): %s"),
forward_stats)
_forward_context = prev_context
# type: ignore
# SPDX-License-Identifier: Apache-2.0
"""
Inference module for diffusion models.
This module provides classes and functions for running inference with diffusion models.
"""
import time
from typing import Any, Dict
import torch
from fastvideo.v1.fastvideo_args import FastVideoArgs
from fastvideo.v1.logger import init_logger
from fastvideo.v1.pipelines import (ComposedPipelineBase, ForwardBatch,
build_pipeline)
# TODO(will): remove, check if this is hunyuan specific
from fastvideo.v1.utils import align_to
logger = init_logger(__name__)
class InferenceEngine:
"""
Engine for running inference with diffusion models.
"""
def __init__(
self,
pipeline: ComposedPipelineBase,
fastvideo_args: FastVideoArgs,
):
"""
Initialize the inference engine.
Args:
pipeline: The pipeline to use for inference.
fastvideo_args: The inference arguments.
default_negative_prompt: The default negative prompt to use.
"""
self.pipeline = pipeline
self.fastvideo_args = fastvideo_args
@classmethod
def create_engine(
cls,
fastvideo_args: FastVideoArgs,
) -> "InferenceEngine":
"""
Create an inference engine with the specified arguments.
Args:
fastvideo_args: The inference arguments.
model_loader_cls: The model loader class to use. If None, it will be
determined from the model type.
pipeline_type: The type of pipeline to create. If None, it will be
determined from the model type.
Returns:
The created inference engine.
Raises:
ValueError: If the model type is not recognized or if the pipeline type
is not recognized.
"""
logger.info("Building pipeline...")
# TODO(will): I don't really like this api.
# it should be something closer to pipeline_cls.from_pretrained(...)
# this way for training we can just do pipeline_cls.from_pretrained(
# checkpoint_path) and have it handle everything.
# TODO(Peiyuan): Then maybe we should only pass in model path and device, not the entire inference args?
pipeline = build_pipeline(fastvideo_args)
logger.info("Pipeline Ready")
# Create the inference engine
return cls(pipeline, fastvideo_args)
def run(
self,
prompt: str,
fastvideo_args: FastVideoArgs,
) -> Dict[str, Any]:
"""
Run inference with the pipeline.
Args:
prompt: The prompt to use for generation.
negative_prompt: The negative prompt to use. If None, the default will be used.
seed: The random seed to use. If None, a random seed will be used.
**kwargs: Additional arguments to pass to the pipeline.
Returns:
A dictionary containing the generated videos and metadata.
"""
out_dict: Dict[str, Any] = dict()
num_videos_per_prompt = fastvideo_args.num_videos
seed = fastvideo_args.seed
height = fastvideo_args.height
width = fastvideo_args.width
video_length = fastvideo_args.num_frames
negative_prompt = fastvideo_args.neg_prompt
infer_steps = fastvideo_args.num_inference_steps
guidance_scale = fastvideo_args.guidance_scale
flow_shift = fastvideo_args.flow_shift
embedded_guidance_scale = fastvideo_args.embedded_cfg_scale
image_path = fastvideo_args.image_path
# ========================================================================
# Arguments: target_width, target_height, target_video_length
# ========================================================================
if width <= 0 or height <= 0 or video_length <= 0:
raise ValueError(
f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={video_length}"
)
if (video_length - 1) % 4 != 0:
raise ValueError(
f"`video_length-1` must be a multiple of 4, got {video_length}")
target_height = align_to(height, 16)
target_width = align_to(width, 16)
target_video_length = video_length
out_dict["size"] = (target_height, target_width, target_video_length)
# ========================================================================
# Arguments: prompt, new_prompt, negative_prompt
# ========================================================================
if not isinstance(prompt, str):
raise TypeError(
f"`prompt` must be a string, but got {type(prompt)}")
prompt = prompt.strip()
# negative prompt
if negative_prompt is not None:
negative_prompt = negative_prompt.strip()
# TODO(PY): move to hunyuan stage
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
n_tokens = latents_size[0] * latents_size[1] * latents_size[2]
# ========================================================================
# Print infer args
# ========================================================================
debug_str = f"""
height: {target_height}
width: {target_width}
video_length: {target_video_length}
prompt: {prompt}
neg_prompt: {negative_prompt}
seed: {seed}
infer_steps: {infer_steps}
num_videos_per_prompt: {num_videos_per_prompt}
guidance_scale: {guidance_scale}
n_tokens: {n_tokens}
flow_shift: {flow_shift}
embedded_guidance_scale: {embedded_guidance_scale}"""
logger.info(debug_str)
# return
# sp_group = get_sp_group()
# local_rank = sp_group.rank
device = torch.device(fastvideo_args.device_str)
batch = ForwardBatch(
image_path=image_path,
prompt=prompt,
negative_prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
height=fastvideo_args.height,
width=fastvideo_args.width,
num_frames=fastvideo_args.num_frames,
num_inference_steps=fastvideo_args.num_inference_steps,
guidance_scale=fastvideo_args.guidance_scale,
# generator=generator,
eta=0.0,
n_tokens=n_tokens,
data_type="video" if fastvideo_args.num_frames > 1 else "image",
device=device,
extra={}, # Any additional parameters
)
print('===============================================')
print(batch)
print('===============================================')
print('===============================================')
print(fastvideo_args)
# ========================================================================
# Pipeline inference
# ========================================================================
start_time = time.time()
samples = self.pipeline.forward(
batch=batch,
fastvideo_args=fastvideo_args,
).output
# TODO(will): fix and move to hunyuan stage
# out_dict["seeds"] = batch.seeds
out_dict["samples"] = samples
out_dict["prompts"] = prompt
gen_time = time.time() - start_time
logger.info("Success, time: %s", gen_time)
return out_dict
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/activation.py
"""Custom activation functions."""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# TODO (will): remove this dependency
from fastvideo.v1.layers.custom_op import CustomOp
from fastvideo.v1.platforms import current_platform
@CustomOp.register("silu_and_mul")
class SiluAndMul(CustomOp):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def __init__(self) -> None:
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
self.op = torch.ops._C.silu_and_mul
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
self.op(out, x)
return out
@CustomOp.register("gelu_and_mul")
class GeluAndMul(CustomOp):
"""An activation function for GeGLU.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
return: (batch_size, seq_len, d) or (num_tokens, d)
"""
def __init__(self, approximate: str = "none"):
super().__init__()
self.approximate = approximate
if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}")
if current_platform.is_cuda_alike() or current_platform.is_cpu():
if approximate == "none":
self.op = torch.ops._C.gelu_and_mul
elif approximate == "tanh":
self.op = torch.ops._C.gelu_tanh_and_mul
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
self.op(out, x)
return out
def extra_repr(self) -> str:
return f'approximate={repr(self.approximate)}'
@CustomOp.register("gelu_new")
class NewGELU(CustomOp):
def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
self.op = torch.ops._C.gelu_new
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c *
(x + 0.044715 * torch.pow(x, 3.0))))
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
self.op(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return self.op(x)
@CustomOp.register("quick_gelu")
class QuickGELU(CustomOp):
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
self.op = torch.ops._C.gelu_quick
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return x * torch.sigmoid(1.702 * x)
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
self.op(out, x)
return out
_ACTIVATION_REGISTRY = {
"gelu": nn.GELU,
"gelu_new": NewGELU,
"gelu_pytorch_tanh": lambda: nn.GELU(approximate="tanh"),
"relu": nn.ReLU,
"silu": nn.SiLU,
"quick_gelu": QuickGELU,
}
def get_act_fn(act_fn_name: str) -> nn.Module:
"""Get an activation function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_REGISTRY:
raise ValueError(
f"Activation function {act_fn_name!r} is not supported.")
return _ACTIVATION_REGISTRY[act_fn_name]()
_ACTIVATION_AND_MUL_REGISTRY = {
"gelu": GeluAndMul,
"silu": SiluAndMul,
}
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
raise ValueError(
f"Activation function {act_fn_name!r} is not supported.")
return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]()
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/custom_op.py
from typing import Any, Callable, Dict, Type
import torch.nn as nn
from fastvideo.v1.logger import init_logger
logger = init_logger(__name__)
class CustomOp(nn.Module):
"""
Base class for custom ops.
Dispatches the forward method to the appropriate backend.
"""
def __init__(self) -> None:
super().__init__()
self._forward_method = self.dispatch_forward()
def forward(self, *args, **kwargs) -> Any:
return self._forward_method(*args, **kwargs)
def forward_native(self, *args, **kwargs) -> Any:
"""PyTorch-native implementation of the forward method.
This method is optional. If implemented, it can be used with compilers
such as torch.compile or PyTorch XLA. Also, it can be used for testing
purposes.
"""
raise NotImplementedError
def forward_cuda(self, *args, **kwargs) -> Any:
raise NotImplementedError
def forward_cpu(self, *args, **kwargs) -> Any:
# By default, we assume that CPU ops are compatible with CUDA ops.
return self.forward_cuda(*args, **kwargs)
def forward_tpu(self, *args, **kwargs) -> Any:
# By default, we assume that TPU ops are compatible with the
# PyTorch-native implementation.
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_native(*args, **kwargs)
def forward_oot(self, *args, **kwargs) -> Any:
# By default, we assume that OOT ops are compatible with the
# PyTorch-native implementation.
return self.forward_native(*args, **kwargs)
def dispatch_forward(self) -> Callable:
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
enabled = self.enabled()
if not enabled:
return self.forward_native
return self.forward_cuda
@classmethod
def enabled(cls) -> bool:
# since we are not using Inductor, we always return True
return True
@staticmethod
def default_on() -> bool:
"""
On by default if level < CompilationLevel.PIECEWISE
Specifying 'all' or 'none' in custom_op takes precedence.
"""
raise NotImplementedError
# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
# Examples:
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry: Dict[str, Type['CustomOp']] = {}
# Decorator to register custom ops.
@classmethod
def register(cls, name: str) -> Callable:
def decorator(op_cls):
assert name not in cls.op_registry, f"Duplicate op name: {name}"
op_cls.name = name
cls.op_registry[name] = op_cls
return op_cls
return decorator
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/layernorm.py
"""Custom normalization layers."""
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from fastvideo.v1.layers.custom_op import CustomOp
@CustomOp.register("rms_norm")
class RMSNorm(CustomOp):
"""Root mean square normalization.
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
Refer to https://arxiv.org/abs/1910.07467
"""
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
dtype: torch.dtype = torch.float32,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.variance_epsilon = eps
self.variance_size_override = (None if var_hidden_size == hidden_size
else var_hidden_size)
self.has_weight = has_weight
self.weight = torch.ones(hidden_size)
if self.has_weight:
self.weight = nn.Parameter(self.weight)
def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
raise ValueError("Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}")
if self.variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}")
x_var = x[:, :, :self.variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype)
if self.has_weight:
x = x * self.weight
if residual is None:
return x
else:
return x, residual
def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
from vllm import _custom_ops as ops
if residual is not None:
ops.fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
ops.rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)
return out
def extra_repr(self) -> str:
s = f"hidden_size={self.weight.data.size(0)}"
s += f", eps={self.variance_epsilon}"
return s
class ScaleResidual(nn.Module):
"""
Applies gated residual connection.
"""
def __init__(self, prefix: str = ""):
super().__init__()
def forward(self, residual: torch.Tensor, x: torch.Tensor,
gate: torch.Tensor) -> torch.Tensor:
"""Apply gated residual connection."""
return residual + x * gate
class ScaleResidualLayerNormScaleShift(nn.Module):
"""
Fused operation that combines:
1. Gated residual connection
2. LayerNorm
3. Scale and shift operations
This reduces memory bandwidth by combining memory-bound operations.
"""
def __init__(
self,
hidden_size: int,
norm_type: str = "rms",
eps: float = 1e-6,
elementwise_affine: bool = False,
dtype: torch.dtype = torch.float32,
prefix: str = "",
):
super().__init__()
if norm_type == "rms":
self.norm = RMSNorm(hidden_size,
has_weight=elementwise_affine,
eps=eps,
dtype=dtype)
elif norm_type == "layer":
self.norm = nn.LayerNorm(hidden_size,
elementwise_affine=elementwise_affine,
eps=eps,
dtype=dtype)
else:
raise NotImplementedError(f"Norm type {norm_type} not implemented")
def forward(self, residual: torch.Tensor, x: torch.Tensor,
gate: torch.Tensor, shift: torch.Tensor,
scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply gated residual connection, followed by layernorm and
scale/shift in a single fused operation.
Returns:
Tuple containing:
- normalized and modulated output
- residual value (value after residual connection
but before normalization)
"""
# Apply residual connection with gating
residual_output = residual + x * gate
# Apply normalization
normalized = self.norm(residual_output)
# Apply scale and shift
modulated = normalized * (1.0 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return modulated, residual_output
class LayerNormScaleShift(nn.Module):
"""
Fused operation that combines LayerNorm with scale and shift operations.
This reduces memory bandwidth by combining memory-bound operations.
"""
def __init__(
self,
hidden_size: int,
norm_type: str = "rms",
eps: float = 1e-6,
elementwise_affine: bool = False,
dtype: torch.dtype = torch.float32,
prefix: str = "",
):
super().__init__()
if norm_type == "rms":
self.norm = RMSNorm(hidden_size,
has_weight=elementwise_affine,
eps=eps)
elif norm_type == "layer":
self.norm = nn.LayerNorm(hidden_size,
elementwise_affine=elementwise_affine,
eps=eps,
dtype=dtype)
else:
raise NotImplementedError(f"Norm type {norm_type} not implemented")
def forward(self, x: torch.Tensor, shift: torch.Tensor,
scale: torch.Tensor) -> torch.Tensor:
"""Apply ln followed by scale and shift in a single fused operation."""
normalized = self.norm(x)
return normalized * (1.0 + scale.unsqueeze(1)) + shift.unsqueeze(1)
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/linear.py
from abc import abstractmethod
from typing import Optional, Union
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
# TODO(will): remove this import by copying the definition from vLLM then
# manually import each quantization method we want to use. Refer to SGLang
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from fastvideo.v1.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from fastvideo.v1.logger import init_logger
# yapf: disable
from fastvideo.v1.models.parameter import (BasevLLMParameter,
BlockQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
PerTensorScaleParameter,
RowvLLMParameter)
# yapf: enable
from fastvideo.v1.models.utils import set_weight_attrs
logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", "AWQLinearMethod",
"GPTQMarlinLinearMethod", "Fp8LinearMethod", "MarlinLinearMethod",
"QQQLinearMethod", "GPTQMarlin24LinearMethod", "TPUInt8LinearMethod",
"GPTQLinearMethod", "FBGEMMFp8LinearMethod", "ModelOptFp8LinearMethod",
"IPEXAWQLinearMethod", "IPEXGPTQLinearMethod", "HQQMarlinMethod",
"QuarkLinearMethod"
]
def adjust_scalar_to_fused_array(
param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int]) -> tuple[torch.Tensor, torch.Tensor]:
"""For fused modules (QKV and MLP) we have an array of length
N that holds 1 scale for each "logical" matrix. So the param
is an array of length N. The loaded_weight corresponds to
one of the shards on disk. Here, we slice the param based on
the shard_id for loading.
"""
qkv_idxs = {"q": 0, "k": 1, "v": 2}
if isinstance(shard_id, str):
shard_id = qkv_idxs[shard_id]
elif not isinstance(shard_id, int):
raise ValueError(f"Unknown Shard Id {shard_id}")
# AutoFP8 scales do not have a shape
# compressed-tensors scales do have a shape
if len(loaded_weight.shape) != 0:
assert loaded_weight.shape[0] == 1
loaded_weight = loaded_weight[0]
return param[shard_id], loaded_weight
class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""
@abstractmethod
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs) -> None:
"""Create weights for a linear layer.
The weights will be set as attributes of the layer.
Args:
layer: The layer that is using the LinearMethodBase factory.
input_size_per_partition: Size of the weight input dim on rank X.
output_partition_sizes: Sizes of the output dim of each logical
weight on rank X. E.g., output_partition_sizes for QKVLinear
is a list contains the width of Wq, Wk, Wv on rank X.
input_size: Size of the input dim of the weight across all ranks.
output_size: Size of the output dim of the weight across all ranks.
params_dtype: Datatype of the parameters.
"""
raise NotImplementedError
@abstractmethod
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
class UnquantizedLinearMethod(LinearMethodBase):
"""Linear method without quantization."""
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs) -> None:
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return F.linear(x, layer.weight, bias)
class LinearBase(torch.nn.Module):
"""Base linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_size: int,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.skip_bias_add = skip_bias_add
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if quant_config is None:
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
def forward(self,
x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
raise NotImplementedError
class ReplicatedLinear(LinearBase):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def __init__(self,
input_size: int,
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size, [self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter,
loaded_weight: torch.Tensor) -> None:
# If the weight on disk does not have a shape, give it one
# (such scales for AutoFp8).
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param.size() == loaded_weight.size(), (
f"Tried to load weights of size {loaded_weight.size()}"
f"to a parameter of size {param.size()}")
param.data.copy_(loaded_weight)
def forward(self,
x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
return s
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Args:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias.
gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def __init__(self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[list[int]] = None,
prefix: str = ""):
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)
self.gather_output = gather_output
if output_sizes is None:
output_sizes = [output_size]
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition, dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter,
loaded_weight: torch.Tensor) -> None:
tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
is_sharded_weight = is_sharded_weight
param_data = param.data
if output_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[output_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def weight_loader_v2(self, param: Parameter,
loaded_weight: torch.Tensor) -> None:
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_column_parallel_weight(loaded_weight=loaded_weight)
def forward(
self,
input_: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size_per_partition}"
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"
s += f", gather_output={self.gather_output}"
return s
class MergedColumnParallelLinear(ColumnParallelLinear):
"""Packed linear layers with column parallelism.
Similar to ColumnParallelLinear, but the weight matrix is concatenated
along the output dimension. When the weight matrix is loaded, the
different partitions are sharded separately.
Args:
input_size: input dimension of the linear layer.
output_sizes: list of output dimensions of the linear layer.
bias: If true, add bias.
gather_output: If true, call all-gather on output and make the output
available to all GPUs, otherwise, every GPU will have
its own output.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def __init__(self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size=input_size,
output_size=sum(output_sizes),
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix)
def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None) -> None:
param_data = param.data
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)
# Special case for per-tensor scale to load scalar into fused array.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
if loaded_shard_id is None:
# Loaded weight is already fused on disk (mlp).
# (e.g., Phi-3's gate_up_proj).
if output_dim is None:
if needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, 0)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
current_shard_offset = 0
shard_offsets: list[tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)
return
assert loaded_shard_id < len(self.output_sizes)
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
start_idx = tp_rank * shard_size
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size)
# Special case for per-tensor scales in fused case.
elif needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, loaded_shard_id)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
logger.warning(
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor) -> None:
"""
Handle special case for models where MLP layers are already
fused on disk. In this case, we have no shard id. This function
determmines the shard id by splitting these layers and then calls
the weight loader using the shard id.
An example of a model with these fused layers:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
"""
current_shard_offset = 0
shard_offsets: list[tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if isinstance(
param,
(PackedColumnParameter,
PackedvLLMParameter)) and param.packed_dim == param.output_dim:
shard_size, shard_offset = \
param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset)
loaded_weight_shard = loaded_weight.narrow(param.output_dim,
shard_offset, shard_size)
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
def weight_loader_v2(self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None) -> None:
if loaded_shard_id is None:
if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=0)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight)
return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
assert loaded_shard_id < len(self.output_sizes)
tp_size = get_tensor_model_parallel_world_size()
if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
assert self.quant_method is not None
assert isinstance(self.quant_method,
(Fp8LinearMethod, Fp8MoEMethod))
weight_block_size = self.quant_method.quant_config.weight_block_size
assert weight_block_size is not None
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
block_n) // tp_size
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
block_n // tp_size)
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)
class QKVParallelLinear(ColumnParallelLinear):
"""Linear layers for the attention's QKV transformation.
Linear layers for the linear transformation of the query, key, and value
vectors in the attention layer. The weight matrix is concatenated along
the output dimension. The layer is parallelized along the head dimension.
When the number of key/value heads is smaller than the number of query
heads (e.g., multi-query/grouped-query attention), the key/value head may
be replicated while the query heads are partitioned.
Args:
hidden_size: input hidden state size of the transformer.
head_size: size of each attention head.
total_num_heads: total number of attention query heads.
total_num_kv_heads: total number of attention key/value heads. If
None, assume total_num_kv_heads = total_num_heads.
bias: If true, add bias.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def __init__(self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads:
self.num_kv_heads = 1
self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
else:
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
self.num_kv_head_replicas = 1
input_size = self.hidden_size
output_size = (self.num_heads +
2 * self.num_kv_heads) * tp_size * self.head_size
self.output_sizes = [
self.num_heads * self.head_size * tp_size, # q_proj
self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj
]
super().__init__(input_size=input_size,
output_size=output_size,
bias=bias,
gather_output=False,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix)
def _get_shard_offset_mapping(self, loaded_shard_id: str) -> Optional[int]:
shard_offset_mapping = {
"q": 0,
"k": self.num_heads * self.head_size,
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
}
return shard_offset_mapping.get(loaded_shard_id)
def _get_shard_size_mapping(self, loaded_shard_id: str) -> Optional[int]:
shard_size_mapping = {
"q": self.num_heads * self.head_size,
"k": self.num_kv_heads * self.head_size,
"v": self.num_kv_heads * self.head_size,
}
return shard_size_mapping.get(loaded_shard_id)
def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
"""
Handle special case for models where QKV layers are already
fused on disk. In this case, we have no shard id. This function
determmines the shard id by splitting these layers and then calls
the weight loader using the shard id.
An example of a model with these fused layers:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
"""
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
("k", self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size),
("v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.head_size),
]
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if isinstance(
param,
(PackedColumnParameter,
PackedvLLMParameter)) and param.packed_dim == param.output_dim:
shard_size, shard_offset = \
param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset)
loaded_weight_shard = loaded_weight.narrow(param.output_dim,
shard_offset, shard_size)
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
def weight_loader_v2(self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter):
param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_qkv_weight(loaded_weight=loaded_weight)
return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
assert loaded_shard_id in ["q", "k", "v"]
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
shard_size = self._get_shard_size_mapping(loaded_shard_id)
param.load_qkv_weight(loaded_weight=loaded_weight,
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)
def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)
# Special case for per-tensor scales in fused case.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv).
# (e.g., Phi-3's qkv_proj).
if output_dim is None:
if needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, 0)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
("k", self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size),
("v", (self.total_num_heads + self.total_num_kv_heads) *
self.head_size, self.total_num_kv_heads * self.head_size),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)
return
tp_rank = get_tensor_model_parallel_rank()
assert loaded_shard_id in ["q", "k", "v"]
# If output dim is defined, use the default loading process.
if output_dim is not None:
if loaded_shard_id == "q":
shard_offset = 0
shard_size = self.num_heads * self.head_size
elif loaded_shard_id == "k":
shard_offset = self.num_heads * self.head_size
shard_size = self.num_kv_heads * self.head_size
elif loaded_shard_id == "v":
shard_offset = (self.num_heads +
self.num_kv_heads) * self.head_size
shard_size = self.num_kv_heads * self.head_size
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight
shard_idx = 0
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
if loaded_shard_id == "q":
shard_idx = tp_rank
else:
shard_idx = tp_rank // self.num_kv_head_replicas
start_idx = shard_idx * shard_size
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0]
shard_index = ["q", "k", "v"].index(loaded_shard_id)
param_data = param_data.narrow(0, shard_index * shard_size,
shard_size)
# Special case for per-tensor scales in fused case.
elif needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, loaded_shard_id)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
logger.warning(
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions.")
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
skip_bias_add: This was added to enable performance optimization where
bias can be fused with other element-wise operations.
We skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def __init__(self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
# Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
input_dim = getattr(param, "input_dim", None)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight
param_data = param.data
if input_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def weight_loader_v2(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def extra_repr(self) -> str:
s = f"input_features={self.input_size_per_partition}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
s += f", tp_size={self.tp_size}"
s += f", reduce_results={self.reduce_results}"
return s
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
import torch.nn as nn
from fastvideo.v1.layers.activation import get_act_fn
from fastvideo.v1.layers.linear import ReplicatedLinear
class MLP(nn.Module):
"""
MLP for DiT blocks, NO gated linear units
"""
def __init__(
self,
input_dim: int,
mlp_hidden_dim: int,
output_dim: Optional[int] = None,
bias: bool = True,
act_type: str = "gelu_pytorch_tanh",
dtype: Optional[torch.dtype] = None,
prefix: str = "",
):
super().__init__()
self.fc_in = ReplicatedLinear(
input_dim,
mlp_hidden_dim, # For activation func like SiLU that need 2x width
bias=bias,
params_dtype=dtype)
self.act = get_act_fn(act_type)
if output_dim is None:
output_dim = input_dim
self.fc_out = ReplicatedLinear(mlp_hidden_dim,
output_dim,
bias=bias,
params_dtype=dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.fc_in(x)
x = self.act(x)
x, _ = self.fc_out(x)
return x
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/rotary_embedding.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rotary Positional Embeddings."""
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from fastvideo.v1.distributed.parallel_state import get_sp_group
from fastvideo.v1.layers.custom_op import CustomOp
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)
def _apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
# cos = cos.unsqueeze(-2).to(x.dtype)
# sin = sin.unsqueeze(-2).to(x.dtype)
cos = cos.unsqueeze(-2)
sin = sin.unsqueeze(-2)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = (x1.float() * cos - x2.float() * sin).type_as(x)
o2 = (x2.float() * cos + x1.float() * sin).type_as(x)
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
@CustomOp.register("rotary_embedding")
class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: Union[int, float],
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
cache = self._compute_cos_sin_cache()
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward()."""
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
ops.batched_rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache, self.is_neox_style,
self.rotary_dim, offsets)
else:
ops.rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache, self.is_neox_style)
return query, key
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
return s
def _to_tuple(x: Union[int, Tuple[int, ...]], dim: int = 2) -> Tuple[int, ...]:
if isinstance(x, int):
return (x, ) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_meshgrid_nd(start: Union[int, Tuple[int, ...]],
*args: Union[int, Tuple[int, ...]],
dim: int = 2) -> torch.Tensor:
"""
Get n-D meshgrid with start, stop and num.
Args:
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
n-tuples.
*args: See above.
dim (int): Dimension of the meshgrid. Defaults to 2.
Returns:
grid (np.ndarray): [dim, ...]
"""
if len(args) == 0:
# start is grid_size
num = _to_tuple(start, dim=dim)
start = (0, ) * dim
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = tuple(stop[i] - start[i] for i in range(dim))
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
return grid
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[torch.FloatTensor, int],
theta: float = 10000.0,
theta_rescale_factor: float = 1.0,
interpolation_factor: float = 1.0,
dtype: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
Args:
dim (int): Dimension of the frequency tensor.
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
interpolation_factor (float, optional): Factor to scale positions. Defaults to 1.0.
Returns:
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
"""
if isinstance(pos, int):
pos = torch.arange(pos).float()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
if theta_rescale_factor != 1.0:
theta *= theta_rescale_factor**(dim / (dim - 2))
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].to(dtype) / dim)
) # [D/2]
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
freqs_cos = freqs.cos() # [S, D/2]
freqs_sin = freqs.sin() # [S, D/2]
return freqs_cos, freqs_sin
def get_nd_rotary_pos_embed(
rope_dim_list,
start,
*args,
theta=10000.0,
theta_rescale_factor: Union[float, List[float]] = 1.0,
interpolation_factor: Union[float, List[float]] = 1.0,
shard_dim: int = 0,
sp_rank: int = 0,
sp_world_size: int = 1,
dtype: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
Supports sequence parallelism by allowing sharding of a specific dimension.
Args:
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
sum(rope_dim_list) should equal to head_dim of attention layer.
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
*args: See above.
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
interpolation_factor (float): Factor to scale positions. Defaults to 1.0.
shard_dim (int): Which dimension to shard for sequence parallelism. Defaults to 0.
sp_rank (int): Rank in the sequence parallel group. Defaults to 0.
sp_world_size (int): World size of the sequence parallel group. Defaults to 1.
Returns:
Tuple[torch.Tensor, torch.Tensor]: (cos, sin) tensors of shape [HW, D/2]
"""
# Get the full grid
full_grid = get_meshgrid_nd(
start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
# Shard the grid if using sequence parallelism (sp_world_size > 1)
assert shard_dim < len(
rope_dim_list
), f"shard_dim {shard_dim} must be less than number of dimensions {len(rope_dim_list)}"
if sp_world_size > 1:
# Get the shape of the full grid
grid_shape = list(full_grid.shape[1:])
# Ensure the dimension to shard is divisible by sp_world_size
assert grid_shape[shard_dim] % sp_world_size == 0, (
f"Dimension {shard_dim} with size {grid_shape[shard_dim]} is not divisible "
f"by sequence parallel world size {sp_world_size}")
# Compute the start and end indices for this rank's shard
shard_size = grid_shape[shard_dim] // sp_world_size
start_idx = sp_rank * shard_size
end_idx = (sp_rank + 1) * shard_size
# Create slicing indices for each dimension
slice_indices = [slice(None) for _ in range(len(grid_shape))]
slice_indices[shard_dim] = slice(start_idx, end_idx)
# Shard the grid
# Update grid shape for the sharded dimension
grid_shape[shard_dim] = grid_shape[shard_dim] // sp_world_size
grid = torch.empty((len(rope_dim_list), ) + tuple(grid_shape),
dtype=full_grid.dtype)
for i in range(len(rope_dim_list)):
grid[i] = full_grid[i][tuple(slice_indices)]
else:
grid = full_grid
if isinstance(theta_rescale_factor, (int, float)):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor,
list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
assert len(theta_rescale_factor) == len(
rope_dim_list
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
if isinstance(interpolation_factor, (int, float)):
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor,
list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
assert len(interpolation_factor) == len(
rope_dim_list
), "len(interpolation_factor) should equal to len(rope_dim_list)"
# use 1/ndim of dimensions to encode grid_axis
embs = []
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i],
grid[i].reshape(-1),
theta,
theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
dtype=dtype,
) # 2 x [WHD, rope_dim_list[i]]
embs.append(emb)
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
return cos, sin
def get_rotary_pos_embed(
rope_sizes,
hidden_size,
heads_num,
rope_dim_list,
rope_theta,
theta_rescale_factor=1.0,
interpolation_factor=1.0,
shard_dim: int = 0,
dtype: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate rotary positional embeddings for the given sizes.
Args:
rope_sizes: Tuple of dimensions (t, h, w)
hidden_size: Hidden dimension size
heads_num: Number of attention heads
rope_dim_list: List of dimensions for each axis, or None
rope_theta: Base for frequency calculations
theta_rescale_factor: Rescale factor for theta. Defaults to 1.0
interpolation_factor: Factor to scale positions. Defaults to 1.0
shard_dim: Which dimension to shard for sequence parallelism. Defaults to 0.
Returns:
Tuple of (cos, sin) tensors for rotary embeddings
"""
target_ndim = 3
head_dim = hidden_size // heads_num
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(
rope_dim_list
) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
# Get SP info
sp_group = get_sp_group()
sp_rank = sp_group.rank_in_group
sp_world_size = sp_group.world_size
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
rope_sizes,
theta=rope_theta,
theta_rescale_factor=theta_rescale_factor,
interpolation_factor=interpolation_factor,
shard_dim=shard_dim,
sp_rank=sp_rank,
sp_world_size=sp_world_size,
dtype=dtype,
)
return freqs_cos, freqs_sin
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: Union[int, float],
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in rope_scaling.items()
}
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling_args, dtype)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
else:
raise ValueError(f"Unknown RoPE scaling {rope_scaling}")
_ROPE_DICT[key] = rotary_emb
return rotary_emb
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/utils.py
"""Utility methods for model layers."""
from typing import Tuple
import torch
def get_token_bin_counts_and_mask(
tokens: torch.Tensor,
vocab_size: int,
num_seqs: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Compute the bin counts for the tokens.
# vocab_size + 1 for padding.
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
dtype=torch.long,
device=tokens.device)
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
bin_counts = bin_counts[:, :vocab_size]
mask = bin_counts > 0
return bin_counts, mask
# SPDX-License-Identifier: Apache-2.0
import math
from typing import Optional
import torch
import torch.nn as nn
from fastvideo.v1.layers.activation import get_act_fn
from fastvideo.v1.layers.linear import ReplicatedLinear
from fastvideo.v1.layers.mlp import MLP
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding
Image to Patch Embedding using Conv2d
A convolution based approach to patchifying a 2D image w/ embedding projection.
Based on the impl in https://github.com/google-research/vision_transformer
Hacked together by / Copyright 2020 Ross Wightman
Remove the _assert function in forward function to be compatible with multi-resolution images.
"""
def __init__(self,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
dtype=None,
prefix: str = ""):
super().__init__()
# Convert patch_size to 2-tuple
if isinstance(patch_size, (list, tuple)):
if len(patch_size) == 1:
patch_size = (patch_size[0], patch_size[0])
else:
patch_size = (patch_size, patch_size)
self.patch_size = patch_size
self.flatten = flatten
self.proj = nn.Conv3d(in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=bias,
dtype=dtype)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(
self,
hidden_size,
act_layer="silu",
frequency_embedding_size=256,
max_period=10000,
dtype=None,
freq_dtype=torch.float32,
prefix: str = "",
):
super().__init__()
self.frequency_embedding_size = frequency_embedding_size
self.max_period = max_period
self.mlp = MLP(frequency_embedding_size,
hidden_size,
hidden_size,
act_type=act_layer,
dtype=dtype)
self.freq_dtype = freq_dtype
def forward(self, t: torch.Tensor) -> torch.Tensor:
t_freq = timestep_embedding(t,
self.frequency_embedding_size,
self.max_period,
dtype=self.freq_dtype).to(
self.mlp.fc_in.weight.dtype)
# t_freq = t_freq.to(self.mlp.fc_in.weight.dtype)
t_emb = self.mlp(t_freq)
return t_emb
def timestep_embedding(t: torch.Tensor,
dim: int,
max_period: int = 10000,
dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""
Create sinusoidal timestep embeddings.
Args:
t: Tensor of shape [B] with timesteps
dim: Embedding dimension
max_period: Controls the minimum frequency of the embeddings
Returns:
Tensor of shape [B, dim] with embeddings
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) *
torch.arange(start=0, end=half, dtype=dtype) /
half).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
class ModulateProjection(nn.Module):
"""Modulation layer for DiT blocks."""
def __init__(
self,
hidden_size: int,
factor: int = 2,
act_layer: str = "silu",
dtype: Optional[torch.dtype] = None,
prefix: str = "",
):
super().__init__()
self.factor = factor
self.hidden_size = hidden_size
self.linear = ReplicatedLinear(hidden_size,
hidden_size * factor,
bias=True,
params_dtype=dtype)
self.act = get_act_fn(act_layer)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.act(x)
x, _ = self.linear(x)
return x
def unpatchify(x, t, h, w, patch_size, channels) -> torch.Tensor:
"""
Convert patched representation back to image space.
Args:
x: Tensor of shape [B, T*H*W, C*P_t*P_h*P_w]
t, h, w: Temporal and spatial dimensions
Returns:
Unpatchified tensor of shape [B, C, T*P_t, H*P_h, W*P_w]
"""
assert x.ndim == 3, f"x.ndim: {x.ndim}"
assert len(patch_size) == 3, f"patch_size: {patch_size}"
assert t * h * w == x.shape[
1], f"t * h * w: {t * h * w}, x.shape[1]: {x.shape[1]}"
c = channels
pt, ph, pw = patch_size
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
x = torch.einsum("nthwcopq->nctohpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from fastvideo.v1.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from fastvideo.v1.models.parameter import BasevLLMParameter
from fastvideo.v1.models.utils import set_weight_attrs
from fastvideo.v1.platforms import current_platform
DEFAULT_VOCAB_PADDING_SIZE = 64
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
"""Unquantized method for embeddings."""
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""Create weights for embedding layer."""
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return F.linear(x, layer.weight, bias)
def embedding(self, layer: torch.nn.Module,
input_: torch.Tensor) -> torch.Tensor:
return F.embedding(input_, layer.weight)
def pad_vocab_size(vocab_size: int,
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int,
rank: int,
offset: int = 0) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f + offset, index_l + offset
def vocab_range_from_global_vocab_size(global_vocab_size: int,
rank: int,
world_size: int,
offset: int = 0) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
rank,
offset=offset)
@dataclass
class VocabParallelEmbeddingShardIndices:
"""Indices for a shard of a vocab parallel embedding."""
padded_org_vocab_start_index: int
padded_org_vocab_end_index: int
padded_added_vocab_start_index: int
padded_added_vocab_end_index: int
org_vocab_start_index: int
org_vocab_end_index: int
added_vocab_start_index: int
added_vocab_end_index: int
@property
def num_org_elements(self) -> int:
return self.org_vocab_end_index - self.org_vocab_start_index
@property
def num_added_elements(self) -> int:
return self.added_vocab_end_index - self.added_vocab_start_index
@property
def num_org_elements_padded(self) -> int:
return (self.padded_org_vocab_end_index -
self.padded_org_vocab_start_index)
@property
def num_added_elements_padded(self) -> int:
return (self.padded_added_vocab_end_index -
self.padded_added_vocab_start_index)
@property
def num_org_vocab_padding(self) -> int:
return self.num_org_elements_padded - self.num_org_elements
@property
def num_added_vocab_padding(self) -> int:
return self.num_added_elements_padded - self.num_added_elements
@property
def num_elements_padded(self) -> int:
return self.num_org_elements_padded + self.num_added_elements_padded
def __post_init__(self):
# sanity checks
assert (self.padded_org_vocab_start_index
<= self.padded_org_vocab_end_index)
assert (self.padded_added_vocab_start_index
<= self.padded_added_vocab_end_index)
assert self.org_vocab_start_index <= self.org_vocab_end_index
assert self.added_vocab_start_index <= self.added_vocab_end_index
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
assert (self.added_vocab_start_index
<= self.padded_added_vocab_start_index)
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
assert self.num_org_elements <= self.num_org_elements_padded
assert self.num_added_elements <= self.num_added_elements_padded
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
# torch.compile will fuse all of the pointwise ops below
# into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_
< org_vocab_end_index)
added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index)
added_offset = added_vocab_start_index - (
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
valid_offset = (org_vocab_start_index * org_vocab_mask) + (added_offset *
added_vocab_mask)
vocab_mask = org_vocab_mask | added_vocab_mask
input_ = vocab_mask * (input_ - valid_offset)
return input_, ~vocab_mask
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
make sure it is divisible by the number of model parallel GPUs.
In order to support various loading methods, we ensure that LoRA-added
embeddings are always at the end of TP-sharded tensors. In other words,
we shard base embeddings and LoRA embeddings separately (both padded),
and place them in the same tensor.
In this example, we will have the original vocab size = 1010,
added vocab size = 16 and padding to 64. Therefore, the total
vocab size with padding will be 1088 (because we first pad 1010 to
1024, add 16, and then pad to 1088).
Therefore, the tensor format looks like the following:
TP1, rank 0 (no sharding):
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
TP2, rank 0:
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
TP2, rank 1:
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
quant_config: quant config for the layer
prefix: full name of the layer in the state dict
""" # noqa: E501
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
# Keep the input dimensions.
tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_embeddings = num_embeddings
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings,
self.padding_size)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.shard_indices = self._get_indices(self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size, tp_rank,
self.tp_size)
self.embedding_dim = embedding_dim
quant_method = None
if quant_config is not None:
quant_method = quant_config.get_quant_method(self, prefix=prefix)
if quant_method is None:
quant_method = UnquantizedEmbeddingMethod()
# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
quant_method_implements_embedding = method_has_implemented_embedding(
type(quant_method))
if is_embedding_layer and not quant_method_implements_embedding:
raise NotImplementedError(
f"The class {type(quant_method).__name__} must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
self.quant_method: QuantizeMethodBase = quant_method
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the vocaburaly dimension.
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
self.tp_size)
assert (self.shard_indices.num_elements_padded ==
self.num_embeddings_per_partition)
self.num_org_embeddings_per_partition = (
self.shard_indices.org_vocab_end_index -
self.shard_indices.org_vocab_start_index)
self.num_added_embeddings_per_partition = (
self.shard_indices.added_vocab_end_index -
self.shard_indices.added_vocab_start_index)
self.quant_method.create_weights(self,
self.embedding_dim,
[self.num_embeddings_per_partition],
self.embedding_dim,
self.num_embeddings_padded,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
@classmethod
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
vocab_size: int, org_vocab_size: int, tp_rank: int,
tp_size: int) -> VocabParallelEmbeddingShardIndices:
"""Get start and end indices for vocab parallel embedding, following the
layout outlined in the class docstring, based on the given tp_rank and
tp_size."""
num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
padded_org_vocab_start_index, padded_org_vocab_end_index = (
vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
tp_size))
padded_added_vocab_start_index, padded_added_vocab_end_index = (
vocab_range_from_global_vocab_size(num_added_embeddings_padded,
tp_rank,
tp_size,
offset=org_vocab_size))
# remove padding
org_vocab_start_index = min(padded_org_vocab_start_index,
org_vocab_size)
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
added_vocab_start_index = min(padded_added_vocab_start_index,
vocab_size)
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
return VocabParallelEmbeddingShardIndices(
padded_org_vocab_start_index, padded_org_vocab_end_index,
padded_added_vocab_start_index, padded_added_vocab_end_index,
org_vocab_start_index, org_vocab_end_index, added_vocab_start_index,
added_vocab_end_index)
def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
"""Get a mapping that can be used to reindex the gathered
logits for sampling.
During sampling, we gather logits from all ranks. The relationship
of index->token_id will follow the same format as outlined in the class
docstring. However, after the gather, we want to reindex the final
logits tensor to map index->token_id one-to-one (the index is always
equal the token_id it corresponds to). The indices returned by this
method allow us to do that.
"""
if self.tp_size < 2:
return None
base_embeddings: List[int] = []
added_embeddings: List[int] = []
padding: List[int] = []
for tp_rank in range(self.tp_size):
shard_indices = self._get_indices(self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size, tp_rank,
self.tp_size)
range_start = self.num_embeddings_per_partition * tp_rank
range_end = self.num_embeddings_per_partition * (tp_rank + 1)
base_embeddings.extend(
range(range_start,
range_start + shard_indices.num_org_elements))
padding.extend(
range(range_start + shard_indices.num_org_elements,
range_start + shard_indices.num_org_elements_padded))
added_embeddings.extend(
range(
range_start + shard_indices.num_org_elements_padded,
range_start + shard_indices.num_org_elements_padded +
shard_indices.num_added_elements))
padding.extend(
range(
range_start + shard_indices.num_org_elements_padded +
shard_indices.num_added_elements,
range_start + shard_indices.num_org_elements_padded +
shard_indices.num_added_elements_padded))
assert (range_start + shard_indices.num_org_elements_padded +
shard_indices.num_added_elements_padded == range_end)
ret = base_embeddings + added_embeddings + padding
assert len(ret) == self.num_embeddings_padded
return ret
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
output_dim = getattr(param, "output_dim", None)
packed_dim = getattr(param, "packed_dim", None)
# If the parameter is a gguf weight, then load it directly.
if getattr(param, "is_gguf_weight_type", None):
param.data.copy_(loaded_weight)
param.weight_type = loaded_weight.item()
return
elif isinstance(param, UninitializedParameter):
shape = list(loaded_weight.shape)
if output_dim is not None:
shape[output_dim] = self.num_embeddings_per_partition
param.materialize(tuple(shape), dtype=loaded_weight.dtype)
# If parameter does not have output dim, then it should
# be copied onto all gpus (e.g. g_idx for act_order gptq).
if output_dim is None:
assert param.data.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
return
# Shard indexes for loading the weight
start_idx = self.shard_indices.org_vocab_start_index
shard_size = self.shard_indices.org_vocab_end_index - start_idx
# If param packed on the same dim we are sharding on, then
# need to adjust offsets of loaded weight by pack_factor.
if packed_dim is not None and packed_dim == output_dim:
packed_factor = param.packed_factor if isinstance(
param, BasevLLMParameter) else param.pack_factor
assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
param.packed_factor)
start_idx = start_idx // packed_factor
shard_size = shard_size // packed_factor
else:
assert loaded_weight.shape[output_dim] == self.org_vocab_size
# Copy the data. Select chunk corresponding to current shard.
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
param[loaded_weight.shape[0]:].data.fill_(0)
def forward(self, input_):
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = get_masked_input_and_mask(
input_, self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index)
else:
masked_input = input_
# Get the embeddings.
output_parallel = self.quant_method.embedding(self, masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
def extra_repr(self) -> str:
s = f"num_embeddings={self.num_embeddings_per_partition}"
s += f", embedding_dim={self.embedding_dim}"
s += f", org_vocab_size={self.org_vocab_size}"
s += f', num_embeddings_padded={self.num_embeddings_padded}'
s += f', tp_size={self.tp_size}'
return s
# SPDX-License-Identifier: Apache-2.0
# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/logger.py
"""Logging configuration for fastvideo.v1."""
import datetime
import json
import logging
import os
import sys
from functools import lru_cache, partial
from logging import Logger
from logging.config import dictConfig
from os import path
from types import MethodType
from typing import Any, Optional, cast
import fastvideo.v1.envs as envs
FASTVIDEO_CONFIGURE_LOGGING = envs.FASTVIDEO_CONFIGURE_LOGGING
FASTVIDEO_LOGGING_CONFIG_PATH = envs.FASTVIDEO_LOGGING_CONFIG_PATH
FASTVIDEO_LOGGING_LEVEL = envs.FASTVIDEO_LOGGING_LEVEL
FASTVIDEO_LOGGING_PREFIX = envs.FASTVIDEO_LOGGING_PREFIX
RED = '\033[91m'
GREEN = '\033[92m'
RESET = '\033[0;0m'
_warned_local_main_process = False
_warned_main_process = False
_FORMAT = (f"{FASTVIDEO_LOGGING_PREFIX}%(levelname)s %(asctime)s "
"[%(filename)s:%(lineno)d] %(message)s")
_DATE_FORMAT = "%m-%d %H:%M:%S"
DEFAULT_LOGGING_CONFIG = {
"formatters": {
"fastvideo": {
"class": "fastvideo.v1.logging_utils.NewLineFormatter",
"datefmt": _DATE_FORMAT,
"format": _FORMAT,
},
},
"handlers": {
"fastvideo": {
"class": "logging.StreamHandler",
"formatter": "fastvideo",
"level": FASTVIDEO_LOGGING_LEVEL,
"stream": "ext://sys.stdout",
},
},
"loggers": {
"fastvideo": {
"handlers": ["fastvideo"],
"level": "DEBUG",
"propagate": False,
},
},
"root": {
"handlers": ["fastvideo"],
"level": "DEBUG",
},
"version": 1,
"disable_existing_loggers": False
}
@lru_cache
def _print_info_once(logger: Logger, msg: str) -> None:
# Set the stacklevel to 2 to print the original caller's line info
logger.info(msg, stacklevel=2)
@lru_cache
def _print_warning_once(logger: Logger, msg: str) -> None:
# Set the stacklevel to 2 to print the original caller's line info
logger.warning(msg, stacklevel=2)
# TODO(will): add env variable to control this process-aware logging behavior
def _info(logger: Logger,
msg: object,
*args: Any,
main_process_only: bool = False,
local_main_process_only: bool = True,
**kwargs: Any) -> None:
"""Process-aware INFO level logging function.
This function controls logging behavior based on the process rank, allowing for
selective logging from specific processes in a distributed environment.
Args:
logger: The logger instance to use for logging
msg: The message format string to log
*args: Format string arguments
main_process_only: If True, only log if this is the global main process (RANK=0)
local_main_process_only: If True, only log if this is the local main process (LOCAL_RANK=0)
**kwargs: Additional keyword arguments to pass to the logger.log method
- stacklevel: Defaults to 2 to show the original caller's location
Note:
- When both main_process_only and local_main_process_only are True,
the message will be logged only if both conditions are met
- When both are False, the message will be logged from all processes
- By default, only logs from processes with LOCAL_RANK=0
"""
try:
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
except Exception:
local_rank = 0
rank = 0
is_main_process = rank == 0
is_local_main_process = local_rank == 0
if (main_process_only and is_main_process) or (local_main_process_only
and is_local_main_process):
logger.log(logging.INFO, msg, *args, **kwargs)
global _warned_local_main_process, _warned_main_process
if not _warned_local_main_process and local_main_process_only:
logger.warning(
'%s is_local_main_process is set to True, logging only from the local main process.%s',
GREEN,
RESET,
)
_warned_local_main_process = True
if not _warned_main_process and main_process_only:
logger.warning(
'%s is_main_process_only is set to True, logging only from the main process.%s',
GREEN,
RESET,
)
_warned_main_process = True
if not main_process_only and not local_main_process_only:
logger.log(logging.INFO, msg, *args, **kwargs)
class _FastvideoLogger(Logger):
"""
Note:
This class is just to provide type information.
We actually patch the methods directly on the :class:`logging.Logger`
instance to avoid conflicting with other libraries such as
`intel_extension_for_pytorch.utils._logger`.
"""
def info_once(self, msg: str) -> None:
"""
As :meth:`info`, but subsequent calls with the same message
are silently dropped.
"""
_print_info_once(self, msg)
def warning_once(self, msg: str) -> None:
"""
As :meth:`warning`, but subsequent calls with the same message
are silently dropped.
"""
_print_warning_once(self, msg)
def info( # type: ignore[override]
self,
msg: object,
*args: Any,
main_process_only: bool = False,
local_main_process_only: bool = True,
**kwargs: Any) -> None:
_info(self,
msg,
*args,
main_process_only=main_process_only,
local_main_process_only=local_main_process_only,
**kwargs)
def _configure_fastvideo_root_logger() -> None:
logging_config = dict[str, Any]()
if not FASTVIDEO_CONFIGURE_LOGGING and FASTVIDEO_LOGGING_CONFIG_PATH:
raise RuntimeError(
"FASTVIDEO_CONFIGURE_LOGGING evaluated to false, but "
"FASTVIDEO_LOGGING_CONFIG_PATH was given. FASTVIDEO_LOGGING_CONFIG_PATH "
"implies FASTVIDEO_CONFIGURE_LOGGING. Please enable "
"FASTVIDEO_CONFIGURE_LOGGING or unset FASTVIDEO_LOGGING_CONFIG_PATH."
)
if FASTVIDEO_CONFIGURE_LOGGING:
logging_config = DEFAULT_LOGGING_CONFIG
if FASTVIDEO_LOGGING_CONFIG_PATH:
if not path.exists(FASTVIDEO_LOGGING_CONFIG_PATH):
raise RuntimeError(
"Could not load logging config. File does not exist: %s",
FASTVIDEO_LOGGING_CONFIG_PATH)
with open(FASTVIDEO_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
custom_config = json.loads(file.read())
if not isinstance(custom_config, dict):
raise ValueError("Invalid logging config. Expected Dict, got %s.",
type(custom_config).__name__)
logging_config = custom_config
for formatter in logging_config.get("formatters", {}).values():
# This provides backwards compatibility after #10134.
if formatter.get("class") == "fastvideo.v1.logging.NewLineFormatter":
formatter["class"] = "fastvideo.v1.logging_utils.NewLineFormatter"
if logging_config:
dictConfig(logging_config)
def init_logger(name: str) -> _FastvideoLogger:
"""The main purpose of this function is to ensure that loggers are
retrieved in such a way that we can be sure the root fastvideo logger has
already been configured."""
logger = logging.getLogger(name)
methods_to_patch = {
"info_once": _print_info_once,
"warning_once": _print_warning_once,
"info": _info,
}
for method_name, method in methods_to_patch.items():
setattr(logger, method_name,
MethodType(method, logger)) # type: ignore[arg-type]
return cast(_FastvideoLogger, logger)
# The root logger is initialized when the module is imported.
# This is thread-safe as the module is only imported once,
# guaranteed by the Python GIL.
_configure_fastvideo_root_logger()
logger = init_logger(__name__)
def _trace_calls(log_path, root_dir, frame, event, arg=None):
if event in ['call', 'return']:
# Extract the filename, line number, function name, and the code object
filename = frame.f_code.co_filename
lineno = frame.f_lineno
func_name = frame.f_code.co_name
if not filename.startswith(root_dir):
# only log the functions in the fastvideo root_dir
return
# Log every function call or return
try:
last_frame = frame.f_back
if last_frame is not None:
last_filename = last_frame.f_code.co_filename
last_lineno = last_frame.f_lineno
last_func_name = last_frame.f_code.co_name
else:
# initial frame
last_filename = ""
last_lineno = 0
last_func_name = ""
with open(log_path, 'a') as f:
ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
if event == 'call':
f.write(f"{ts} Call to"
f" {func_name} in {filename}:{lineno}"
f" from {last_func_name} in {last_filename}:"
f"{last_lineno}\n")
else:
f.write(f"{ts} Return from"
f" {func_name} in {filename}:{lineno}"
f" to {last_func_name} in {last_filename}:"
f"{last_lineno}\n")
except NameError:
# modules are deleted during shutdown
pass
return partial(_trace_calls, log_path, root_dir)
def enable_trace_function_call(log_file_path: str,
root_dir: Optional[str] = None):
"""
Enable tracing of every function call in code under `root_dir`.
This is useful for debugging hangs or crashes.
`log_file_path` is the path to the log file.
`root_dir` is the root directory of the code to trace. If None, it is the
fastvideo root directory.
Note that this call is thread-level, any threads calling this function
will have the trace enabled. Other threads will not be affected.
"""
logger.warning(
"FASTVIDEO_TRACE_FUNCTION is enabled. It will record every"
" function executed by Python. This will slow down the code. It "
"is suggested to be used for debugging hang or crashes only.")
logger.info("Trace frame log is saved to %s", log_file_path)
if root_dir is None:
# by default, this is the fastvideo root directory
root_dir = os.path.dirname(os.path.dirname(__file__))
sys.settrace(partial(_trace_calls, log_file_path, root_dir))
# SPDX-License-Identifier: Apache-2.0
from fastvideo.v1.logging_utils.formatter import NewLineFormatter
__all__ = [
"NewLineFormatter",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment