Commit 2b3ebe0c authored by luopl's avatar luopl
Browse files

Initial commit

parents
Pipeline #2705 canceled with stages
from setuptools import find_packages, setup
import subprocess
def get_cuda_version():
try:
nvcc_version = subprocess.check_output(["nvcc", "--version"]).decode("utf-8")
version_line = [line for line in nvcc_version.split("\n") if "release" in line][
0
]
cuda_version = version_line.split(" ")[-2].replace(",", "")
return "cu" + cuda_version.replace(".", "")
except Exception as e:
return "no_cuda"
if __name__ == "__main__":
with open("README.md", "r") as f:
long_description = f.read()
fp = open("stepvideo/__version__.py", "r").read()
version = eval(fp.strip().split()[-1])
setup(
name="stepvideo",
author="Step-Video Team",
packages=find_packages(),
install_requires=[
"torchvision==0.19.1+das.opt2.dtk2504",
"torch==2.4.1+das.opt2.dtk2504",
"accelerate==1.6.0",
"transformers==4.49.0",
"diffusers==0.31.0",
"sentencepiece==0.2.0",
"blinker==1.4",
"imageio>=2.37.0",
"optimus==2.1",
"numpy",
"einops",
"aiohttp",
"asyncio",
"flask",
"flask_restful",
"ffmpeg-python",
"requests",
"xfuser==0.4.2rc2"
],
url="",
description="A 30B DiT based text to video and image generation model",
long_description=long_description,
long_description_content_type="text/markdown",
version=version,
classifiers=[
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
],
include_package_data=True,
python_requires=">=3.10",
)
\ No newline at end of file
import os
os.environ["NCCL_DEBUG"] = "ERROR"
from .diffusion.scheduler import *
from .diffusion.video_pipeline import *
from .modules.model import *
\ No newline at end of file
__version__ = "0.1.0"
\ No newline at end of file
import argparse
def parse_args(namespace=None):
parser = argparse.ArgumentParser(description="StepVideo inference script")
parser = add_extra_models_args(parser)
parser = add_denoise_schedule_args(parser)
parser = add_inference_args(parser)
parser = add_parallel_args(parser)
args = parser.parse_args(namespace=namespace)
return args
def add_extra_models_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="Extra models args, including vae, text encoders and tokenizers)"
)
group.add_argument(
"--vae_url",
type=str,
default='127.0.0.1',
help="vae url.",
)
group.add_argument(
"--caption_url",
type=str,
default='127.0.0.1',
help="caption url.",
)
return parser
def add_denoise_schedule_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Denoise schedule args")
# Flow Matching
group.add_argument(
"--time_shift",
type=float,
default=7.0,
help="Shift factor for flow matching schedulers.",
)
group.add_argument(
"--flow_reverse",
action="store_true",
help="If reverse, learning/sampling from t=1 -> t=0.",
)
group.add_argument(
"--flow_solver",
type=str,
default="euler",
help="Solver for flow matching.",
)
return parser
def add_inference_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Inference args")
# ======================== Model loads ========================
group.add_argument(
"--model_dir",
type=str,
default="./ckpts",
help="Root path of all the models, including t2v models and extra models.",
)
group.add_argument(
"--model_resolution",
type=str,
default="540p",
choices=["540p"],
help="Root path of all the models, including t2v models and extra models.",
)
group.add_argument(
"--use-cpu-offload",
action="store_true",
help="Use CPU offload for the model load.",
)
# ======================== Inference general setting ========================
group.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size for inference and evaluation.",
)
group.add_argument(
"--infer_steps",
type=int,
default=50,
help="Number of denoising steps for inference.",
)
group.add_argument(
"--save_path",
type=str,
default="./results",
help="Path to save the generated samples.",
)
group.add_argument(
"--output_file_name",
type=str,
default="",
help="Name to save the generated samples.",
)
group.add_argument(
"--name_suffix",
type=str,
default="",
help="Suffix for the names of saved samples.",
)
group.add_argument(
"--num_videos",
type=int,
default=1,
help="Number of videos to generate for each prompt.",
)
# ---sample size---
group.add_argument(
"--num_frames",
type=int,
default=102,
help="How many frames to sample from a video. ",
)
group.add_argument(
"--height",
type=int,
default=544,
help="The height of video sample",
)
group.add_argument(
"--width",
type=int,
default=992,
help="The width of video sample",
)
# --- prompt ---
group.add_argument(
"--prompt",
type=str,
default=None,
help="Prompt for sampling during evaluation.",
)
group.add_argument(
"--first_image_path",
type=str,
default='./assets/demo.png',
help="The reference image path for image-to-video task.",
)
group.add_argument("--seed", type=int, default=1234, help="Seed for evaluation.")
# Classifier-Free Guidance
group.add_argument(
"--pos_magic", type=str, default="画面中的主体动作表现生动自然、画面流畅、生动细节、光线统一柔和、超真实动态捕捉、大师级运镜、整体不变形、超高清、画面稳定、逼真的细节、专业级构图、超细节、清晰。", help="Positive magic prompt for sampling."
)
group.add_argument(
"--neg_magic", type=str, default="动画、模糊、变形、毁容、低质量、拼贴、粒状、标志、抽象、插图、计算机生成、扭曲、动作不流畅、面部有褶皱、表情僵硬、畸形手指", help="Negative magic prompt for sampling."
)
group.add_argument(
"--cfg_scale", type=float, default=9.0, help="Classifier free guidance scale."
)
group.add_argument(
"--motion_score", type=float, default=5, help="Score to control the motion level of the video."
)
return parser
def add_parallel_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Parallel args")
# ======================== Model loads ========================
group.add_argument(
"--ulysses_degree",
type=int,
default=8,
help="Ulysses degree.",
)
group.add_argument(
"--ring_degree",
type=int,
default=1,
help="Ulysses degree.",
)
group.add_argument(
"--tensor_parallel_degree",
type=int,
default=1,
help="Tensor parallel degree.",
)
return parser
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, logging
from diffusers.schedulers.scheduling_utils import SchedulerMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class FlowMatchDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Euler scheduler.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
reverse (`bool`, defaults to `True`):
Whether to reverse the timestep schedule.
"""
_compatibles = []
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
reverse: bool = False,
solver: str = "euler",
device: Union[str, torch.device] = None,
):
sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
if not reverse:
sigmas = sigmas.flip(0)
self.sigmas = sigmas
# the value fed to model
self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
self._step_index = None
self._begin_index = None
self.device = device
self.supported_solver = ["euler"]
if solver not in self.supported_solver:
raise ValueError(
f"Solver {solver} not supported. Supported solvers: {self.supported_solver}"
)
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def set_timesteps(
self,
num_inference_steps: int,
time_shift: float = 13.0,
device: Union[str, torch.device] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
n_tokens (`int`, *optional*):
Number of tokens in the input sequence.
"""
device = device or self.device
self.num_inference_steps = num_inference_steps
sigmas = torch.linspace(1, 0, num_inference_steps + 1, device=device)
sigmas = self.sd3_time_shift(sigmas, time_shift)
if not self.config.reverse:
sigmas = 1 - sigmas
self.sigmas = sigmas
self.timesteps = sigmas[:-1]
# Reset step index
self._step_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def scale_model_input(
self, sample: torch.Tensor, timestep: Optional[int] = None
) -> torch.Tensor:
return sample
def sd3_time_shift(self, t: torch.Tensor, time_shift: float = 13.0):
return (time_shift * t) / (1 + (time_shift - 1) * t)
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
return_dict: bool = False,
) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
n_tokens (`int`, *optional*):
Number of tokens in the input sequence.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if self.step_index is None:
self._init_step_index(timestep)
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
if self.config.solver == "euler":
prev_sample = sample + model_output.to(torch.float32) * dt
else:
raise ValueError(
f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}"
)
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return prev_sample
return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
def __len__(self):
return self.config.num_train_timesteps
# Copyright 2025 StepFun Inc. All Rights Reserved.
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
import numpy as np
import pickle
import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import BaseOutput
import asyncio
from stepvideo.modules.model import StepVideoModel
from stepvideo.diffusion.scheduler import FlowMatchDiscreteScheduler
from stepvideo.utils import VideoProcessor
from torchvision import transforms
from PIL import Image as PILImage
import os
def call_api_gen(url, api, port=8080):
url =f"http://{url}:{port}/{api}-api"
import aiohttp
async def _fn(samples, *args, **kwargs):
if api=='vae':
data = {
"samples": samples,
}
elif api=='vae-encode':
data = {
"videos": samples,
}
elif api == 'caption':
data = {
"prompts": samples,
}
else:
raise Exception(f"Not supported api: {api}...")
async with aiohttp.ClientSession() as sess:
data_bytes = pickle.dumps(data)
async with sess.get(url, data=data_bytes, timeout=12000) as response:
result = bytearray()
while not response.content.at_eof():
chunk = await response.content.read(1024)
result += chunk
response_data = pickle.loads(result)
return response_data
return _fn
@dataclass
class StepVideoPipelineOutput(BaseOutput):
video: Union[torch.Tensor, np.ndarray]
class StepVideoPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-video generation using StepVideo.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
transformer ([`StepVideoModel`]):
Conditional Transformer to denoise the encoded image latents.
scheduler ([`FlowMatchDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae_url:
remote vae server's url.
caption_url:
remote caption (stepllm and clip) server's url.
"""
def __init__(
self,
transformer: StepVideoModel,
scheduler: FlowMatchDiscreteScheduler,
vae_url: str = '127.0.0.1',
caption_url: str = '127.0.0.1',
save_path: str = './results',
name_suffix: str = '',
):
super().__init__()
self.register_modules(
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 8
self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16
self.video_processor = VideoProcessor(save_path, name_suffix)
self.vae_url = vae_url
self.caption_url = caption_url
self.setup_api(self.vae_url, self.caption_url)
def setup_pipeline(self, args):
self.args = args
self.video_processor = VideoProcessor(self.args.save_path, self.args.name_suffix)
self.setup_api(args.vae_url, args.caption_url)
return self
def setup_api(self, vae_url, caption_url):
self.vae_url = vae_url
self.caption_url = caption_url
self.caption = call_api_gen(caption_url, 'caption')
self.vae = call_api_gen(vae_url, 'vae')
self.vae_encode = call_api_gen(vae_url, 'vae-encode')
return self
def encode_prompt(
self,
prompt: str,
neg_magic: str = '',
pos_magic: str = '',
):
device = self._execution_device
prompts = [prompt+pos_magic]
bs = len(prompts)
prompts += [neg_magic]*bs
data = asyncio.run(self.caption(prompts))
prompt_embeds, prompt_attention_mask, clip_embedding = data['y'].to(device), data['y_mask'].to(device), data['clip_embedding'].to(device)
return prompt_embeds, clip_embedding, prompt_attention_mask
def decode_vae(self, samples):
samples = asyncio.run(self.vae(samples.cpu()))
return samples
def encode_vae(self, img):
latents = asyncio.run(self.vae_encode(img))
return latents
def check_inputs(self, num_frames, width, height):
num_frames = max(num_frames//17*17, 1)
width = max(width//16*16, 16)
height = max(height//16*16, 16)
return num_frames, width, height
def prepare_latents(
self,
batch_size: int,
num_channels_latents: 64,
height: int = 544,
width: int = 992,
num_frames: int = 204,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
num_frames, width, height = self.check_inputs(num_frames, width, height)
shape = (
batch_size,
max(num_frames//17*3, 1),
num_channels_latents,
int(height) // self.vae_scale_factor_spatial,
int(width) // self.vae_scale_factor_spatial,
) # b,f,c,h,w
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if generator is None:
generator = torch.Generator(device=self._execution_device)
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
return latents
def resize_to_desired_aspect_ratio(self, video, aspect_size):
## video is in shape [f, c, h, w]
height, width = video.shape[-2:]
aspect_ratio = [w/h for h, w in aspect_size]
# # resize
aspect_ratio_fact = width / height
bucket_idx = np.argmin(np.abs(aspect_ratio_fact - np.array(aspect_ratio)))
aspect_ratio = aspect_ratio[bucket_idx]
target_size_height, target_size_width = aspect_size[bucket_idx]
if aspect_ratio_fact < aspect_ratio:
scale = target_size_width / width
else:
scale = target_size_height / height
width_scale = int(round(width * scale))
height_scale = int(round(height * scale))
# # crop
delta_h = height_scale - target_size_height
delta_w = width_scale - target_size_width
assert delta_w>=0
assert delta_h>=0
assert not all(
[delta_h, delta_w]
)
top = delta_h//2
left = delta_w//2
## resize image and crop
resize_crop_transform = transforms.Compose([
transforms.Resize((height_scale, width_scale)),
lambda x: transforms.functional.crop(x, top, left, target_size_height, target_size_width),
])
video = torch.stack([resize_crop_transform(frame.contiguous()) for frame in video], dim=0)
return video
def prepare_condition_hidden_states(
self,
img: Union[str, PILImage.Image, torch.Tensor]=None,
batch_size: int = 1,
num_channels_latents: int = 64,
height: int = 544,
width: int = 992,
num_frames: int = 204,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None
):
if isinstance(img, str):
assert os.path.exists(img)
img = PILImage.open(img)
if isinstance(img, PILImage.Image):
img_tensor = transforms.ToTensor()(img.convert('RGB'))*2-1
else:
img_tensor = img
num_frames, width, height = self.check_inputs(num_frames, width, height)
img_tensor = self.resize_to_desired_aspect_ratio(img_tensor[None], aspect_size=[(height, width)])[None]
img_emb = self.encode_vae(img_tensor).repeat(batch_size, 1,1,1,1).to(device)
padding_tensor = torch.zeros((batch_size, max(num_frames//17*3, 1)-1, num_channels_latents, int(height) // self.vae_scale_factor_spatial, int(width) // self.vae_scale_factor_spatial,), device=device)
condition_hidden_states = torch.cat([img_emb, padding_tensor], dim=1)
condition_hidden_states = condition_hidden_states.repeat(2, 1,1,1,1) ## for CFG
return condition_hidden_states.to(dtype)
@torch.inference_mode()
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: int = 544,
width: int = 992,
num_frames: int = 102,
num_inference_steps: int = 50,
guidance_scale: float = 9.0,
time_shift: float = 13.0,
neg_magic: str = "",
pos_magic: str = "",
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
first_image: Union[str, PILImage.Image, torch.Tensor] = None,
motion_score: float = 2.0,
output_type: Optional[str] = "mp4",
output_file_name: Optional[str] = "",
return_dict: bool = True,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, defaults to `544`):
The height in pixels of the generated image.
width (`int`, defaults to `992`):
The width in pixels of the generated image.
num_frames (`int`, defaults to `204`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `9.0`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
first_image (`str`, `PIL.Image`, `torch.Tensor`):
A path for the reference image
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
output_file_name(`str`, *optional*`):
The output mp4 file name.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`StepVideoPipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`~StepVideoPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`StepVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
where the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
# 1. Check inputs. Raise error if not correct
device = self._execution_device
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, prompt_embeds_2, prompt_attention_mask = self.encode_prompt(
prompt=prompt,
neg_magic=neg_magic,
pos_magic=pos_magic,
)
transformer_dtype = self.transformer.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
prompt_embeds_2 = prompt_embeds_2.to(transformer_dtype)
# 4. Prepare timesteps
self.scheduler.set_timesteps(
num_inference_steps=num_inference_steps,
time_shift=time_shift,
device=device
)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.bfloat16,
device,
generator,
latents,
)
condition_hidden_states = self.prepare_condition_hidden_states(
first_image,
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
dtype=torch.bfloat16,
device=device)
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(self.scheduler.timesteps):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = latent_model_input.to(transformer_dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
encoder_hidden_states_2=prompt_embeds_2,
condition_hidden_states=condition_hidden_states,
motion_score=motion_score,
return_dict=False,
)
# perform guidance
if do_classifier_free_guidance:
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
model_output=noise_pred,
timestep=t,
sample=latents
)
progress_bar.update()
if not torch.distributed.is_initialized() or int(torch.distributed.get_rank())==0:
if not output_type == "latent":
video = self.decode_vae(latents)
video = self.video_processor.postprocess_video(video, output_file_name=output_file_name, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video, )
return StepVideoPipelineOutput(video=video)
\ No newline at end of file
import torch
import torch.nn as nn
from einops import rearrange
try:
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
except ImportError:
xFuserLongContextAttention = None
class Attention(nn.Module):
def __init__(self):
super().__init__()
def attn_processor(self, attn_type):
if attn_type == 'torch':
return self.torch_attn_func
elif attn_type == 'parallel':
return self.parallel_attn_func
else:
raise Exception('Not supported attention type...')
def torch_attn_func(
self,
q,
k,
v,
attn_mask=None,
causal=False,
drop_rate=0.0,
**kwargs
):
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
if attn_mask is not None and attn_mask.ndim == 3: ## no head
n_heads = q.shape[2]
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
q, k, v = map(lambda x: rearrange(x, 'b s h d -> b h s d'), (q, k, v))
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
)
x = rearrange(x, 'b h s d -> b s h d')
return x
def parallel_attn_func(
self,
q,
k,
v,
causal=False,
**kwargs
):
assert xFuserLongContextAttention is not None; 'to use sequence parallel attention, xFuserLongContextAttention should be imported...'
hybrid_seq_parallel_attn = xFuserLongContextAttention()
x = hybrid_seq_parallel_attn(
None, q,k,v, causal=causal
)
return x
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
import torch
import torch.nn as nn
from typing import Optional
from einops import rearrange
from stepvideo.modules.rope import RoPE3D
from stepvideo.modules.attentions import Attention
from stepvideo.modules.normalization import RMSNorm
class SelfAttention(Attention):
def __init__(self, hidden_dim, head_dim, bias=False, with_rope=True, with_qk_norm=True, attn_type='torch'):
super().__init__()
self.head_dim = head_dim
self.n_heads = hidden_dim // head_dim
self.wqkv = nn.Linear(hidden_dim, hidden_dim*3, bias=bias)
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=bias)
self.with_rope = with_rope
self.with_qk_norm = with_qk_norm
if self.with_qk_norm:
self.q_norm = RMSNorm(head_dim, elementwise_affine=True)
self.k_norm = RMSNorm(head_dim, elementwise_affine=True)
if self.with_rope:
self.rope_3d = RoPE3D(freq=1e4, F0=1.0, scaling_factor=1.0)
self.rope_ch_split = [64, 32, 32]
self.core_attention = self.attn_processor(attn_type=attn_type)
self.parallel = attn_type=='parallel'
def apply_rope3d(self, x, fhw_positions, rope_ch_split, parallel=True):
x = self.rope_3d(x, fhw_positions, rope_ch_split, parallel)
return x
def forward(
self,
x,
cu_seqlens=None,
max_seqlen=None,
rope_positions=None,
attn_mask=None
):
xqkv = self.wqkv(x)
xqkv = xqkv.view(*x.shape[:-1], self.n_heads, 3*self.head_dim)
xq, xk, xv = torch.split(xqkv, [self.head_dim]*3, dim=-1) ## seq_len, n, dim
if self.with_qk_norm:
xq = self.q_norm(xq)
xk = self.k_norm(xk)
if self.with_rope:
xq = self.apply_rope3d(xq, rope_positions, self.rope_ch_split, parallel=self.parallel)
xk = self.apply_rope3d(xk, rope_positions, self.rope_ch_split, parallel=self.parallel)
output = self.core_attention(
xq,
xk,
xv,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
attn_mask=attn_mask
)
output = rearrange(output, 'b s h d -> b s (h d)')
output = self.wo(output)
return output
class CrossAttention(Attention):
def __init__(self, hidden_dim, head_dim, bias=False, with_qk_norm=True, attn_type='torch'):
super().__init__()
self.head_dim = head_dim
self.n_heads = hidden_dim // head_dim
self.wq = nn.Linear(hidden_dim, hidden_dim, bias=bias)
self.wkv = nn.Linear(hidden_dim, hidden_dim*2, bias=bias)
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=bias)
self.with_qk_norm = with_qk_norm
if self.with_qk_norm:
self.q_norm = RMSNorm(head_dim, elementwise_affine=True)
self.k_norm = RMSNorm(head_dim, elementwise_affine=True)
self.core_attention = self.attn_processor(attn_type=attn_type)
def forward(
self,
x: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attn_mask=None
):
xq = self.wq(x)
xq = xq.view(*xq.shape[:-1], self.n_heads, self.head_dim)
xkv = self.wkv(encoder_hidden_states)
xkv = xkv.view(*xkv.shape[:-1], self.n_heads, 2*self.head_dim)
xk, xv = torch.split(xkv, [self.head_dim]*2, dim=-1) ## seq_len, n, dim
if self.with_qk_norm:
xq = self.q_norm(xq)
xk = self.k_norm(xk)
output = self.core_attention(
xq,
xk,
xv,
attn_mask=attn_mask
)
output = rearrange(output, 'b s h d -> b s (h d)')
output = self.wo(output)
return output
class GELU(nn.Module):
r"""
GELU activation function with tanh approximation support with `approximate="tanh"`.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
self.approximate = approximate
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(gate, approximate=self.approximate)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = self.gelu(hidden_states)
return hidden_states
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
inner_dim: Optional[int] = None,
dim_out: Optional[int] = None,
mult: int = 4,
bias: bool = False,
):
super().__init__()
inner_dim = dim*mult if inner_dim is None else inner_dim
dim_out = dim if dim_out is None else dim_out
self.net = nn.ModuleList([
GELU(dim, inner_dim, approximate="tanh", bias=bias),
nn.Identity(),
nn.Linear(inner_dim, dim_out, bias=bias)
])
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
def modulate(x, scale, shift):
x = x * (1 + scale) + shift
return x
def gate(x, gate):
x = gate * x
return x
class StepVideoTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
attention_head_dim: int,
norm_eps: float = 1e-5,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = False,
attention_type: str = 'parallel'
):
super().__init__()
self.dim = dim
self.norm1 = nn.LayerNorm(dim, eps=norm_eps)
self.attn1 = SelfAttention(dim, attention_head_dim, bias=False, with_rope=True, with_qk_norm=True, attn_type=attention_type)
self.norm2 = nn.LayerNorm(dim, eps=norm_eps)
self.attn2 = CrossAttention(dim, attention_head_dim, bias=False, with_qk_norm=True, attn_type='torch')
self.ff = FeedForward(dim=dim, inner_dim=ff_inner_dim, dim_out=dim, bias=ff_bias)
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) /dim**0.5)
@torch.no_grad()
def forward(
self,
q: torch.Tensor,
kv: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
attn_mask = None,
rope_positions: list = None,
) -> torch.Tensor:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
torch.clone(chunk) for chunk in (self.scale_shift_table[None] + timestep.reshape(-1, 6, self.dim)).chunk(6, dim=1)
)
scale_shift_q = modulate(self.norm1(q), scale_msa, shift_msa)
attn_q = self.attn1(
scale_shift_q,
rope_positions=rope_positions
)
q = gate(attn_q, gate_msa) + q
attn_q = self.attn2(
q,
kv,
attn_mask
)
q = attn_q + q
scale_shift_q = modulate(self.norm2(q), scale_mlp, shift_mlp)
ff_output = self.ff(scale_shift_q)
q = gate(ff_output, gate_mlp) + q
return q
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding"""
def __init__(
self,
patch_size=64,
in_channels=3,
embed_dim=768,
layer_norm=False,
flatten=True,
bias=True,
):
super().__init__()
self.flatten = flatten
self.layer_norm = layer_norm
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
def forward(self, latent):
latent = self.proj(latent).to(latent.dtype)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
return latent
\ No newline at end of file
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
from typing import Any, Dict, Optional, Union
import torch
from torch import nn
import os
from einops import rearrange, repeat
from stepvideo.modules.blocks import StepVideoTransformerBlock, PatchEmbed
from stepvideo.utils import with_empty_init
from stepvideo.parallel import parallel_forward
from stepvideo.modules.normalization import (
PixArtAlphaTextProjection,
AdaLayerNormSingle
)
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
class StepVideoModel(ModelMixin, ConfigMixin):
_no_split_modules = ["StepVideoTransformerBlock", "PatchEmbed"]
@with_empty_init
@register_to_config
def __init__(
self,
num_attention_heads: int = 48,
attention_head_dim: int = 128,
in_channels: int = 64,
out_channels: Optional[int] = 64,
num_layers: int = 48,
dropout: float = 0.0,
patch_size: int = 1,
norm_type: str = "ada_norm_single",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
use_additional_conditions: Optional[bool] = False,
caption_channels: Optional[int]|list|tuple = [6144, 1024],
attention_type: Optional[str] = "parallel",
):
super().__init__()
# Set some common variables used across the board.
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.out_channels = in_channels if out_channels is None else out_channels
self.use_additional_conditions = use_additional_conditions
self.pos_embed = PatchEmbed(
patch_size=patch_size,
in_channels=self.config.in_channels if not use_additional_conditions else self.config.in_channels*2,
embed_dim=self.inner_dim,
)
self.transformer_blocks = nn.ModuleList(
[
StepVideoTransformerBlock(
dim=self.inner_dim,
attention_head_dim=self.config.attention_head_dim,
attention_type=attention_type
)
for _ in range(self.config.num_layers)
]
)
# 3. Output blocks.
self.norm_out = nn.LayerNorm(self.inner_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels)
self.patch_size = patch_size
self.adaln_single = AdaLayerNormSingle(
self.inner_dim, use_additional_conditions=self.use_additional_conditions
)
if isinstance(self.config.caption_channels, int):
caption_channel = self.config.caption_channels
else:
caption_channel, clip_channel = self.config.caption_channels
self.clip_projection = nn.Linear(clip_channel, self.inner_dim)
self.caption_norm = nn.LayerNorm(caption_channel, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.caption_projection = PixArtAlphaTextProjection(
in_features=caption_channel, hidden_size=self.inner_dim
)
self.parallel = attention_type=='parallel'
def patchfy(self, hidden_states, condition_hidden_states=None):
if condition_hidden_states is not None:
hidden_states = torch.cat([hidden_states, condition_hidden_states], dim=2)
hidden_states = rearrange(hidden_states, 'b f c h w -> (b f) c h w')
hidden_states = self.pos_embed(hidden_states)
return hidden_states
def prepare_attn_mask(self, encoder_attention_mask, encoder_hidden_states, q_seqlen):
kv_seqlens = encoder_attention_mask.sum(dim=1).int()
mask = torch.zeros([len(kv_seqlens), q_seqlen, max(kv_seqlens)], dtype=torch.bool, device=encoder_attention_mask.device)
encoder_hidden_states = encoder_hidden_states[:,: max(kv_seqlens)]
for i, kv_len in enumerate(kv_seqlens):
mask[i, :, :kv_len] = 1
return encoder_hidden_states, mask
@parallel_forward
def block_forward(
self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
rope_positions=None,
attn_mask=None,
parallel=True
):
for i, block in enumerate(self.transformer_blocks):
hidden_states = block(
hidden_states,
encoder_hidden_states,
timestep=timestep,
attn_mask=attn_mask,
rope_positions=rope_positions
)
return hidden_states
@torch.inference_mode()
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_hidden_states_2: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
fps: torch.Tensor=None,
condition_hidden_states: torch.Tensor = None,
motion_score: torch.Tensor=None,
return_dict: bool = True,
):
assert hidden_states.ndim==5; "hidden_states's shape should be (bsz, f, ch, h ,w)"
bsz, frame, _, height, width = hidden_states.shape
height, width = height // self.patch_size, width // self.patch_size
hidden_states = self.patchfy(hidden_states, condition_hidden_states)
len_frame = hidden_states.shape[1]
if self.use_additional_conditions:
added_cond_kwargs = {
"motion_score": torch.tensor([motion_score], device=hidden_states.device, dtype=hidden_states.dtype).repeat(bsz),
}
else:
added_cond_kwargs = {}
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs=added_cond_kwargs
)
encoder_hidden_states = self.caption_projection(self.caption_norm(encoder_hidden_states))
if encoder_hidden_states_2 is not None and hasattr(self, 'clip_projection'):
clip_embedding = self.clip_projection(encoder_hidden_states_2)
encoder_hidden_states = torch.cat([clip_embedding, encoder_hidden_states], dim=1)
hidden_states = rearrange(hidden_states, '(b f) l d-> b (f l) d', b=bsz, f=frame, l=len_frame).contiguous()
encoder_hidden_states, attn_mask = self.prepare_attn_mask(encoder_attention_mask, encoder_hidden_states, q_seqlen=frame*len_frame)
hidden_states = self.block_forward(
hidden_states,
encoder_hidden_states,
timestep=timestep,
rope_positions=[frame, height, width],
attn_mask=attn_mask,
parallel=self.parallel
)
hidden_states = rearrange(hidden_states, 'b (f l) d -> (b f) l d', b=bsz, f=frame, l=len_frame)
embedded_timestep = repeat(embedded_timestep, 'b d -> (b f) d', f=frame).contiguous()
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
# unpatchify
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = rearrange(hidden_states, 'n h w p q c -> n c h p w q')
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
output = rearrange(output, '(b f) c h w -> b f c h w', f=frame)
if return_dict:
return {'x': output}
return output
\ No newline at end of file
from typing import Any, Dict, Optional, Union, Tuple
import torch
import torch.nn as nn
import math
class RMSNorm(nn.Module):
def __init__(
self,
dim: int,
elementwise_affine=True,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
ACTIVATION_FUNCTIONS = {
"swish": nn.SiLU(),
"silu": nn.SiLU(),
"mish": nn.Mish(),
"gelu": nn.GELU(),
"relu": nn.ReLU(),
}
def get_activation(act_fn: str) -> nn.Module:
"""Helper function to get activation function from string.
Args:
act_fn (str): Name of activation function.
Returns:
nn.Module: Activation function.
"""
act_fn = act_fn.lower()
if act_fn in ACTIVATION_FUNCTIONS:
return ACTIVATION_FUNCTIONS[act_fn]
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True
):
super().__init__()
linear_cls = nn.Linear
self.linear_1 = linear_cls(
in_channels,
time_embed_dim,
bias=sample_proj_bias,
)
if cond_proj_dim is not None:
self.cond_proj = linear_cls(
cond_proj_dim,
in_channels,
bias=False,
)
else:
self.cond_proj = None
self.act = get_activation(act_fn)
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = linear_cls(
time_embed_dim,
time_embed_dim_out,
bias=sample_proj_bias,
)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
super().__init__()
self.outdim = size_emb_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.use_additional_conditions = use_additional_conditions
if self.use_additional_conditions:
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.motion_score_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timestep, motion_score=None):
hidden_dtype = next(self.timestep_embedder.parameters()).dtype
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
if self.use_additional_conditions:
batch_size = timestep.shape[0]
motion_score_emb = self.additional_condition_proj(motion_score.flatten()).to(hidden_dtype)
motion_score_emb = self.motion_score_embedder(motion_score_emb).reshape(batch_size, -1)
conditioning = timesteps_emb + motion_score_emb
else:
conditioning = timesteps_emb
return conditioning
class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
"""
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, time_step_rescale=1000):
super().__init__()
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 2, use_additional_conditions=use_additional_conditions
)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
self.time_step_rescale = time_step_rescale ## timestep usually in [0, 1], we rescale it to [0,1000] for stability
def forward(
self,
timestep: torch.Tensor,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
):
embedded_timestep = self.emb(timestep*self.time_step_rescale, **added_cond_kwargs)
out = self.linear(self.silu(embedded_timestep))
return out, embedded_timestep
class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_features, hidden_size):
super().__init__()
self.linear_1 = nn.Linear(
in_features,
hidden_size,
bias=True,
)
self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = nn.Linear(
hidden_size,
hidden_size,
bias=True,
)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
import torch
from stepvideo.parallel import get_sequence_parallel_world_size, get_sequence_parallel_rank
class RoPE1D:
def __init__(self, freq=1e4, F0=1.0, scaling_factor=1.0):
self.base = freq
self.F0 = F0
self.scaling_factor = scaling_factor
self.cache = {}
def get_cos_sin(self, D, seq_len, device, dtype):
if (D, seq_len, device, dtype) not in self.cache:
inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
freqs = torch.cat((freqs, freqs), dim=-1)
cos = freqs.cos() # (Seq, Dim)
sin = freqs.sin()
self.cache[D, seq_len, device, dtype] = (cos, sin)
return self.cache[D, seq_len, device, dtype]
@staticmethod
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rope1d(self, tokens, pos1d, cos, sin):
assert pos1d.ndim == 2
cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :]
sin = torch.nn.functional.embedding(pos1d, sin)[:, :, None, :]
return (tokens * cos) + (self.rotate_half(tokens) * sin)
def __call__(self, tokens, positions):
"""
input:
* tokens: batch_size x ntokens x nheads x dim
* positions: batch_size x ntokens (t position of each token)
output:
* tokens after appplying RoPE2D (batch_size x ntokens x nheads x dim)
"""
D = tokens.size(3)
assert positions.ndim == 2 # Batch, Seq
cos, sin = self.get_cos_sin(D, int(positions.max()) + 1, tokens.device, tokens.dtype)
tokens = self.apply_rope1d(tokens, positions, cos, sin)
return tokens
class RoPE3D(RoPE1D):
def __init__(self, freq=1e4, F0=1.0, scaling_factor=1.0):
super(RoPE3D, self).__init__(freq, F0, scaling_factor)
self.position_cache = {}
def get_mesh_3d(self, rope_positions, bsz):
f, h, w = rope_positions
if f"{f}-{h}-{w}" not in self.position_cache:
x = torch.arange(f, device='cpu')
y = torch.arange(h, device='cpu')
z = torch.arange(w, device='cpu')
self.position_cache[f"{f}-{h}-{w}"] = torch.cartesian_prod(x, y, z).view(1, f*h*w, 3).expand(bsz, -1, 3)
return self.position_cache[f"{f}-{h}-{w}"]
def __call__(self, tokens, rope_positions, ch_split, parallel=False):
"""
input:
* tokens: batch_size x ntokens x nheads x dim
* rope_positions: list of (f, h, w)
output:
* tokens after appplying RoPE2D (batch_size x ntokens x nheads x dim)
"""
assert sum(ch_split) == tokens.size(-1);
mesh_grid = self.get_mesh_3d(rope_positions, bsz=tokens.shape[0])
out = []
for i, (D, x) in enumerate(zip(ch_split, torch.split(tokens, ch_split, dim=-1))):
cos, sin = self.get_cos_sin(D, int(mesh_grid.max()) + 1, tokens.device, tokens.dtype)
if parallel:
mesh = torch.chunk(mesh_grid[:, :, i], get_sequence_parallel_world_size(),dim=1)[get_sequence_parallel_rank()].clone()
else:
mesh = mesh_grid[:, :, i].clone()
x = self.apply_rope1d(x, mesh.to(tokens.device), cos, sin)
out.append(x)
tokens = torch.cat(out, dim=-1)
return tokens
import torch.distributed as dist
import xfuser
import torch
def initialize_parall_group(ring_degree, ulysses_degree):
dist.init_process_group("nccl")
xfuser.core.distributed.init_distributed_environment(
rank=dist.get_rank(),
world_size=dist.get_world_size()
)
xfuser.core.distributed.initialize_model_parallel(
sequence_parallel_degree=ulysses_degree,
ring_degree=ring_degree,
ulysses_degree=ulysses_degree,
)
torch.cuda.set_device(dist.get_rank())
def get_parallel_group():
return xfuser.core.distributed.get_world_group()
def get_sequence_parallel_world_size():
return xfuser.core.distributed.parallel_state.get_sequence_parallel_world_size()
def get_sequence_parallel_rank():
return xfuser.core.distributed.parallel_state.get_sequence_parallel_rank()
def get_sp_group():
return xfuser.core.distributed.parallel_state.get_sp_group()
def parallel_forward(fn_):
def wrapTheFunction(_, hidden_states, *args, **kwargs):
if kwargs['parallel']:
hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
kwargs['attn_mask'] = torch.chunk(kwargs['attn_mask'], get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
output = fn_(_, hidden_states, *args, **kwargs)
if kwargs['parallel']:
output = get_sp_group().all_gather(output.contiguous(), dim=-2)
return output
return wrapTheFunction
\ No newline at end of file
import torch
from stepvideo.config import parse_args
import os
accepted_version = {
'2.2': 'liboptimus_ths-torch2.2-cu121.cpython-310-x86_64-linux-gnu.so',
'2.3': 'liboptimus_ths-torch2.3-cu121.cpython-310-x86_64-linux-gnu.so',
'2.5': 'liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so',
}
try:
args = parse_args()
version = '.'.join(torch.__version__.split('.')[:2])
if version in accepted_version:
torch.ops.load_library(os.path.join(args.model_dir, f'lib/{accepted_version[version]}'))
else:
raise ValueError("Not supported torch version for liboptimus")
except Exception as err:
print(err)
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
import os
class HunyuanClip(nn.Module):
"""
Hunyuan clip code copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
hunyuan's clip used BertModel and BertTokenizer, so we copy it.
"""
def __init__(self, model_dir, max_length=77):
super(HunyuanClip, self).__init__()
self.max_length = max_length
self.tokenizer = BertTokenizer.from_pretrained(os.path.join(model_dir, 'tokenizer'))
self.text_encoder = BertModel.from_pretrained(os.path.join(model_dir, 'clip_text_encoder'))
@torch.no_grad
def forward(self, prompts, with_mask=True):
self.device = next(self.text_encoder.parameters()).device
text_inputs = self.tokenizer(
prompts,
padding="max_length",
max_length=self.max_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
prompt_embeds = self.text_encoder(
text_inputs.input_ids.to(self.device),
attention_mask=text_inputs.attention_mask.to(self.device) if with_mask else None,
)
return prompt_embeds.last_hidden_state, prompt_embeds.pooler_output
\ No newline at end of file
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
import torch
# def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=True,
# return_attn_probs=False, tp_group_rank=0, tp_group_size=1):
# softmax_scale = q.size(-1) ** (-0.5) if softmax_scale is None else softmax_scale
# return torch.ops.Optimus._fwd(q, k, v, None, dropout_p, softmax_scale, causal, return_attn_probs, None, tp_group_rank, tp_group_size)[0]
from flash_attn import flash_attn_func as fa_func
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=True,
return_attn_probs=False, tp_group_rank=0, tp_group_size=1):
softmax_scale = q.size(-1) ** (-0.5) if softmax_scale is None else softmax_scale
return fa_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=causal,
return_attn_probs=return_attn_probs)
class FlashSelfAttention(torch.nn.Module):
def __init__(
self,
attention_dropout=0.0,
):
super().__init__()
self.dropout_p = attention_dropout
def forward(self, q, k, v, cu_seqlens=None, max_seq_len=None):
if cu_seqlens is None:
output = flash_attn_func(q, k, v, dropout_p=self.dropout_p)
else:
raise ValueError('cu_seqlens is not supported!')
return output
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
import os
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from stepvideo.text_encoder.flashattention import FlashSelfAttention
from stepvideo.modules.normalization import RMSNorm
from stepvideo.text_encoder.tokenizer import LLaMaEmbedding, Wrapped_StepChatTokenizer
from stepvideo.utils import with_empty_init
from safetensors.torch import load_file
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from einops import rearrange
import json
def safediv(n, d):
q, r = divmod(n, d)
assert r == 0
return q
class MultiQueryAttention(nn.Module):
def __init__(self, cfg, layer_id=None):
super().__init__()
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
self.max_seq_len = cfg.seq_length
self.use_flash_attention = cfg.use_flash_attn
assert self.use_flash_attention, 'FlashAttention is required!'
self.n_groups = cfg.num_attention_groups
self.tp_size = 1
self.n_local_heads = cfg.num_attention_heads
self.n_local_groups = self.n_groups
self.wqkv = nn.Linear(
cfg.hidden_size,
cfg.hidden_size + self.head_dim * 2 * self.n_groups,
bias=False,
)
self.wo = nn.Linear(
cfg.hidden_size,
cfg.hidden_size,
bias=False,
)
assert self.use_flash_attention, 'non-Flash attention not supported yet.'
self.core_attention = FlashSelfAttention(attention_dropout=cfg.attention_dropout)
self.layer_id = layer_id
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor],
cu_seqlens: Optional[torch.Tensor],
max_seq_len: Optional[torch.Tensor],
):
seqlen, bsz, dim = x.shape
xqkv = self.wqkv(x)
xq, xkv = torch.split(
xqkv,
(dim // self.tp_size,
self.head_dim*2*self.n_groups // self.tp_size
),
dim=-1,
)
# gather on 1st dimension
xq = xq.view(seqlen, bsz, self.n_local_heads, self.head_dim)
xkv = xkv.view(seqlen, bsz, self.n_local_groups, 2 * self.head_dim)
xk, xv = xkv.chunk(2, -1)
# rotary embedding + flash attn
xq = rearrange(xq, "s b h d -> b s h d")
xk = rearrange(xk, "s b h d -> b s h d")
xv = rearrange(xv, "s b h d -> b s h d")
q_per_kv = self.n_local_heads // self.n_local_groups
if q_per_kv > 1:
b, s, h, d = xk.size()
if h == 1:
xk = xk.expand(b, s, q_per_kv, d)
xv = xv.expand(b, s, q_per_kv, d)
else:
''' To cover the cases where h > 1, we have
the following implementation, which is equivalent to:
xk = xk.repeat_interleave(q_per_kv, dim=-2)
xv = xv.repeat_interleave(q_per_kv, dim=-2)
but can avoid calling aten::item() that involves cpu.
'''
idx = torch.arange(q_per_kv * h, device=xk.device).reshape(q_per_kv, -1).permute(1, 0).flatten()
xk = torch.index_select(xk.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous()
xv = torch.index_select(xv.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous()
if self.use_flash_attention:
output = self.core_attention(xq, xk, xv,
cu_seqlens=cu_seqlens,
max_seq_len=max_seq_len)
# reduce-scatter only support first dimention now
output = rearrange(output, "b s h d -> s b (h d)").contiguous()
else:
xq, xk, xv = [
rearrange(x, "b s ... -> s b ...").contiguous()
for x in (xq, xk, xv)
]
output = self.core_attention(xq, xk, xv, mask)
output = self.wo(output)
return output
class FeedForward(nn.Module):
def __init__(
self,
cfg,
dim: int,
hidden_dim: int,
layer_id: int,
multiple_of: int=256,
):
super().__init__()
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
def swiglu(x):
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1]
self.swiglu = swiglu
self.w1 = nn.Linear(
dim,
2 * hidden_dim,
bias=False,
)
self.w2 = nn.Linear(
hidden_dim,
dim,
bias=False,
)
def forward(self, x):
x = self.swiglu(self.w1(x))
output = self.w2(x)
return output
class TransformerBlock(nn.Module):
def __init__(
self, cfg, layer_id: int
):
super().__init__()
self.n_heads = cfg.num_attention_heads
self.dim = cfg.hidden_size
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
self.attention = MultiQueryAttention(
cfg,
layer_id=layer_id,
)
self.feed_forward = FeedForward(
cfg,
dim=cfg.hidden_size,
hidden_dim=cfg.ffn_hidden_size,
layer_id=layer_id,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(
cfg.hidden_size,
eps=cfg.layernorm_epsilon,
)
self.ffn_norm = RMSNorm(
cfg.hidden_size,
eps=cfg.layernorm_epsilon,
)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor],
cu_seqlens: Optional[torch.Tensor],
max_seq_len: Optional[torch.Tensor],
):
residual = self.attention.forward(
self.attention_norm(x), mask,
cu_seqlens, max_seq_len
)
h = x + residual
ffn_res = self.feed_forward.forward(self.ffn_norm(h))
out = h + ffn_res
return out
class Transformer(nn.Module):
def __init__(
self,
config,
max_seq_size=8192,
):
super().__init__()
self.num_layers = config.num_layers
self.layers = self._build_layers(config)
def _build_layers(self, config):
layers = torch.nn.ModuleList()
for layer_id in range(self.num_layers):
layers.append(
TransformerBlock(
config,
layer_id=layer_id + 1 ,
)
)
return layers
def forward(
self,
hidden_states,
attention_mask,
cu_seqlens=None,
max_seq_len=None,
):
if max_seq_len is not None and not isinstance(max_seq_len, torch.Tensor):
max_seq_len = torch.tensor(max_seq_len, dtype=torch.int32, device="cpu")
for lid, layer in enumerate(self.layers):
hidden_states = layer(
hidden_states,
attention_mask,
cu_seqlens,
max_seq_len,
)
return hidden_states
class Step1Model(PreTrainedModel):
config_class=PretrainedConfig
@with_empty_init
def __init__(
self,
config,
):
super().__init__(config)
self.tok_embeddings = LLaMaEmbedding(config)
self.transformer = Transformer(config)
def forward(
self,
input_ids=None,
attention_mask=None,
):
hidden_states = self.tok_embeddings(input_ids)
hidden_states = self.transformer(
hidden_states,
attention_mask,
)
return hidden_states
class STEP1TextEncoder(torch.nn.Module):
def __init__(self, model_dir, max_length=320):
super(STEP1TextEncoder, self).__init__()
self.max_length = max_length
self.text_tokenizer = Wrapped_StepChatTokenizer(os.path.join(model_dir, 'step1_chat_tokenizer.model'))
text_encoder = Step1Model.from_pretrained(model_dir)
self.text_encoder = text_encoder.eval().to(torch.bfloat16)
@torch.no_grad
def forward(self, prompts, with_mask=True, max_length=None):
self.device = next(self.text_encoder.parameters()).device
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
if type(prompts) is str:
prompts = [prompts]
txt_tokens = self.text_tokenizer(
prompts, max_length=max_length or self.max_length, padding="max_length", truncation=True, return_tensors="pt"
)
y = self.text_encoder(
txt_tokens.input_ids.to(self.device),
attention_mask=txt_tokens.attention_mask.to(self.device) if with_mask else None
)
y_mask = txt_tokens.attention_mask
return y.transpose(0,1), y_mask
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
import torch.nn as nn
import torch
from typing import List
class LLaMaEmbedding(nn.Module):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(self,
cfg,
):
super().__init__()
self.hidden_size = cfg.hidden_size
self.params_dtype = cfg.params_dtype
self.fp32_residual_connection = cfg.fp32_residual_connection
self.embedding_weights_in_fp32 = cfg.embedding_weights_in_fp32
self.word_embeddings = torch.nn.Embedding(
cfg.padded_vocab_size, self.hidden_size,
)
self.embedding_dropout = torch.nn.Dropout(cfg.hidden_dropout)
def forward(self, input_ids):
# Embeddings.
if self.embedding_weights_in_fp32:
self.word_embeddings = self.word_embeddings.to(torch.float32)
embeddings = self.word_embeddings(input_ids)
if self.embedding_weights_in_fp32:
embeddings = embeddings.to(self.params_dtype)
self.word_embeddings = self.word_embeddings.to(self.params_dtype)
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
embeddings = embeddings.float()
# Dropout.
embeddings = self.embedding_dropout(embeddings)
return embeddings
class StepChatTokenizer:
"""Step Chat Tokenizer"""
def __init__(
self, model_file, name="StepChatTokenizer",
bot_token="<|BOT|>", # Begin of Turn
eot_token="<|EOT|>", # End of Turn
call_start_token="<|CALL_START|>", # Call Start
call_end_token="<|CALL_END|>", # Call End
think_start_token="<|THINK_START|>", # Think Start
think_end_token="<|THINK_END|>", # Think End
mask_start_token="<|MASK_1e69f|>", # Mask start
mask_end_token="<|UNMASK_1e69f|>", # Mask end
):
import sentencepiece
self._tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
self._vocab = {}
self._inv_vocab = {}
self._special_tokens = {}
self._inv_special_tokens = {}
self._t5_tokens = []
for idx in range(self._tokenizer.get_piece_size()):
text = self._tokenizer.id_to_piece(idx)
self._inv_vocab[idx] = text
self._vocab[text] = idx
if self._tokenizer.is_control(idx) or self._tokenizer.is_unknown(idx):
self._special_tokens[text] = idx
self._inv_special_tokens[idx] = text
self._unk_id = self._tokenizer.unk_id()
self._bos_id = self._tokenizer.bos_id()
self._eos_id = self._tokenizer.eos_id()
for token in [
bot_token, eot_token, call_start_token, call_end_token,
think_start_token, think_end_token
]:
assert token in self._vocab, f"Token '{token}' not found in tokenizer"
assert token in self._special_tokens, f"Token '{token}' is not a special token"
for token in [mask_start_token, mask_end_token]:
assert token in self._vocab, f"Token '{token}' not found in tokenizer"
self._bot_id = self._tokenizer.piece_to_id(bot_token)
self._eot_id = self._tokenizer.piece_to_id(eot_token)
self._call_start_id = self._tokenizer.piece_to_id(call_start_token)
self._call_end_id = self._tokenizer.piece_to_id(call_end_token)
self._think_start_id = self._tokenizer.piece_to_id(think_start_token)
self._think_end_id = self._tokenizer.piece_to_id(think_end_token)
self._mask_start_id = self._tokenizer.piece_to_id(mask_start_token)
self._mask_end_id = self._tokenizer.piece_to_id(mask_end_token)
self._underline_id = self._tokenizer.piece_to_id("\u2581")
@property
def vocab(self):
return self._vocab
@property
def inv_vocab(self):
return self._inv_vocab
@property
def vocab_size(self):
return self._tokenizer.vocab_size()
def tokenize(self, text: str) -> List[int]:
return self._tokenizer.encode_as_ids(text)
def detokenize(self, token_ids: List[int]) -> str:
return self._tokenizer.decode_ids(token_ids)
class Tokens:
def __init__(self, input_ids, cu_input_ids, attention_mask, cu_seqlens, max_seq_len) -> None:
self.input_ids = input_ids
self.attention_mask = attention_mask
self.cu_input_ids = cu_input_ids
self.cu_seqlens = cu_seqlens
self.max_seq_len = max_seq_len
def to(self, device):
self.input_ids = self.input_ids.to(device)
self.attention_mask = self.attention_mask.to(device)
self.cu_input_ids = self.cu_input_ids.to(device)
self.cu_seqlens = self.cu_seqlens.to(device)
return self
class Wrapped_StepChatTokenizer(StepChatTokenizer):
def __call__(self, text, max_length=320, padding="max_length", truncation=True, return_tensors="pt"):
# [bos, ..., eos, pad, pad, ..., pad]
self.BOS = 1
self.EOS = 2
self.PAD = 2
out_tokens = []
attn_mask = []
if len(text) == 0:
part_tokens = [self.BOS] + [self.EOS]
valid_size = len(part_tokens)
if len(part_tokens) < max_length:
part_tokens += [self.PAD] * (max_length - valid_size)
out_tokens.append(part_tokens)
attn_mask.append([1]*valid_size+[0]*(max_length-valid_size))
else:
for part in text:
part_tokens = self.tokenize(part)
part_tokens = part_tokens[:(max_length - 2)] # leave 2 space for bos and eos
part_tokens = [self.BOS] + part_tokens + [self.EOS]
valid_size = len(part_tokens)
if len(part_tokens) < max_length:
part_tokens += [self.PAD] * (max_length - valid_size)
out_tokens.append(part_tokens)
attn_mask.append([1]*valid_size+[0]*(max_length-valid_size))
out_tokens = torch.tensor(out_tokens, dtype=torch.long)
attn_mask = torch.tensor(attn_mask, dtype=torch.long)
# padding y based on tp size
padded_len = 0
padded_flag = True if padded_len > 0 else False
if padded_flag:
pad_tokens = torch.tensor([[self.PAD] * max_length], device=out_tokens.device)
pad_attn_mask = torch.tensor([[1]*padded_len+[0]*(max_length-padded_len)], device=attn_mask.device)
out_tokens = torch.cat([out_tokens, pad_tokens], dim=0)
attn_mask = torch.cat([attn_mask, pad_attn_mask], dim=0)
# cu_seqlens
cu_out_tokens = out_tokens.masked_select(attn_mask != 0).unsqueeze(0)
seqlen = attn_mask.sum(dim=1).tolist()
cu_seqlens = torch.cumsum(torch.tensor([0]+seqlen), 0).to(device=out_tokens.device,dtype=torch.int32)
max_seq_len = max(seqlen)
return Tokens(out_tokens, cu_out_tokens, attn_mask, cu_seqlens, max_seq_len)
\ No newline at end of file
from .utils import *
from .video_process import *
\ No newline at end of file
import numpy as np
import random
import torch
from functools import wraps
import torch.utils._device
def setup_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
def __init__(self, device=None):
self.device = device
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if getattr(func, '__module__', None) == 'torch.nn.init':
if 'tensor' in kwargs:
return kwargs['tensor']
else:
return args[0]
if self.device is not None and func in torch.utils._device._device_constructors() and kwargs.get('device') is None:
kwargs['device'] = self.device
return func(*args, **kwargs)
def with_empty_init(func):
@wraps(func)
def wrapper(*args, **kwargs):
with EmptyInitOnDevice('cpu'):
return func(*args, **kwargs)
return wrapper
def culens2mask(
cu_seqlens=None,
cu_seqlens_kv=None,
max_seqlen=None,
max_seqlen_kv=None,
is_causal=False
):
assert len(cu_seqlens) == len(cu_seqlens_kv); "q k v should have same bsz..."
bsz = len(cu_seqlens) - 1
seqlens = cu_seqlens[1:]-cu_seqlens[:-1]
seqlens_kv = cu_seqlens_kv[1:]-cu_seqlens_kv[:-1]
attn_mask = torch.zeros(bsz, max_seqlen, max_seqlen_kv, dtype=torch.bool)
for i, (seq_len, seq_len_kv) in enumerate(zip(seqlens, seqlens_kv)):
if is_causal:
attn_mask[i, :seq_len, :seq_len_kv] = torch.triu(torch.ones(seq_len, seq_len_kv), diagonal=1).bool()
else:
attn_mask[i, :seq_len, :seq_len_kv] = torch.ones([seq_len, seq_len_kv], dtype=torch.bool)
return attn_mask
import numpy as np
import datetime
import torch
import os
import imageio
class VideoProcessor:
def __init__(self, save_path: str='./results', name_suffix: str=''):
self.save_path = save_path
os.makedirs(self.save_path, exist_ok=True)
self.name_suffix = name_suffix
def crop2standard540p(self, vid_array):
_, height, width, _ = vid_array.shape
height_center = height//2
width_center = width//2
if width_center>height_center: ## horizon mode
return vid_array[:, height_center-270:height_center+270, width_center-480:width_center+480]
elif width_center<height_center: ## portrait mode
return vid_array[:, height_center-480:height_center+480, width_center-270:width_center+270]
else:
return vid_array
def save_imageio_video(self, video_array: np.array, output_filename: str, fps=25, codec='libx264'):
ffmpeg_params = [
"-vf", "atadenoise=0a=0.1:0b=0.1:1a=0.1:1b=0.1", # denoise
]
with imageio.get_writer(output_filename, fps=fps, codec=codec, ffmpeg_params=ffmpeg_params) as vid_writer:
for img_array in video_array:
vid_writer.append_data(img_array)
def postprocess_video(self, video_tensor, output_file_name='', output_type="mp4", crop2standard540p=True):
if len(self.name_suffix) == 0:
video_path = os.path.join(self.save_path, f"{output_file_name}-{str(datetime.datetime.now())}.{output_type}")
else:
video_path = os.path.join(self.save_path, f"{output_file_name}-{self.name_suffix}.{output_type}")
video_tensor = (video_tensor.cpu().clamp(-1, 1)+1)*127.5
video_tensor = torch.cat([t for t in video_tensor], dim=-2)
video_array = video_tensor.clamp(0, 255).to(torch.uint8).numpy().transpose(0,2,3,1)
if crop2standard540p:
video_array = self.crop2standard540p(video_array)
self.save_imageio_video(video_array, video_path)
print(f"Saved the generated video in {video_path}")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment