Commit 09d54a38 authored by luopl's avatar luopl
Browse files

"Initial commit"

parents
Pipeline #2695 failed with stages
in 0 seconds
import os
import torch
import torch.distributed as dist
from packaging import version
from dataclasses import dataclass, fields
from torch import distributed as dist
from xfuser.logger import init_logger
import xfuser.envs as envs
# from xfuser.envs import CUDA_VERSION, TORCH_VERSION, PACKAGES_CHECKER
from xfuser.envs import TORCH_VERSION, PACKAGES_CHECKER
logger = init_logger(__name__)
from typing import Union, Optional, List
env_info = PACKAGES_CHECKER.get_packages_info()
HAS_LONG_CTX_ATTN = env_info["has_long_ctx_attn"]
HAS_FLASH_ATTN = env_info["has_flash_attn"]
def check_packages():
import diffusers
if not version.parse(diffusers.__version__) > version.parse("0.30.2"):
raise RuntimeError(
"This project requires diffusers version > 0.30.2. Currently, you can not install a correct version of diffusers by pip install."
"Please install it from source code!"
)
def check_env():
# https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/cudagraph.html
#if CUDA_VERSION < version.parse("11.3"):
# raise RuntimeError("NCCL CUDA Graph support requires CUDA 11.3 or above")
if TORCH_VERSION < version.parse("2.2.0"):
# https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/
raise RuntimeError(
"CUDAGraph with NCCL support requires PyTorch 2.2.0 or above. "
"If it is not released yet, please install nightly built PyTorch "
"with `pip3 install --pre torch torchvision torchaudio --index-url "
"https://download.pytorch.org/whl/nightly/cu121`"
)
@dataclass
class ModelConfig:
model: str
download_dir: Optional[str] = None
trust_remote_code: bool = False
@dataclass
class RuntimeConfig:
warmup_steps: int = 1
dtype: torch.dtype = torch.float16
use_cuda_graph: bool = False
use_parallel_vae: bool = False
use_profiler: bool = False
use_torch_compile: bool = False
use_onediff: bool = False
use_fp8_t5_encoder: bool = False
def __post_init__(self):
check_packages()
if self.use_cuda_graph:
check_env()
@dataclass
class FastAttnConfig:
use_fast_attn: bool = False
n_step: int = 20
n_calib: int = 8
threshold: float = 0.5
window_size: int = 64
coco_path: Optional[str] = None
use_cache: bool = False
def __post_init__(self):
assert self.n_calib > 0, "n_calib must be greater than 0"
assert self.threshold > 0.0, "threshold must be greater than 0"
@dataclass
class DataParallelConfig:
dp_degree: int = 1
use_cfg_parallel: bool = False
world_size: int = 1
def __post_init__(self):
assert self.dp_degree >= 1, "dp_degree must greater than or equal to 1"
# set classifier_free_guidance_degree parallel for split batch
if self.use_cfg_parallel:
self.cfg_degree = 2
else:
self.cfg_degree = 1
assert self.dp_degree * self.cfg_degree <= self.world_size, (
"dp_degree * cfg_degree must be less than or equal to "
"world_size because of classifier free guidance"
)
assert (
self.world_size % (self.dp_degree * self.cfg_degree) == 0
), "world_size must be divisible by dp_degree * cfg_degree"
@dataclass
class SequenceParallelConfig:
ulysses_degree: Optional[int] = None
ring_degree: Optional[int] = None
world_size: int = 1
def __post_init__(self):
if self.ulysses_degree is None:
self.ulysses_degree = 1
logger.info(
f"Ulysses degree not set, " f"using default value {self.ulysses_degree}"
)
if self.ring_degree is None:
self.ring_degree = 1
logger.info(
f"Ring degree not set, " f"using default value {self.ring_degree}"
)
self.sp_degree = self.ulysses_degree * self.ring_degree
if not HAS_LONG_CTX_ATTN and self.sp_degree > 1:
raise ImportError(
f"Sequence Parallel kit 'yunchang' not found but "
f"sp_degree is {self.sp_degree}, please set it "
f"to 1 or install 'yunchang' to use it"
)
@dataclass
class TensorParallelConfig:
tp_degree: int = 1
split_scheme: Optional[str] = "row"
world_size: int = 1
def __post_init__(self):
assert self.tp_degree >= 1, "tp_degree must greater than 1"
assert (
self.tp_degree <= self.world_size
), "tp_degree must be less than or equal to world_size"
@dataclass
class PipeFusionParallelConfig:
pp_degree: int = 1
num_pipeline_patch: Optional[int] = None
attn_layer_num_for_pp: Optional[List[int]] = (None,)
world_size: int = 1
def __post_init__(self):
assert (
self.pp_degree is not None and self.pp_degree >= 1
), "pipefusion_degree must be set and greater than 1 to use pipefusion"
assert (
self.pp_degree <= self.world_size
), "pipefusion_degree must be less than or equal to world_size"
if self.num_pipeline_patch is None:
self.num_pipeline_patch = self.pp_degree
logger.info(
f"Pipeline patch number not set, "
f"using default value {self.pp_degree}"
)
if self.attn_layer_num_for_pp is not None:
logger.info(
f"attn_layer_num_for_pp set, splitting attention layers"
f"to {self.attn_layer_num_for_pp}"
)
assert len(self.attn_layer_num_for_pp) == self.pp_degree, (
"attn_layer_num_for_pp must have the same "
"length as pp_degree if not None"
)
if self.pp_degree == 1 and self.num_pipeline_patch > 1:
logger.warning(
f"Pipefusion degree is 1, pipeline will not be used,"
f"num_pipeline_patch will be ignored"
)
self.num_pipeline_patch = 1
@dataclass
class ParallelConfig:
dp_config: DataParallelConfig
sp_config: SequenceParallelConfig
pp_config: PipeFusionParallelConfig
tp_config: TensorParallelConfig
world_size: int = 1 # FIXME: remove this
worker_cls: str = "xfuser.ray.worker.worker.Worker"
def __post_init__(self):
assert self.tp_config is not None, "tp_config must be set"
assert self.dp_config is not None, "dp_config must be set"
assert self.sp_config is not None, "sp_config must be set"
assert self.pp_config is not None, "pp_config must be set"
parallel_world_size = (
self.dp_config.dp_degree
* self.dp_config.cfg_degree
* self.sp_config.sp_degree
* self.tp_config.tp_degree
* self.pp_config.pp_degree
)
world_size = self.world_size
assert parallel_world_size == world_size, (
f"parallel_world_size {parallel_world_size} "
f"must be equal to world_size {self.world_size}"
)
assert (
world_size % (self.dp_config.dp_degree * self.dp_config.cfg_degree) == 0
), "world_size must be divisible by dp_degree * cfg_degree"
assert (
world_size % self.pp_config.pp_degree == 0
), "world_size must be divisible by pp_degree"
assert (
world_size % self.sp_config.sp_degree == 0
), "world_size must be divisible by sp_degree"
assert (
world_size % self.tp_config.tp_degree == 0
), "world_size must be divisible by tp_degree"
self.dp_degree = self.dp_config.dp_degree
self.cfg_degree = self.dp_config.cfg_degree
self.sp_degree = self.sp_config.sp_degree
self.pp_degree = self.pp_config.pp_degree
self.tp_degree = self.tp_config.tp_degree
self.ulysses_degree = self.sp_config.ulysses_degree
self.ring_degree = self.sp_config.ring_degree
@dataclass(frozen=True)
class EngineConfig:
model_config: ModelConfig
runtime_config: RuntimeConfig
parallel_config: ParallelConfig
fast_attn_config: FastAttnConfig
def __post_init__(self):
world_size = self.parallel_config.world_size
if self.fast_attn_config.use_fast_attn:
assert self.parallel_config.dp_degree == world_size, f"world_size must be equal to dp_degree when using DiTFastAttn"
def to_dict(self):
"""Return the configs as a dictionary, for use in **kwargs."""
return dict((field.name, getattr(self, field.name)) for field in fields(self))
@dataclass
class InputConfig:
height: int = 1024
width: int = 1024
num_frames: int = 49
use_resolution_binning: bool = (True,)
batch_size: Optional[int] = None
img_file_path: Optional[str] = None
prompt: Union[str, List[str]] = ""
negative_prompt: Union[str, List[str]] = ""
num_inference_steps: int = 20
max_sequence_length: int = 256
seed: int = 42
output_type: str = "pil"
def __post_init__(self):
if isinstance(self.prompt, list):
assert (
len(self.prompt) == len(self.negative_prompt)
or len(self.negative_prompt) == 0
), "prompts and negative_prompts must have the same quantities"
self.batch_size = self.batch_size or len(self.prompt)
else:
self.batch_size = self.batch_size or 1
assert self.output_type in [
"pil",
"latent",
"pt",
], "output_pil must be either 'pil' or 'latent'"
import os
import torch
import diffusers
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from packaging import version
from xfuser.logger import init_logger
logger = init_logger(__name__)
if TYPE_CHECKING:
MASTER_ADDR: str = ""
MASTER_PORT: Optional[int] = None
CUDA_HOME: Optional[str] = None
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
XDIT_LOGGING_LEVEL: str = "INFO"
CUDA_VERSION: version.Version
TORCH_VERSION: version.Version
environment_variables: Dict[str, Callable[[], Any]] = {
# ================== Runtime Env Vars ==================
# used in distributed environment to determine the master address
"MASTER_ADDR": lambda: os.getenv("MASTER_ADDR", ""),
# used in distributed environment to manually set the communication port
"MASTER_PORT": lambda: (
int(os.getenv("MASTER_PORT", "0")) if "MASTER_PORT" in os.environ else None
),
# path to cudatoolkit home directory, under which should be bin, include,
# and lib directories.
"CUDA_HOME": lambda: os.environ.get("CUDA_HOME", None),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")),
# used to control the visible devices in the distributed setting
"CUDA_VISIBLE_DEVICES": lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None),
# this is used for configuring the default logging level
"XDIT_LOGGING_LEVEL": lambda: os.getenv("XDIT_LOGGING_LEVEL", "INFO"),
}
variables: Dict[str, Callable[[], Any]] = {
# ================== Other Vars ==================
# used in version checking
# "CUDA_VERSION": lambda: version.parse(torch.version.cuda),
"CUDA_VERSION": "gfx928",
"TORCH_VERSION": lambda: version.parse(
version.parse(torch.__version__).base_version
),
}
class PackagesEnvChecker:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(PackagesEnvChecker, cls).__new__(cls)
cls._instance.initialize()
return cls._instance
def initialize(self):
self.packages_info = {
"has_flash_attn": self.check_flash_attn(),
"has_long_ctx_attn": self.check_long_ctx_attn(),
"diffusers_version": self.check_diffusers_version(),
}
def check_flash_attn(self):
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpu_name = torch.cuda.get_device_name(device)
if "Turing" in gpu_name or "Tesla" in gpu_name or "T4" in gpu_name:
return False
else:
from flash_attn import flash_attn_func
from flash_attn import __version__
if __version__ < "2.6.0":
raise ImportError(f"install flash_attn >= 2.6.0")
return True
except ImportError:
logger.warning(
f'Flash Attention library "flash_attn" not found, '
f"using pytorch attention implementation"
)
return False
def check_long_ctx_attn(self):
try:
from yunchang import (
set_seq_parallel_pg,
ring_flash_attn_func,
UlyssesAttention,
LongContextAttention,
LongContextAttentionQKVPacked,
)
return True
except ImportError:
logger.warning(
f'Ring Flash Attention library "yunchang" not found, '
f"using pytorch attention implementation"
)
return False
def check_diffusers_version(self):
if version.parse(
version.parse(diffusers.__version__).base_version
) < version.parse("0.30.0"):
raise RuntimeError(
f"Diffusers version: {version.parse(version.parse(diffusers.__version__).base_version)} is not supported,"
f"please upgrade to version > 0.30.0"
)
return version.parse(version.parse(diffusers.__version__).base_version)
def get_packages_info(self):
return self.packages_info
PACKAGES_CHECKER = PackagesEnvChecker()
def __getattr__(name):
# lazy evaluation of environment variables
if name in environment_variables:
return environment_variables[name]()
if name in variables:
return variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__():
return list(environment_variables.keys())
parallel=2 # or parallel=8
url='127.0.0.1'
model_dir=/home/luopl1/stepfun-ai/stepvideo-t2v
tp_degree=2
ulysses_degree=1
# make sure tp_degree x ulysses_degree = parallel
torchrun --nproc_per_node $parallel run_parallel.py --model_dir $model_dir --vae_url $url --caption_url $url --ulysses_degree $ulysses_degree --tensor_parallel_degree $tp_degree --prompt "一名宇航员在月球上发现一块石碑,上面印有“stepfun”字样,闪闪发光" --height 200 --width 200 --infer_steps 20 --cfg_scale 9.0 --time_shift 13.0
\ No newline at end of file
from stepvideo.diffusion.video_pipeline import StepVideoPipeline
import torch.distributed as dist
import torch
from stepvideo.config import parse_args
from stepvideo.parallel import initialize_parall_group, get_parallel_group
from stepvideo.utils import setup_seed
from xfuser.model_executor.models.customized.step_video_t2v.tp_applicator import TensorParallelApplicator
from xfuser.core.distributed.parallel_state import get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank
if __name__ == "__main__":
args = parse_args()
initialize_parall_group(ring_degree=args.ring_degree, ulysses_degree=args.ulysses_degree, tensor_parallel_degree=args.tensor_parallel_degree)
local_rank = get_parallel_group().local_rank
device = torch.device(f"cuda:{local_rank}")
setup_seed(args.seed)
pipeline = StepVideoPipeline.from_pretrained(args.model_dir).to(dtype=torch.bfloat16, device="cpu")
if args.tensor_parallel_degree > 1:
tp_applicator = TensorParallelApplicator(get_tensor_model_parallel_world_size(), get_tensor_model_parallel_rank())
tp_applicator.apply_to_model(pipeline.transformer)
pipeline.transformer = pipeline.transformer.to(device)
pipeline.setup_api(
vae_url = args.vae_url,
caption_url = args.caption_url,
)
prompt = args.prompt
videos = pipeline(
prompt=prompt,
num_frames=args.num_frames,
height=args.height,
width=args.width,
num_inference_steps = args.infer_steps,
guidance_scale=args.cfg_scale,
time_shift=args.time_shift,
pos_magic=args.pos_magic,
neg_magic=args.neg_magic,
output_file_name=prompt[:50]
)
dist.destroy_process_group()
\ No newline at end of file
from setuptools import find_packages, setup
import subprocess
def get_cuda_version():
try:
nvcc_version = subprocess.check_output(["nvcc", "--version"]).decode("utf-8")
version_line = [line for line in nvcc_version.split("\n") if "release" in line][
0
]
cuda_version = version_line.split(" ")[-2].replace(",", "")
return "cu" + cuda_version.replace(".", "")
except Exception as e:
return "no_cuda"
if __name__ == "__main__":
with open("README.md", "r") as f:
long_description = f.read()
fp = open("stepvideo/__version__.py", "r").read()
version = eval(fp.strip().split()[-1])
setup(
name="stepvideo",
author="Step-Video Team",
packages=find_packages(),
install_requires=[
"torchvision==0.18.1+das.opt1.dtk24042",
"torch==2.3.0+das.opt2.dtk24043",
"accelerate==1.0.0",
"transformers==4.39.1",
"diffusers==0.31.0",
"sentencepiece==0.1.99",
"imageio==2.37.0",
"optimus==2.1",
"numpy",
"einops",
"aiohttp",
"asyncio",
"flask",
"flask_restful",
"ffmpeg-python",
"requests",
"xfuser==0.4.2rc2"
],
url="",
description="A 30B DiT based text to video and image generation model",
long_description=long_description,
long_description_content_type="text/markdown",
version=version,
classifiers=[
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
],
include_package_data=True,
python_requires=">=3.10",
)
\ No newline at end of file
import os
os.environ["NCCL_DEBUG"] = "ERROR"
from .diffusion.scheduler import *
from .diffusion.video_pipeline import *
from .modules.model import *
\ No newline at end of file
__version__ = "0.1.0"
\ No newline at end of file
import argparse
def parse_args(namespace=None):
parser = argparse.ArgumentParser(description="StepVideo inference script")
parser = add_extra_models_args(parser)
parser = add_denoise_schedule_args(parser)
parser = add_inference_args(parser)
parser = add_parallel_args(parser)
args = parser.parse_args(namespace=namespace)
return args
def add_extra_models_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="Extra models args, including vae, text encoders and tokenizers)"
)
group.add_argument(
"--vae_url",
type=str,
default='127.0.0.1',
help="vae url.",
)
group.add_argument(
"--caption_url",
type=str,
default='127.0.0.1',
help="caption url.",
)
return parser
def add_denoise_schedule_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Denoise schedule args")
# Flow Matching
group.add_argument(
"--time_shift",
type=float,
default=7.0,
help="Shift factor for flow matching schedulers.",
)
group.add_argument(
"--flow_reverse",
action="store_true",
help="If reverse, learning/sampling from t=1 -> t=0.",
)
group.add_argument(
"--flow_solver",
type=str,
default="euler",
help="Solver for flow matching.",
)
return parser
def add_inference_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Inference args")
# ======================== Model loads ========================
group.add_argument(
"--model_dir",
type=str,
default="./ckpts",
help="Root path of all the models, including t2v models and extra models.",
)
group.add_argument(
"--model_resolution",
type=str,
default="540p",
choices=["540p"],
help="Root path of all the models, including t2v models and extra models.",
)
group.add_argument(
"--use-cpu-offload",
action="store_true",
help="Use CPU offload for the model load.",
)
# ======================== Inference general setting ========================
group.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size for inference and evaluation.",
)
group.add_argument(
"--infer_steps",
type=int,
default=50,
help="Number of denoising steps for inference.",
)
group.add_argument(
"--save_path",
type=str,
default="./results",
help="Path to save the generated samples.",
)
group.add_argument(
"--name_suffix",
type=str,
default="",
help="Suffix for the names of saved samples.",
)
group.add_argument(
"--num_videos",
type=int,
default=1,
help="Number of videos to generate for each prompt.",
)
# ---sample size---
group.add_argument(
"--num_frames",
type=int,
default=204,
help="How many frames to sample from a video. ",
)
group.add_argument(
"--height",
type=int,
default=544,
help="The height of video sample",
)
group.add_argument(
"--width",
type=int,
default=992,
help="The width of video sample",
)
# --- prompt ---
group.add_argument(
"--prompt",
type=str,
default=None,
help="Prompt for sampling during evaluation.",
)
group.add_argument("--seed", type=int, default=1234, help="Seed for evaluation.")
# Classifier-Free Guidance
group.add_argument(
"--pos_magic", type=str, default="超高清、HDR 视频、环境光、杜比全景声、画面稳定、流畅动作、逼真的细节、专业级构图、超现实主义、自然、生动、超细节、清晰。", help="Positive magic prompt for sampling."
)
group.add_argument(
"--neg_magic", type=str, default="画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。", help="Negative magic prompt for sampling."
)
group.add_argument(
"--cfg_scale", type=float, default=9.0, help="Classifier free guidance scale."
)
return parser
def add_parallel_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Parallel args")
# ======================== Model loads ========================
group.add_argument(
"--ulysses_degree",
type=int,
default=8,
help="Ulysses degree.",
)
group.add_argument(
"--ring_degree",
type=int,
default=1,
help="Ulysses degree.",
)
group.add_argument(
"--tensor_parallel_degree",
type=int,
default=1,
help="Tensor parallel degree.",
)
return parser
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, logging
from diffusers.schedulers.scheduling_utils import SchedulerMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class FlowMatchDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Euler scheduler.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
reverse (`bool`, defaults to `True`):
Whether to reverse the timestep schedule.
"""
_compatibles = []
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
reverse: bool = False,
solver: str = "euler",
device: Union[str, torch.device] = None,
):
sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
if not reverse:
sigmas = sigmas.flip(0)
self.sigmas = sigmas
# the value fed to model
self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
self._step_index = None
self._begin_index = None
self.device = device
self.supported_solver = ["euler"]
if solver not in self.supported_solver:
raise ValueError(
f"Solver {solver} not supported. Supported solvers: {self.supported_solver}"
)
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def set_timesteps(
self,
num_inference_steps: int,
time_shift: float = 13.0,
device: Union[str, torch.device] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
n_tokens (`int`, *optional*):
Number of tokens in the input sequence.
"""
device = device or self.device
self.num_inference_steps = num_inference_steps
sigmas = torch.linspace(1, 0, num_inference_steps + 1, device=device)
sigmas = self.sd3_time_shift(sigmas, time_shift)
if not self.config.reverse:
sigmas = 1 - sigmas
self.sigmas = sigmas
self.timesteps = sigmas[:-1]
# Reset step index
self._step_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def scale_model_input(
self, sample: torch.Tensor, timestep: Optional[int] = None
) -> torch.Tensor:
return sample
def sd3_time_shift(self, t: torch.Tensor, time_shift: float = 13.0):
return (time_shift * t) / (1 + (time_shift - 1) * t)
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
return_dict: bool = False,
) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
n_tokens (`int`, *optional*):
Number of tokens in the input sequence.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if self.step_index is None:
self._init_step_index(timestep)
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
if self.config.solver == "euler":
prev_sample = sample + model_output.to(torch.float32) * dt
else:
raise ValueError(
f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}"
)
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return prev_sample
return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
def __len__(self):
return self.config.num_train_timesteps
# Copyright 2025 StepFun Inc. All Rights Reserved.
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
import numpy as np
import pickle
import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import BaseOutput
import asyncio
from stepvideo.modules.model import StepVideoModel
from stepvideo.diffusion.scheduler import FlowMatchDiscreteScheduler
from stepvideo.utils import VideoProcessor
def call_api_gen(url, api, port=8080):
url =f"http://{url}:{port}/{api}-api"
import aiohttp
async def _fn(samples, *args, **kwargs):
if api=='vae':
data = {
"samples": samples,
}
elif api == 'caption':
data = {
"prompts": samples,
}
else:
raise Exception(f"Not supported api: {api}...")
async with aiohttp.ClientSession() as sess:
data_bytes = pickle.dumps(data)
async with sess.get(url, data=data_bytes, timeout=12000) as response:
result = bytearray()
while not response.content.at_eof():
chunk = await response.content.read(1024)
result += chunk
response_data = pickle.loads(result)
return response_data
return _fn
@dataclass
class StepVideoPipelineOutput(BaseOutput):
video: Union[torch.Tensor, np.ndarray]
class StepVideoPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-video generation using StepVideo.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
transformer ([`StepVideoModel`]):
Conditional Transformer to denoise the encoded image latents.
scheduler ([`FlowMatchDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae_url:
remote vae server's url.
caption_url:
remote caption (stepllm and clip) server's url.
"""
def __init__(
self,
transformer: StepVideoModel,
scheduler: FlowMatchDiscreteScheduler,
vae_url: str = '127.0.0.1',
caption_url: str = '127.0.0.1',
save_path: str = './results',
name_suffix: str = '',
):
super().__init__()
self.register_modules(
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 8
self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16
self.video_processor = VideoProcessor(save_path, name_suffix)
self.vae_url = vae_url
self.caption_url = caption_url
self.setup_api(self.vae_url, self.caption_url)
def setup_api(self, vae_url, caption_url):
self.vae_url = vae_url
self.caption_url = caption_url
self.caption = call_api_gen(caption_url, 'caption')
self.vae = call_api_gen(vae_url, 'vae')
return self
def encode_prompt(
self,
prompt: str,
neg_magic: str = '',
pos_magic: str = '',
):
device = self._execution_device
prompts = [prompt+pos_magic]
bs = len(prompts)
prompts += [neg_magic]*bs
data = asyncio.run(self.caption(prompts))
prompt_embeds, prompt_attention_mask, clip_embedding = data['y'].to(device), data['y_mask'].to(device), data['clip_embedding'].to(device)
return prompt_embeds, clip_embedding, prompt_attention_mask
def decode_vae(self, samples):
samples = asyncio.run(self.vae(samples.cpu()))
return samples
def check_inputs(self, num_frames, width, height):
num_frames = max(num_frames//17*17, 1)
width = max(width//16*16, 16)
height = max(height//16*16, 16)
return num_frames, width, height
def prepare_latents(
self,
batch_size: int,
num_channels_latents: 64,
height: int = 544,
width: int = 992,
num_frames: int = 204,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
num_frames, width, height = self.check_inputs(num_frames, width, height)
shape = (
batch_size,
max(num_frames//17*3, 1),
num_channels_latents,
int(height) // self.vae_scale_factor_spatial,
int(width) // self.vae_scale_factor_spatial,
) # b,f,c,h,w
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if generator is None:
generator = torch.Generator(device=self._execution_device)
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
return latents
@torch.inference_mode()
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: int = 544,
width: int = 992,
num_frames: int = 204,
num_inference_steps: int = 50,
guidance_scale: float = 9.0,
time_shift: float = 13.0,
neg_magic: str = "",
pos_magic: str = "",
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "mp4",
output_file_name: Optional[str] = "",
return_dict: bool = True,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, defaults to `544`):
The height in pixels of the generated image.
width (`int`, defaults to `992`):
The width in pixels of the generated image.
num_frames (`int`, defaults to `204`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `9.0`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
output_file_name(`str`, *optional*`):
The output mp4 file name.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`StepVideoPipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`~StepVideoPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`StepVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
where the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
# 1. Check inputs. Raise error if not correct
device = self._execution_device
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, prompt_embeds_2, prompt_attention_mask = self.encode_prompt(
prompt=prompt,
neg_magic=neg_magic,
pos_magic=pos_magic,
)
transformer_dtype = self.transformer.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
prompt_embeds_2 = prompt_embeds_2.to(transformer_dtype)
# 4. Prepare timesteps
self.scheduler.set_timesteps(
num_inference_steps=num_inference_steps,
time_shift=time_shift,
device=device
)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.bfloat16,
device,
generator,
latents,
)
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(self.scheduler.timesteps):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = latent_model_input.to(transformer_dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
encoder_hidden_states_2=prompt_embeds_2,
return_dict=False,
)
# perform guidance
if do_classifier_free_guidance:
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
model_output=noise_pred,
timestep=t,
sample=latents
)
progress_bar.update()
if not torch.distributed.is_initialized() or int(torch.distributed.get_rank())==0:
if not output_type == "latent":
video = self.decode_vae(latents)
video = self.video_processor.postprocess_video(video, output_file_name=output_file_name, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video, )
return StepVideoPipelineOutput(video=video)
\ No newline at end of file
import torch
import torch.nn as nn
from einops import rearrange
try:
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
except ImportError:
xFuserLongContextAttention = None
class Attention(nn.Module):
def __init__(self):
super().__init__()
def attn_processor(self, attn_type):
if attn_type == 'torch':
return self.torch_attn_func
elif attn_type == 'parallel':
return self.parallel_attn_func
else:
raise Exception('Not supported attention type...')
def torch_attn_func(
self,
q,
k,
v,
attn_mask=None,
causal=False,
drop_rate=0.0,
**kwargs
):
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
if attn_mask is not None and attn_mask.ndim == 3: ## no head
n_heads = q.shape[2]
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
q, k, v = map(lambda x: rearrange(x, 'b s h d -> b h s d'), (q, k, v))
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
)
x = rearrange(x, 'b h s d -> b s h d')
return x
def parallel_attn_func(
self,
q,
k,
v,
causal=False,
**kwargs
):
assert xFuserLongContextAttention is not None; 'to use sequence parallel attention, xFuserLongContextAttention should be imported...'
hybrid_seq_parallel_attn = xFuserLongContextAttention()
x = hybrid_seq_parallel_attn(
None, q,k,v, causal=causal
)
return x
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
import torch
import torch.nn as nn
from typing import Optional
from einops import rearrange
from stepvideo.modules.rope import RoPE3D
from stepvideo.modules.attentions import Attention
from stepvideo.modules.normalization import RMSNorm
class SelfAttention(Attention):
def __init__(self, hidden_dim, head_dim, bias=False, with_rope=True, with_qk_norm=True, attn_type='torch'):
super().__init__()
self.head_dim = head_dim
self.n_heads = hidden_dim // head_dim
self.wqkv = nn.Linear(hidden_dim, hidden_dim*3, bias=bias)
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=bias)
self.with_rope = with_rope
self.with_qk_norm = with_qk_norm
if self.with_qk_norm:
self.q_norm = RMSNorm(head_dim, elementwise_affine=True)
self.k_norm = RMSNorm(head_dim, elementwise_affine=True)
if self.with_rope:
self.rope_3d = RoPE3D(freq=1e4, F0=1.0, scaling_factor=1.0)
self.rope_ch_split = [64, 32, 32]
self.core_attention = self.attn_processor(attn_type=attn_type)
self.parallel = attn_type=='parallel'
def apply_rope3d(self, x, fhw_positions, rope_ch_split, parallel=True):
x = self.rope_3d(x, fhw_positions, rope_ch_split, parallel)
return x
def forward(
self,
x,
cu_seqlens=None,
max_seqlen=None,
rope_positions=None,
attn_mask=None
):
xqkv = self.wqkv(x)
xqkv = xqkv.view(*x.shape[:-1], self.n_heads, 3*self.head_dim)
xq, xk, xv = torch.split(xqkv, [self.head_dim]*3, dim=-1) ## seq_len, n, dim
if self.with_qk_norm:
xq = self.q_norm(xq)
xk = self.k_norm(xk)
if self.with_rope:
xq = self.apply_rope3d(xq, rope_positions, self.rope_ch_split, parallel=self.parallel)
xk = self.apply_rope3d(xk, rope_positions, self.rope_ch_split, parallel=self.parallel)
output = self.core_attention(
xq,
xk,
xv,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
attn_mask=attn_mask
)
output = rearrange(output, 'b s h d -> b s (h d)')
output = self.wo(output)
return output
class CrossAttention(Attention):
def __init__(self, hidden_dim, head_dim, bias=False, with_qk_norm=True, attn_type='torch'):
super().__init__()
self.head_dim = head_dim
self.n_heads = hidden_dim // head_dim
self.wq = nn.Linear(hidden_dim, hidden_dim, bias=bias)
self.wkv = nn.Linear(hidden_dim, hidden_dim*2, bias=bias)
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=bias)
self.with_qk_norm = with_qk_norm
if self.with_qk_norm:
self.q_norm = RMSNorm(head_dim, elementwise_affine=True)
self.k_norm = RMSNorm(head_dim, elementwise_affine=True)
self.core_attention = self.attn_processor(attn_type=attn_type)
def forward(
self,
x: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attn_mask=None
):
xq = self.wq(x)
xq = xq.view(*xq.shape[:-1], self.n_heads, self.head_dim)
xkv = self.wkv(encoder_hidden_states)
xkv = xkv.view(*xkv.shape[:-1], self.n_heads, 2*self.head_dim)
xk, xv = torch.split(xkv, [self.head_dim]*2, dim=-1) ## seq_len, n, dim
if self.with_qk_norm:
xq = self.q_norm(xq)
xk = self.k_norm(xk)
output = self.core_attention(
xq,
xk,
xv,
attn_mask=attn_mask
)
output = rearrange(output, 'b s h d -> b s (h d)')
output = self.wo(output)
return output
class GELU(nn.Module):
r"""
GELU activation function with tanh approximation support with `approximate="tanh"`.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
self.approximate = approximate
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(gate, approximate=self.approximate)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = self.gelu(hidden_states)
return hidden_states
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
inner_dim: Optional[int] = None,
dim_out: Optional[int] = None,
mult: int = 4,
bias: bool = False,
):
super().__init__()
inner_dim = dim*mult if inner_dim is None else inner_dim
dim_out = dim if dim_out is None else dim_out
self.net = nn.ModuleList([
GELU(dim, inner_dim, approximate="tanh", bias=bias),
nn.Identity(),
nn.Linear(inner_dim, dim_out, bias=bias)
])
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
def modulate(x, scale, shift):
x = x * (1 + scale) + shift
return x
def gate(x, gate):
x = gate * x
return x
class StepVideoTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
attention_head_dim: int,
norm_eps: float = 1e-5,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = False,
attention_type: str = 'parallel'
):
super().__init__()
self.dim = dim
self.norm1 = nn.LayerNorm(dim, eps=norm_eps)
self.attn1 = SelfAttention(dim, attention_head_dim, bias=False, with_rope=True, with_qk_norm=True, attn_type=attention_type)
self.norm2 = nn.LayerNorm(dim, eps=norm_eps)
self.attn2 = CrossAttention(dim, attention_head_dim, bias=False, with_qk_norm=True, attn_type='torch')
self.ff = FeedForward(dim=dim, inner_dim=ff_inner_dim, dim_out=dim, bias=ff_bias)
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) /dim**0.5)
@torch.no_grad()
def forward(
self,
q: torch.Tensor,
kv: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
attn_mask = None,
rope_positions: list = None,
) -> torch.Tensor:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
torch.clone(chunk) for chunk in (self.scale_shift_table[None] + timestep.reshape(-1, 6, self.dim)).chunk(6, dim=1)
)
scale_shift_q = modulate(self.norm1(q), scale_msa, shift_msa)
attn_q = self.attn1(
scale_shift_q,
rope_positions=rope_positions
)
q = gate(attn_q, gate_msa) + q
attn_q = self.attn2(
q,
kv,
attn_mask
)
q = attn_q + q
scale_shift_q = modulate(self.norm2(q), scale_mlp, shift_mlp)
ff_output = self.ff(scale_shift_q)
q = gate(ff_output, gate_mlp) + q
return q
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding"""
def __init__(
self,
patch_size=64,
in_channels=3,
embed_dim=768,
layer_norm=False,
flatten=True,
bias=True,
):
super().__init__()
self.flatten = flatten
self.layer_norm = layer_norm
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
def forward(self, latent):
latent = self.proj(latent).to(latent.dtype)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
return latent
\ No newline at end of file
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
from typing import Any, Dict, Optional, Union
import torch
from torch import nn
import os
from einops import rearrange, repeat
from xfuser.model_executor.models.customized.step_video_t2v.blocks import StepVideoTransformerBlock, PatchEmbed
from stepvideo.utils import with_empty_init
from stepvideo.parallel import parallel_forward
from xfuser.model_executor.models.customized.step_video_t2v.normalization import (
PixArtAlphaTextProjection,
AdaLayerNormSingle
)
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
class StepVideoModel(ModelMixin, ConfigMixin):
_no_split_modules = ["StepVideoTransformerBlock", "PatchEmbed"]
@with_empty_init
@register_to_config
def __init__(
self,
num_attention_heads: int = 48,
attention_head_dim: int = 128,
in_channels: int = 64,
out_channels: Optional[int] = 64,
num_layers: int = 48,
dropout: float = 0.0,
patch_size: int = 1,
norm_type: str = "ada_norm_single",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
use_additional_conditions: Optional[bool] = False,
caption_channels: Optional[int]|list|tuple = [6144, 1024],
attention_type: Optional[str] = "parallel",
):
super().__init__()
# Set some common variables used across the board.
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.out_channels = in_channels if out_channels is None else out_channels
self.use_additional_conditions = use_additional_conditions
self.pos_embed = PatchEmbed(
patch_size=patch_size,
in_channels=self.config.in_channels,
embed_dim=self.inner_dim,
)
self.transformer_blocks = nn.ModuleList(
[
StepVideoTransformerBlock(
dim=self.inner_dim,
attention_head_dim=self.config.attention_head_dim,
attention_type=attention_type
)
for _ in range(self.config.num_layers)
]
)
# 3. Output blocks.
self.norm_out = nn.LayerNorm(self.inner_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels)
self.patch_size = patch_size
self.adaln_single = AdaLayerNormSingle(
self.inner_dim, use_additional_conditions=self.use_additional_conditions
)
if isinstance(self.config.caption_channels, int):
caption_channel = self.config.caption_channels
else:
caption_channel, clip_channel = self.config.caption_channels
self.clip_projection = nn.Linear(clip_channel, self.inner_dim)
self.caption_norm = nn.LayerNorm(caption_channel, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.caption_projection = PixArtAlphaTextProjection(
in_features=caption_channel, hidden_size=self.inner_dim
)
self.parallel = attention_type=='parallel'
def patchfy(self, hidden_states):
hidden_states = rearrange(hidden_states, 'b f c h w -> (b f) c h w')
hidden_states = self.pos_embed(hidden_states)
return hidden_states
def prepare_attn_mask(self, encoder_attention_mask, encoder_hidden_states, q_seqlen):
kv_seqlens = encoder_attention_mask.sum(dim=1).int()
mask = torch.zeros([len(kv_seqlens), q_seqlen, max(kv_seqlens)], dtype=torch.bool, device=encoder_attention_mask.device)
encoder_hidden_states = encoder_hidden_states[:,: max(kv_seqlens)]
for i, kv_len in enumerate(kv_seqlens):
mask[i, :, :kv_len] = 1
return encoder_hidden_states, mask
@parallel_forward
def block_forward(
self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
rope_positions=None,
attn_mask=None,
parallel=True
):
for i, block in enumerate(self.transformer_blocks):
hidden_states = block(
hidden_states,
encoder_hidden_states,
timestep=timestep,
attn_mask=attn_mask,
rope_positions=rope_positions
)
return hidden_states
@torch.inference_mode()
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_hidden_states_2: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
fps: torch.Tensor=None,
return_dict: bool = True,
):
assert hidden_states.ndim==5; "hidden_states's shape should be (bsz, f, ch, h ,w)"
bsz, frame, _, height, width = hidden_states.shape
height, width = height // self.patch_size, width // self.patch_size
hidden_states = self.patchfy(hidden_states)
len_frame = hidden_states.shape[1]
if self.use_additional_conditions:
added_cond_kwargs = {
"resolution": torch.tensor([(height, width)]*bsz, device=hidden_states.device, dtype=hidden_states.dtype),
"nframe": torch.tensor([frame]*bsz, device=hidden_states.device, dtype=hidden_states.dtype),
"fps": fps
}
else:
added_cond_kwargs = {}
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs=added_cond_kwargs
)
encoder_hidden_states = self.caption_projection(self.caption_norm(encoder_hidden_states))
if encoder_hidden_states_2 is not None and hasattr(self, 'clip_projection'):
clip_embedding = self.clip_projection(encoder_hidden_states_2)
encoder_hidden_states = torch.cat([clip_embedding, encoder_hidden_states], dim=1)
hidden_states = rearrange(hidden_states, '(b f) l d-> b (f l) d', b=bsz, f=frame, l=len_frame).contiguous()
encoder_hidden_states, attn_mask = self.prepare_attn_mask(encoder_attention_mask, encoder_hidden_states, q_seqlen=frame*len_frame)
hidden_states = self.block_forward(
hidden_states,
encoder_hidden_states,
timestep=timestep,
rope_positions=[frame, height, width],
attn_mask=attn_mask,
parallel=self.parallel
)
hidden_states = rearrange(hidden_states, 'b (f l) d -> (b f) l d', b=bsz, f=frame, l=len_frame)
embedded_timestep = repeat(embedded_timestep, 'b d -> (b f) d', f=frame).contiguous()
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
# unpatchify
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = rearrange(hidden_states, 'n h w p q c -> n c h p w q')
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
output = rearrange(output, '(b f) c h w -> b f c h w', f=frame)
if return_dict:
return {'x': output}
return output
\ No newline at end of file
from typing import Any, Dict, Optional, Union, Tuple
import torch
import torch.nn as nn
import math
class RMSNorm(nn.Module):
def __init__(
self,
dim: int,
elementwise_affine=True,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
ACTIVATION_FUNCTIONS = {
"swish": nn.SiLU(),
"silu": nn.SiLU(),
"mish": nn.Mish(),
"gelu": nn.GELU(),
"relu": nn.ReLU(),
}
def get_activation(act_fn: str) -> nn.Module:
"""Helper function to get activation function from string.
Args:
act_fn (str): Name of activation function.
Returns:
nn.Module: Activation function.
"""
act_fn = act_fn.lower()
if act_fn in ACTIVATION_FUNCTIONS:
return ACTIVATION_FUNCTIONS[act_fn]
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True
):
super().__init__()
linear_cls = nn.Linear
self.linear_1 = linear_cls(
in_channels,
time_embed_dim,
bias=sample_proj_bias,
)
if cond_proj_dim is not None:
self.cond_proj = linear_cls(
cond_proj_dim,
in_channels,
bias=False,
)
else:
self.cond_proj = None
self.act = get_activation(act_fn)
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = linear_cls(
time_embed_dim,
time_embed_dim_out,
bias=sample_proj_bias,
)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
super().__init__()
self.outdim = size_emb_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.use_additional_conditions = use_additional_conditions
if self.use_additional_conditions:
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
self.nframe_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.fps_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timestep, resolution=None, nframe=None, fps=None):
hidden_dtype = next(self.timestep_embedder.parameters()).dtype
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
if self.use_additional_conditions:
batch_size = timestep.shape[0]
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
nframe_emb = self.additional_condition_proj(nframe.flatten()).to(hidden_dtype)
nframe_emb = self.nframe_embedder(nframe_emb).reshape(batch_size, -1)
conditioning = timesteps_emb + resolution_emb + nframe_emb
if fps is not None:
fps_emb = self.additional_condition_proj(fps.flatten()).to(hidden_dtype)
fps_emb = self.fps_embedder(fps_emb).reshape(batch_size, -1)
conditioning = conditioning + fps_emb
else:
conditioning = timesteps_emb
return conditioning
class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
"""
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, time_step_rescale=1000):
super().__init__()
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 2, use_additional_conditions=use_additional_conditions
)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
self.time_step_rescale = time_step_rescale ## timestep usually in [0, 1], we rescale it to [0,1000] for stability
def forward(
self,
timestep: torch.Tensor,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
embedded_timestep = self.emb(timestep*self.time_step_rescale, **added_cond_kwargs)
out = self.linear(self.silu(embedded_timestep))
return out, embedded_timestep
class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_features, hidden_size):
super().__init__()
self.linear_1 = nn.Linear(
in_features,
hidden_size,
bias=True,
)
self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = nn.Linear(
hidden_size,
hidden_size,
bias=True,
)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
import torch
from stepvideo.parallel import get_sequence_parallel_world_size, get_sequence_parallel_rank
class RoPE1D:
def __init__(self, freq=1e4, F0=1.0, scaling_factor=1.0):
self.base = freq
self.F0 = F0
self.scaling_factor = scaling_factor
self.cache = {}
def get_cos_sin(self, D, seq_len, device, dtype):
if (D, seq_len, device, dtype) not in self.cache:
inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
freqs = torch.cat((freqs, freqs), dim=-1)
cos = freqs.cos() # (Seq, Dim)
sin = freqs.sin()
self.cache[D, seq_len, device, dtype] = (cos, sin)
return self.cache[D, seq_len, device, dtype]
@staticmethod
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rope1d(self, tokens, pos1d, cos, sin):
assert pos1d.ndim == 2
cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :]
sin = torch.nn.functional.embedding(pos1d, sin)[:, :, None, :]
return (tokens * cos) + (self.rotate_half(tokens) * sin)
def __call__(self, tokens, positions):
"""
input:
* tokens: batch_size x ntokens x nheads x dim
* positions: batch_size x ntokens (t position of each token)
output:
* tokens after appplying RoPE2D (batch_size x ntokens x nheads x dim)
"""
D = tokens.size(3)
assert positions.ndim == 2 # Batch, Seq
cos, sin = self.get_cos_sin(D, int(positions.max()) + 1, tokens.device, tokens.dtype)
tokens = self.apply_rope1d(tokens, positions, cos, sin)
return tokens
class RoPE3D(RoPE1D):
def __init__(self, freq=1e4, F0=1.0, scaling_factor=1.0):
super(RoPE3D, self).__init__(freq, F0, scaling_factor)
self.position_cache = {}
def get_mesh_3d(self, rope_positions, bsz):
f, h, w = rope_positions
if f"{f}-{h}-{w}" not in self.position_cache:
x = torch.arange(f, device='cpu')
y = torch.arange(h, device='cpu')
z = torch.arange(w, device='cpu')
self.position_cache[f"{f}-{h}-{w}"] = torch.cartesian_prod(x, y, z).view(1, f*h*w, 3).expand(bsz, -1, 3)
return self.position_cache[f"{f}-{h}-{w}"]
def __call__(self, tokens, rope_positions, ch_split, parallel=False):
"""
input:
* tokens: batch_size x ntokens x nheads x dim
* rope_positions: list of (f, h, w)
output:
* tokens after appplying RoPE2D (batch_size x ntokens x nheads x dim)
"""
assert sum(ch_split) == tokens.size(-1);
mesh_grid = self.get_mesh_3d(rope_positions, bsz=tokens.shape[0])
out = []
for i, (D, x) in enumerate(zip(ch_split, torch.split(tokens, ch_split, dim=-1))):
cos, sin = self.get_cos_sin(D, int(mesh_grid.max()) + 1, tokens.device, tokens.dtype)
if parallel:
mesh = torch.chunk(mesh_grid[:, :, i], get_sequence_parallel_world_size(),dim=1)[get_sequence_parallel_rank()].clone()
else:
mesh = mesh_grid[:, :, i].clone()
x = self.apply_rope1d(x, mesh.to(tokens.device), cos, sin)
out.append(x)
tokens = torch.cat(out, dim=-1)
return tokens
import torch.distributed as dist
import xfuser
import torch
def initialize_parall_group(ring_degree, ulysses_degree, tensor_parallel_degree):
dist.init_process_group("nccl")
xfuser.core.distributed.init_distributed_environment(
rank=dist.get_rank(),
world_size=dist.get_world_size()
)
xfuser.core.distributed.initialize_model_parallel(
sequence_parallel_degree=ulysses_degree,
ring_degree=ring_degree,
ulysses_degree=ulysses_degree,
tensor_parallel_degree=tensor_parallel_degree,
)
torch.cuda.set_device(dist.get_rank())
def get_parallel_group():
return xfuser.core.distributed.get_world_group()
def get_sequence_parallel_world_size():
return xfuser.core.distributed.parallel_state.get_sequence_parallel_world_size()
def get_sequence_parallel_rank():
return xfuser.core.distributed.parallel_state.get_sequence_parallel_rank()
def get_sp_group():
return xfuser.core.distributed.parallel_state.get_sp_group()
def parallel_forward(fn_):
def wrapTheFunction(_, hidden_states, *args, **kwargs):
if kwargs['parallel']:
hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
kwargs['attn_mask'] = torch.chunk(kwargs['attn_mask'], get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
output = fn_(_, hidden_states, *args, **kwargs)
if kwargs['parallel']:
output = get_sp_group().all_gather(output.contiguous(), dim=-2)
return output
return wrapTheFunction
\ No newline at end of file
import torch
from stepvideo.config import parse_args
import os
accepted_version = {
'2.2': 'liboptimus_ths-torch2.2-cu121.cpython-310-x86_64-linux-gnu.so',
'2.3': 'liboptimus_ths-torch2.3-cu121.cpython-310-x86_64-linux-gnu.so',
'2.5': 'liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so',
}
try:
args = parse_args()
version = '.'.join(torch.__version__.split('.')[:2])
if version in accepted_version:
torch.ops.load_library(os.path.join(args.model_dir, f'lib/{accepted_version[version]}'))
else:
raise ValueError("Not supported torch version for liboptimus")
except Exception as err:
print(err)
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
import os
class HunyuanClip(nn.Module):
"""
Hunyuan clip code copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
hunyuan's clip used BertModel and BertTokenizer, so we copy it.
"""
def __init__(self, model_dir, max_length=77):
super(HunyuanClip, self).__init__()
self.max_length = max_length
self.tokenizer = BertTokenizer.from_pretrained(os.path.join(model_dir, 'tokenizer'))
self.text_encoder = BertModel.from_pretrained(os.path.join(model_dir, 'clip_text_encoder'))
@torch.no_grad
def forward(self, prompts, with_mask=True):
self.device = next(self.text_encoder.parameters()).device
text_inputs = self.tokenizer(
prompts,
padding="max_length",
max_length=self.max_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
prompt_embeds = self.text_encoder(
text_inputs.input_ids.to(self.device),
attention_mask=text_inputs.attention_mask.to(self.device) if with_mask else None,
)
return prompt_embeds.last_hidden_state, prompt_embeds.pooler_output
\ No newline at end of file
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
import torch
# def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=True,
# return_attn_probs=False, tp_group_rank=0, tp_group_size=1):
# softmax_scale = q.size(-1) ** (-0.5) if softmax_scale is None else softmax_scale
# return torch.ops.Optimus._fwd(q, k, v, None, dropout_p, softmax_scale, causal, return_attn_probs, None, tp_group_rank, tp_group_size)[0]
from flash_attn import flash_attn_func as fa_func
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=True,
return_attn_probs=False, tp_group_rank=0, tp_group_size=1):
softmax_scale = q.size(-1) ** (-0.5) if softmax_scale is None else softmax_scale
return fa_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=causal,
return_attn_probs=return_attn_probs)
class FlashSelfAttention(torch.nn.Module):
def __init__(
self,
attention_dropout=0.0,
):
super().__init__()
self.dropout_p = attention_dropout
def forward(self, q, k, v, cu_seqlens=None, max_seq_len=None):
if cu_seqlens is None:
output = flash_attn_func(q, k, v, dropout_p=self.dropout_p)
else:
raise ValueError('cu_seqlens is not supported!')
return output
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
import os
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from stepvideo.text_encoder.flashattention import FlashSelfAttention
from stepvideo.modules.normalization import RMSNorm
from stepvideo.text_encoder.tokenizer import LLaMaEmbedding, Wrapped_StepChatTokenizer
from stepvideo.utils import with_empty_init
from safetensors.torch import load_file
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from einops import rearrange
import json
def safediv(n, d):
q, r = divmod(n, d)
assert r == 0
return q
class MultiQueryAttention(nn.Module):
def __init__(self, cfg, layer_id=None):
super().__init__()
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
self.max_seq_len = cfg.seq_length
self.use_flash_attention = cfg.use_flash_attn
assert self.use_flash_attention, 'FlashAttention is required!'
self.n_groups = cfg.num_attention_groups
self.tp_size = 1
self.n_local_heads = cfg.num_attention_heads
self.n_local_groups = self.n_groups
self.wqkv = nn.Linear(
cfg.hidden_size,
cfg.hidden_size + self.head_dim * 2 * self.n_groups,
bias=False,
)
self.wo = nn.Linear(
cfg.hidden_size,
cfg.hidden_size,
bias=False,
)
assert self.use_flash_attention, 'non-Flash attention not supported yet.'
self.core_attention = FlashSelfAttention(attention_dropout=cfg.attention_dropout)
self.layer_id = layer_id
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor],
cu_seqlens: Optional[torch.Tensor],
max_seq_len: Optional[torch.Tensor],
):
seqlen, bsz, dim = x.shape
xqkv = self.wqkv(x)
xq, xkv = torch.split(
xqkv,
(dim // self.tp_size,
self.head_dim*2*self.n_groups // self.tp_size
),
dim=-1,
)
# gather on 1st dimension
xq = xq.view(seqlen, bsz, self.n_local_heads, self.head_dim)
xkv = xkv.view(seqlen, bsz, self.n_local_groups, 2 * self.head_dim)
xk, xv = xkv.chunk(2, -1)
# rotary embedding + flash attn
xq = rearrange(xq, "s b h d -> b s h d")
xk = rearrange(xk, "s b h d -> b s h d")
xv = rearrange(xv, "s b h d -> b s h d")
q_per_kv = self.n_local_heads // self.n_local_groups
if q_per_kv > 1:
b, s, h, d = xk.size()
if h == 1:
xk = xk.expand(b, s, q_per_kv, d)
xv = xv.expand(b, s, q_per_kv, d)
else:
''' To cover the cases where h > 1, we have
the following implementation, which is equivalent to:
xk = xk.repeat_interleave(q_per_kv, dim=-2)
xv = xv.repeat_interleave(q_per_kv, dim=-2)
but can avoid calling aten::item() that involves cpu.
'''
idx = torch.arange(q_per_kv * h, device=xk.device).reshape(q_per_kv, -1).permute(1, 0).flatten()
xk = torch.index_select(xk.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous()
xv = torch.index_select(xv.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous()
if self.use_flash_attention:
output = self.core_attention(xq, xk, xv,
cu_seqlens=cu_seqlens,
max_seq_len=max_seq_len)
# reduce-scatter only support first dimention now
output = rearrange(output, "b s h d -> s b (h d)").contiguous()
else:
xq, xk, xv = [
rearrange(x, "b s ... -> s b ...").contiguous()
for x in (xq, xk, xv)
]
output = self.core_attention(xq, xk, xv, mask)
output = self.wo(output)
return output
class FeedForward(nn.Module):
def __init__(
self,
cfg,
dim: int,
hidden_dim: int,
layer_id: int,
multiple_of: int=256,
):
super().__init__()
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
def swiglu(x):
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1]
self.swiglu = swiglu
self.w1 = nn.Linear(
dim,
2 * hidden_dim,
bias=False,
)
self.w2 = nn.Linear(
hidden_dim,
dim,
bias=False,
)
def forward(self, x):
x = self.swiglu(self.w1(x))
output = self.w2(x)
return output
class TransformerBlock(nn.Module):
def __init__(
self, cfg, layer_id: int
):
super().__init__()
self.n_heads = cfg.num_attention_heads
self.dim = cfg.hidden_size
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
self.attention = MultiQueryAttention(
cfg,
layer_id=layer_id,
)
self.feed_forward = FeedForward(
cfg,
dim=cfg.hidden_size,
hidden_dim=cfg.ffn_hidden_size,
layer_id=layer_id,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(
cfg.hidden_size,
eps=cfg.layernorm_epsilon,
)
self.ffn_norm = RMSNorm(
cfg.hidden_size,
eps=cfg.layernorm_epsilon,
)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor],
cu_seqlens: Optional[torch.Tensor],
max_seq_len: Optional[torch.Tensor],
):
residual = self.attention.forward(
self.attention_norm(x), mask,
cu_seqlens, max_seq_len
)
h = x + residual
ffn_res = self.feed_forward.forward(self.ffn_norm(h))
out = h + ffn_res
return out
class Transformer(nn.Module):
def __init__(
self,
config,
max_seq_size=8192,
):
super().__init__()
self.num_layers = config.num_layers
self.layers = self._build_layers(config)
def _build_layers(self, config):
layers = torch.nn.ModuleList()
for layer_id in range(self.num_layers):
layers.append(
TransformerBlock(
config,
layer_id=layer_id + 1 ,
)
)
return layers
def forward(
self,
hidden_states,
attention_mask,
cu_seqlens=None,
max_seq_len=None,
):
if max_seq_len is not None and not isinstance(max_seq_len, torch.Tensor):
max_seq_len = torch.tensor(max_seq_len, dtype=torch.int32, device="cpu")
for lid, layer in enumerate(self.layers):
hidden_states = layer(
hidden_states,
attention_mask,
cu_seqlens,
max_seq_len,
)
return hidden_states
class Step1Model(PreTrainedModel):
config_class=PretrainedConfig
@with_empty_init
def __init__(
self,
config,
):
super().__init__(config)
self.tok_embeddings = LLaMaEmbedding(config)
self.transformer = Transformer(config)
def forward(
self,
input_ids=None,
attention_mask=None,
):
hidden_states = self.tok_embeddings(input_ids)
hidden_states = self.transformer(
hidden_states,
attention_mask,
)
return hidden_states
class STEP1TextEncoder(torch.nn.Module):
def __init__(self, model_dir, max_length=320):
super(STEP1TextEncoder, self).__init__()
self.max_length = max_length
self.text_tokenizer = Wrapped_StepChatTokenizer(os.path.join(model_dir, 'step1_chat_tokenizer.model'))
text_encoder = Step1Model.from_pretrained(model_dir)
self.text_encoder = text_encoder.eval().to(torch.bfloat16)
@torch.no_grad
def forward(self, prompts, with_mask=True, max_length=None):
self.device = next(self.text_encoder.parameters()).device
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
if type(prompts) is str:
prompts = [prompts]
txt_tokens = self.text_tokenizer(
prompts, max_length=max_length or self.max_length, padding="max_length", truncation=True, return_tensors="pt"
)
y = self.text_encoder(
txt_tokens.input_ids.to(self.device),
attention_mask=txt_tokens.attention_mask.to(self.device) if with_mask else None
)
y_mask = txt_tokens.attention_mask
return y.transpose(0,1), y_mask
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment