Unverified Commit 7bc1dae0 authored by Mick's avatar Mick Committed by GitHub
Browse files

WIP: initial multimodal-gen support (#12484)


Co-authored-by: default avataryhyang201 <yhyang201@gmail.com>
Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
Co-authored-by: default avatarJiLi <leege233@gmail.com>
Co-authored-by: default avatarCHEN Xi <78632976+RubiaCx@users.noreply.github.com>
Co-authored-by: default avatarlaixin <xielx@shanghaitech.edu.cn>
Co-authored-by: default avatarSolitaryThinker <wlsaidhi@gmail.com>
Co-authored-by: default avatarjzhang38 <a1286225768@gmail.com>
Co-authored-by: default avatarBrianChen1129 <yongqichcd@gmail.com>
Co-authored-by: default avatarKevin Lin <42618777+kevin314@users.noreply.github.com>
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>
Co-authored-by: default avatarrlsu9 <r3su@ucsd.edu>
Co-authored-by: default avatarJinzhe Pan <48981407+eigensystem@users.noreply.github.com>
Co-authored-by: default avatarforeverpiano <pianoqwz@qq.com>
Co-authored-by: default avatarRandNMR73 <notomatthew31@gmail.com>
Co-authored-by: default avatarPorridgeSwim <yz3883@columbia.edu>
Co-authored-by: default avatarJiali Chen <90408393+gary-chenjl@users.noreply.github.com>
parent 4fe53e58
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Wan video diffusion pipeline implementation.
This module contains an implementation of the Wan video diffusion pipeline
using the modular pipeline architecture.
"""
from sglang.multimodal_gen.runtime.models.schedulers.scheduling_flow_unipc_multistep import (
FlowUniPCMultistepScheduler,
)
from sglang.multimodal_gen.runtime.pipelines import ComposedPipelineBase, LoRAPipeline
from sglang.multimodal_gen.runtime.pipelines.stages import (
ConditioningStage,
DecodingStage,
DenoisingStage,
InputValidationStage,
LatentPreparationStage,
TextEncodingStage,
TimestepPreparationStage,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class WanPipeline(LoRAPipeline, ComposedPipelineBase):
"""
Wan video diffusion pipeline with LoRA support.
"""
pipeline_name = "WanImageToVideoPipeline"
_required_config_modules = [
"text_encoder",
"tokenizer",
"vae",
"transformer",
"scheduler",
]
def initialize_pipeline(self, server_args: ServerArgs):
# We use UniPCMScheduler from Wan2.1 official repo, not the one in diffusers.
self.modules["scheduler"] = FlowUniPCMultistepScheduler(
shift=server_args.pipeline_config.flow_shift
)
def create_pipeline_stages(self, server_args: ServerArgs) -> None:
"""Set up pipeline stages with proper dependency injection."""
self.add_stage(
stage_name="input_validation_stage", stage=InputValidationStage()
)
self.add_stage(
stage_name="prompt_encoding_stage",
stage=TextEncodingStage(
text_encoders=[self.get_module("text_encoder")],
tokenizers=[self.get_module("tokenizer")],
),
)
self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())
self.add_stage(
stage_name="timestep_preparation_stage",
stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")),
)
self.add_stage(
stage_name="latent_preparation_stage",
stage=LatentPreparationStage(
scheduler=self.get_module("scheduler"),
transformer=self.get_module("transformer", None),
),
)
self.add_stage(
stage_name="denoising_stage",
stage=DenoisingStage(
transformer=self.get_module("transformer"),
transformer_2=self.get_module("transformer_2", None),
scheduler=self.get_module("scheduler"),
vae=self.get_module("vae"),
pipeline=self,
),
)
self.add_stage(
stage_name="decoding_stage",
stage=DecodingStage(vae=self.get_module("vae"), pipeline=self),
)
EntryClass = WanPipeline
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Any
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from sglang.multimodal_gen.dataset import getdataset
from sglang.multimodal_gen.dataset.dataloader.parquet_io import (
ParquetDatasetWriter,
records_to_table,
)
from sglang.multimodal_gen.dataset.preprocessing_datasets import PreprocessBatch
from sglang.multimodal_gen.runtime.distributed import get_local_torch_device
from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import (
ComposedPipelineBase,
)
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.pipelines.stages import TextEncodingStage
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class BasePreprocessPipeline(ComposedPipelineBase):
"""Base class for preprocessing pipelines that handles common functionality."""
def create_pipeline_stages(self, server_args: ServerArgs):
"""Set up pipeline stages with proper dependency injection."""
self.add_stage(
stage_name="prompt_encoding_stage",
stage=TextEncodingStage(
text_encoders=[self.get_module("text_encoder")],
tokenizers=[self.get_module("tokenizer")],
),
)
@torch.no_grad()
def forward(
self,
batch: Req,
server_args: ServerArgs,
args,
):
if not self.post_init_called:
self.post_init()
# Initialize class variables for data sharing
self.video_data: dict[str, Any] = {} # Store video metadata and paths
self.latent_data: dict[str, Any] = {} # Store latent tensors
self.preprocess_video_and_text(server_args, args)
def get_extra_features(
self, valid_data: dict[str, Any], server_args: ServerArgs
) -> dict[str, Any]:
"""Get additional features specific to the pipeline type. Override in subclasses."""
return {}
def get_pyarrow_schema(self) -> pa.Schema:
"""Return the PyArrow schema for this pipeline. Must be overridden."""
raise NotImplementedError
def get_schema_fields(self) -> list[str]:
"""Get the schema fields for the pipeline type."""
return [f.name for f in self.get_pyarrow_schema()]
def create_record_for_schema(
self, preprocess_batch: PreprocessBatch, schema: pa.Schema, strict: bool = False
) -> dict[str, Any]:
"""Create a record for the Parquet dataset using a generic schema-based approach.
Args:
preprocess_batch: The batch containing the data to extract
schema: PyArrow schema defining the expected fields
strict: If True, raises an exception when required fields are missing or unfilled
Returns:
Dictionary record matching the schema
Raises:
ValueError: If strict=True and required fields are missing or unfilled
"""
record = {}
unfilled_fields = []
for field in schema.names:
field_filled = False
if field.endswith("_bytes"):
# Handle binary tensor data - convert numpy array or tensor to bytes
tensor_name = field.replace("_bytes", "")
tensor_data = getattr(preprocess_batch, tensor_name, None)
if tensor_data is not None:
try:
if hasattr(tensor_data, "numpy"): # torch tensor
record[field] = tensor_data.cpu().numpy().tobytes()
field_filled = True
elif hasattr(tensor_data, "tobytes"): # numpy array
record[field] = tensor_data.tobytes()
field_filled = True
else:
raise ValueError(
f"Unsupported tensor type for field {field}: {type(tensor_data)}"
)
except Exception as e:
if strict:
raise ValueError(
f"Failed to convert tensor {tensor_name} to bytes: {e}"
) from e
record[field] = b"" # Empty bytes for missing data
else:
record[field] = b"" # Empty bytes for missing data
elif field.endswith("_shape"):
# Handle tensor shape info
tensor_name = field.replace("_shape", "")
tensor_data = getattr(preprocess_batch, tensor_name, None)
if tensor_data is not None and hasattr(tensor_data, "shape"):
record[field] = list(tensor_data.shape)
field_filled = True
else:
record[field] = []
elif field.endswith("_dtype"):
# Handle tensor dtype info
tensor_name = field.replace("_dtype", "")
tensor_data = getattr(preprocess_batch, tensor_name, None)
if tensor_data is not None and hasattr(tensor_data, "dtype"):
record[field] = str(tensor_data.dtype)
field_filled = True
else:
record[field] = "unknown"
elif field in ["width", "height", "num_frames"]:
# Handle integer metadata fields
value = getattr(preprocess_batch, field, None)
if value is not None:
try:
record[field] = int(value)
field_filled = True
except (ValueError, TypeError) as e:
if strict:
raise ValueError(
f"Failed to convert field {field} to int: {e}"
) from e
record[field] = 0
else:
record[field] = 0
elif field in ["duration_sec", "fps"]:
# Handle float metadata fields
# Map schema field names to batch attribute names
attr_name = "duration" if field == "duration_sec" else field
value = getattr(preprocess_batch, attr_name, None)
if value is not None:
try:
record[field] = float(value)
field_filled = True
except (ValueError, TypeError) as e:
if strict:
raise ValueError(
f"Failed to convert field {field} to float: {e}"
) from e
record[field] = 0.0
else:
record[field] = 0.0
else:
# Handle string fields (id, file_name, caption, media_type, etc.)
# Map common schema field names to batch attribute names
attr_name = field
if field == "caption":
attr_name = "text"
elif field == "file_name":
attr_name = "path"
elif field == "id":
# Generate ID from path if available
path_value = getattr(preprocess_batch, "path", None)
if path_value:
import os
record[field] = os.path.basename(path_value).split(".")[0]
field_filled = True
else:
record[field] = ""
continue
elif field == "media_type":
# Determine media type from path
path_value = getattr(preprocess_batch, "path", None)
if path_value:
record[field] = (
"video" if path_value.endswith(".mp4") else "image"
)
field_filled = True
else:
record[field] = ""
continue
value = getattr(preprocess_batch, attr_name, None)
if value is not None:
record[field] = str(value)
field_filled = True
else:
record[field] = ""
# Track unfilled fields
if not field_filled:
unfilled_fields.append(field)
# Handle strict mode
if strict and unfilled_fields:
raise ValueError(f"Required fields were not filled: {unfilled_fields}")
# Log unfilled fields as warning if not in strict mode
if unfilled_fields:
logger.warning(
"Some fields were not filled and got default values: %s",
unfilled_fields,
)
return record
def create_record(
self,
video_name: str,
vae_latent: np.ndarray,
text_embedding: np.ndarray,
valid_data: dict[str, Any],
idx: int,
extra_features: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Create a record for the Parquet dataset."""
record = {
"id": video_name,
"vae_latent_bytes": vae_latent.tobytes(),
"vae_latent_shape": list(vae_latent.shape),
"vae_latent_dtype": str(vae_latent.dtype),
"text_embedding_bytes": text_embedding.tobytes(),
"text_embedding_shape": list(text_embedding.shape),
"text_embedding_dtype": str(text_embedding.dtype),
"file_name": video_name,
"caption": valid_data["text"][idx] if len(valid_data["text"]) > 0 else "",
"media_type": "video",
"width": (
valid_data["pixel_values"][idx].shape[-2]
if len(valid_data["pixel_values"]) > 0
else 0
),
"height": (
valid_data["pixel_values"][idx].shape[-1]
if len(valid_data["pixel_values"]) > 0
else 0
),
"num_frames": vae_latent.shape[1] if len(vae_latent.shape) > 1 else 0,
"duration_sec": (
float(valid_data["duration"][idx])
if len(valid_data["duration"]) > 0
else 0.0
),
"fps": float(valid_data["fps"][idx]) if len(valid_data["fps"]) > 0 else 0.0,
}
if extra_features:
record.update(extra_features)
return record
def preprocess_video_and_text(self, server_args: ServerArgs, args):
os.makedirs(args.output_dir, exist_ok=True)
# Create directory for combined data
combined_parquet_dir = os.path.join(args.output_dir, "combined_parquet_dataset")
os.makedirs(combined_parquet_dir, exist_ok=True)
local_rank = int(os.getenv("RANK", 0))
# Get how many samples have already been processed
start_idx = 0
for root, _, files in os.walk(combined_parquet_dir):
for file in files:
if file.endswith(".parquet"):
table = pq.read_table(os.path.join(root, file))
start_idx += table.num_rows
# Loading dataset
train_dataset = getdataset(args)
train_dataloader = DataLoader(
train_dataset,
batch_size=args.preprocess_video_batch_size,
num_workers=args.dataloader_num_workers,
)
num_processed_samples = 0
# Add progress bar for video preprocessing
pbar = tqdm(
train_dataloader,
desc="Processing videos",
unit="batch",
disable=local_rank != 0,
)
for batch_idx, data in enumerate(pbar):
if data is None:
continue
with torch.inference_mode():
# Filter out invalid samples (those with all zeros)
valid_indices = []
for i, pixel_values in enumerate(data["pixel_values"]):
if not torch.all(pixel_values == 0): # Check if all values are zero
valid_indices.append(i)
num_processed_samples += len(valid_indices)
if not valid_indices:
continue
# Create new batch with only valid samples
valid_data = {
"pixel_values": torch.stack(
[data["pixel_values"][i] for i in valid_indices]
),
"text": [data["text"][i] for i in valid_indices],
"path": [data["path"][i] for i in valid_indices],
"fps": [data["fps"][i] for i in valid_indices],
"duration": [data["duration"][i] for i in valid_indices],
}
# VAE
with torch.autocast("cuda", dtype=torch.float32):
latents = (
self.get_module("vae")
.encode(valid_data["pixel_values"].to(get_local_torch_device()))
.mean
)
# Get extra features if needed
extra_features = self.get_extra_features(valid_data, server_args)
batch_captions = valid_data["text"]
batch = Req(
data_type="video",
prompt=batch_captions,
prompt_embeds=[],
prompt_attention_mask=[],
)
assert hasattr(self, "prompt_encoding_stage")
result_batch = self.prompt_encoding_stage(batch, server_args)
prompt_embeds, prompt_attention_mask = (
result_batch.prompt_embeds[0],
result_batch.prompt_attention_mask[0],
)
assert prompt_embeds.shape[0] == prompt_attention_mask.shape[0]
# Get sequence lengths from attention masks (number of 1s)
seq_lens = prompt_attention_mask.sum(dim=1)
non_padded_embeds = []
non_padded_masks = []
# Process each item in the batch
for i in range(prompt_embeds.size(0)):
seq_len = seq_lens[i].item()
# Slice the embeddings and masks to keep only non-padding parts
non_padded_embeds.append(prompt_embeds[i, :seq_len])
non_padded_masks.append(prompt_attention_mask[i, :seq_len])
# Update the tensors with non-padded versions
prompt_embeds = non_padded_embeds
prompt_attention_mask = non_padded_masks
# Prepare batch data for Parquet dataset
batch_data = []
# Add progress bar for saving outputs
save_pbar = tqdm(
enumerate(valid_data["path"]),
desc="Saving outputs",
unit="item",
leave=False,
)
for idx, video_path in save_pbar:
# Get the corresponding latent and info using video name
latent = latents[idx].cpu()
video_name = os.path.basename(video_path).split(".")[0]
# Convert tensors to numpy arrays
vae_latent = latent.cpu().numpy()
text_embedding = prompt_embeds[idx].cpu().numpy()
# Get extra features for this sample if needed
sample_extra_features = {}
if extra_features:
for key, value in extra_features.items():
if isinstance(value, torch.Tensor):
sample_extra_features[key] = value[idx].cpu().numpy()
else:
sample_extra_features[key] = value[idx]
# Create record for Parquet dataset
record = self.create_record(
video_name=video_name,
vae_latent=vae_latent,
text_embedding=text_embedding,
valid_data=valid_data,
idx=idx,
extra_features=sample_extra_features,
)
batch_data.append(record)
if batch_data:
write_pbar = tqdm(
total=1, desc="Writing to Parquet dataset", unit="batch"
)
table = records_to_table(batch_data, self.get_pyarrow_schema())
write_pbar.update(1)
write_pbar.close()
if not hasattr(self, "dataset_writer"):
self.dataset_writer = ParquetDatasetWriter(
out_dir=combined_parquet_dir,
samples_per_file=args.samples_per_file,
)
self.dataset_writer.append_table(table)
logger.info("Collected batch with %s samples", len(table))
if num_processed_samples >= args.flush_frequency:
written = self.dataset_writer.flush()
logger.info("Flushed %s samples to parquet", written)
num_processed_samples = 0
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
I2V Data Preprocessing pipeline implementation.
This module contains an implementation of the I2V Data Preprocessing pipeline
using the modular pipeline architecture.
"""
from typing import Any
import numpy as np
import torch
from PIL import Image
from sglang.multimodal_gen.dataset.dataloader.schema import pyarrow_schema_i2v
from sglang.multimodal_gen.runtime.distributed import get_local_torch_device
from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context
from sglang.multimodal_gen.runtime.pipelines.preprocess.preprocess_pipeline_base import (
BasePreprocessPipeline,
)
from sglang.multimodal_gen.runtime.pipelines.stages import (
ImageEncodingStage,
TextEncodingStage,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
class PreprocessPipeline_I2V(BasePreprocessPipeline):
"""I2V preprocessing pipeline implementation."""
_required_config_modules = [
"text_encoder",
"tokenizer",
"vae",
"image_encoder",
"image_processor",
]
def create_pipeline_stages(self, server_args: ServerArgs):
self.add_stage(
stage_name="prompt_encoding_stage",
stage=TextEncodingStage(
text_encoders=[self.get_module("text_encoder")],
tokenizers=[self.get_module("tokenizer")],
),
)
self.add_stage(
stage_name="image_encoding_stage",
stage=ImageEncodingStage(
image_encoder=self.get_module("image_encoder"),
image_processor=self.get_module("image_processor"),
),
)
def get_pyarrow_schema(self):
"""Return the PyArrow schema for I2V pipeline."""
return pyarrow_schema_i2v
def get_extra_features(
self, valid_data: dict[str, Any], server_args: ServerArgs
) -> dict[str, Any]:
# TODO(will): move these to cpu at some point
self.get_module("image_encoder").to(get_local_torch_device())
self.get_module("vae").to(get_local_torch_device())
features = {}
"""Get CLIP features from the first frame of each video."""
first_frame = valid_data["pixel_values"][:, :, 0, :, :].permute(
0, 2, 3, 1
) # (B, C, T, H, W) -> (B, H, W, C)
_, _, num_frames, height, width = valid_data["pixel_values"].shape
# latent_height = height // self.get_module(
# "vae").spatial_compression_ratio
# latent_width = width // self.get_module("vae").spatial_compression_ratio
processed_images = []
# Frame has values between -1 and 1
for frame in first_frame:
frame = (frame + 1) * 127.5
frame_pil = Image.fromarray(frame.cpu().numpy().astype(np.uint8))
processed_img = self.get_module("image_processor")(
images=frame_pil, return_tensors="pt"
)
processed_images.append(processed_img)
# Get CLIP features
pixel_values = torch.cat(
[img["pixel_values"] for img in processed_images], dim=0
).to(get_local_torch_device())
with torch.no_grad():
image_inputs = {"pixel_values": pixel_values}
with set_forward_context(current_timestep=0, attn_metadata=None):
clip_features = self.get_module("image_encoder")(**image_inputs)
clip_features = clip_features.last_hidden_state
features["clip_feature"] = clip_features
"""Get VAE features from the first frame of each video"""
video_conditions = []
for frame in first_frame:
processed_img = frame.to(device="cpu", dtype=torch.float32)
processed_img = processed_img.unsqueeze(0).permute(0, 3, 1, 2).unsqueeze(2)
# (B, H, W, C) -> (B, C, 1, H, W)
video_condition = torch.cat(
[
processed_img,
processed_img.new_zeros(
processed_img.shape[0],
processed_img.shape[1],
num_frames - 1,
height,
width,
),
],
dim=2,
)
video_condition = video_condition.to(
device=get_local_torch_device(), dtype=torch.float32
)
video_conditions.append(video_condition)
video_conditions = torch.cat(video_conditions, dim=0)
with torch.autocast(device_type="cuda", dtype=torch.float32, enabled=True):
encoder_outputs = self.get_module("vae").encode(video_conditions)
latent_condition = encoder_outputs.mean
if (
hasattr(self.get_module("vae"), "shift_factor")
and self.get_module("vae").shift_factor is not None
):
if isinstance(self.get_module("vae").shift_factor, torch.Tensor):
latent_condition -= self.get_module("vae").shift_factor.to(
latent_condition.device, latent_condition.dtype
)
else:
latent_condition -= self.get_module("vae").shift_factor
if isinstance(self.get_module("vae").scaling_factor, torch.Tensor):
latent_condition = latent_condition * self.get_module(
"vae"
).scaling_factor.to(latent_condition.device, latent_condition.dtype)
else:
latent_condition = latent_condition * self.get_module("vae").scaling_factor
# mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height,
# latent_width)
# mask_lat_size[:, :, list(range(1, num_frames))] = 0
# first_frame_mask = mask_lat_size[:, :, 0:1]
# first_frame_mask = torch.repeat_interleave(
# first_frame_mask,
# dim=2,
# repeats=self.get_module("vae").temporal_compression_ratio)
# mask_lat_size = torch.concat(
# [first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
# mask_lat_size = mask_lat_size.view(
# batch_size, -1,
# self.get_module("vae").temporal_compression_ratio, latent_height,
# latent_width)
# mask_lat_size = mask_lat_size.transpose(1, 2)
# mask_lat_size = mask_lat_size.to(latent_condition.device)
# image_latent = torch.concat([mask_lat_size, latent_condition], dim=1)
features["first_frame_latent"] = latent_condition
return features
def create_record(
self,
video_name: str,
vae_latent: np.ndarray,
text_embedding: np.ndarray,
valid_data: dict[str, Any],
idx: int,
extra_features: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Create a record for the Parquet dataset with CLIP features."""
record = super().create_record(
video_name=video_name,
vae_latent=vae_latent,
text_embedding=text_embedding,
valid_data=valid_data,
idx=idx,
extra_features=extra_features,
)
if extra_features and "clip_feature" in extra_features:
clip_feature = extra_features["clip_feature"]
record.update(
{
"clip_feature_bytes": clip_feature.tobytes(),
"clip_feature_shape": list(clip_feature.shape),
"clip_feature_dtype": str(clip_feature.dtype),
}
)
else:
record.update(
{
"clip_feature_bytes": b"",
"clip_feature_shape": [],
"clip_feature_dtype": "",
}
)
if extra_features and "first_frame_latent" in extra_features:
first_frame_latent = extra_features["first_frame_latent"]
record.update(
{
"first_frame_latent_bytes": first_frame_latent.tobytes(),
"first_frame_latent_shape": list(first_frame_latent.shape),
"first_frame_latent_dtype": str(first_frame_latent.dtype),
}
)
else:
record.update(
{
"first_frame_latent_bytes": b"",
"first_frame_latent_shape": [],
"first_frame_latent_dtype": "",
}
)
if extra_features and "pil_image" in extra_features:
pil_image = extra_features["pil_image"]
record.update(
{
"pil_image_bytes": pil_image.tobytes(),
"pil_image_shape": list(pil_image.shape),
"pil_image_dtype": str(pil_image.dtype),
}
)
else:
record.update(
{
"pil_image_bytes": b"",
"pil_image_shape": [],
"pil_image_dtype": "",
}
)
return record
EntryClass = PreprocessPipeline_I2V
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
ODE Trajectory Data Preprocessing pipeline implementation.
This module contains an implementation of the ODE Trajectory Data Preprocessing pipeline
using the modular pipeline architecture.
Sec 4.3 of CausVid paper: https://arxiv.org/pdf/2412.07772
"""
import os
from collections.abc import Iterator
from typing import Any
import pyarrow as pa
import torch
from torch.utils.data import DataLoader
from torchdata.stateful_dataloader import StatefulDataLoader
from tqdm import tqdm
from sglang.multimodal_gen.configs.sample import SamplingParams
from sglang.multimodal_gen.dataset import gettextdataset
from sglang.multimodal_gen.dataset.dataloader.parquet_io import (
ParquetDatasetWriter,
records_to_table,
)
from sglang.multimodal_gen.dataset.dataloader.record_schema import (
ode_text_only_record_creator,
)
from sglang.multimodal_gen.dataset.dataloader.schema import (
pyarrow_schema_ode_trajectory_text_only,
)
from sglang.multimodal_gen.runtime.models.schedulers.scheduling_self_forcing_flow_match import (
SelfForcingFlowMatchScheduler,
)
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.pipelines.preprocess.preprocess_pipeline_base import (
BasePreprocessPipeline,
)
from sglang.multimodal_gen.runtime.pipelines.stages import (
DecodingStage,
DenoisingStage,
InputValidationStage,
LatentPreparationStage,
TextEncodingStage,
TimestepPreparationStage,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import save_decoded_latents_as_video, shallow_asdict
logger = init_logger(__name__)
class PreprocessPipeline_ODE_Trajectory(BasePreprocessPipeline):
"""ODE Trajectory preprocessing pipeline implementation."""
_required_config_modules = [
"text_encoder",
"tokenizer",
"vae",
"transformer",
"scheduler",
]
preprocess_dataloader: StatefulDataLoader
preprocess_loader_iter: Iterator[dict[str, Any]]
pbar: Any
num_processed_samples: int
def get_pyarrow_schema(self) -> pa.Schema:
"""Return the PyArrow schema for ODE Trajectory pipeline."""
return pyarrow_schema_ode_trajectory_text_only
def create_pipeline_stages(self, server_args: ServerArgs):
"""Set up pipeline stages with proper dependency injection."""
assert server_args.pipeline_config.flow_shift == 5
self.modules["scheduler"] = SelfForcingFlowMatchScheduler(
shift=server_args.pipeline_config.flow_shift,
sigma_min=0.0,
extra_one_step=True,
)
self.modules["scheduler"].set_timesteps(
num_inference_steps=48, denoising_strength=1.0
)
self.add_stage(
stage_name="input_validation_stage", stage=InputValidationStage()
)
self.add_stage(
stage_name="prompt_encoding_stage",
stage=TextEncodingStage(
text_encoders=[self.get_module("text_encoder")],
tokenizers=[self.get_module("tokenizer")],
),
)
self.add_stage(
stage_name="timestep_preparation_stage",
stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")),
)
self.add_stage(
stage_name="latent_preparation_stage",
stage=LatentPreparationStage(
scheduler=self.get_module("scheduler"),
transformer=self.get_module("transformer", None),
),
)
self.add_stage(
stage_name="denoising_stage",
stage=DenoisingStage(
transformer=self.get_module("transformer"),
scheduler=self.get_module("scheduler"),
pipeline=self,
),
)
self.add_stage(
stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"))
)
def preprocess_text_and_trajectory(self, server_args: ServerArgs, args):
"""Preprocess text-only data and generate trajectory information."""
for batch_idx, data in enumerate(self.pbar):
if data is None:
continue
with torch.inference_mode():
# For text-only processing, we only need text data
# Filter out samples without text
valid_indices = []
for i, text in enumerate(data["text"]):
if text and text.strip(): # Check if text is not empty
valid_indices.append(i)
self.num_processed_samples += len(valid_indices)
if not valid_indices:
continue
# Create new batch with only valid samples (text-only)
valid_data = {
"text": [data["text"][i] for i in valid_indices],
"path": [data["path"][i] for i in valid_indices],
}
# Add fps and duration if available in data
if "fps" in data:
valid_data["fps"] = [data["fps"][i] for i in valid_indices]
if "duration" in data:
valid_data["duration"] = [
data["duration"][i] for i in valid_indices
]
batch_captions = valid_data["text"]
# Encode text using the standalone TextEncodingStage API
prompt_embeds_list, prompt_masks_list = (
self.prompt_encoding_stage.encode_text(
batch_captions,
server_args,
encoder_index=[0],
return_attention_mask=True,
)
)
prompt_embeds = prompt_embeds_list[0]
prompt_attention_masks = prompt_masks_list[0]
assert prompt_embeds.shape[0] == prompt_attention_masks.shape[0]
sampling_params = SamplingParams.from_pretrained(args.model_path)
# encode negative prompt for trajectory collection
if (
sampling_params.guidance_scale > 1
and sampling_params.negative_prompt is not None
):
negative_prompt_embeds_list, negative_prompt_masks_list = (
self.prompt_encoding_stage.encode_text(
sampling_params.negative_prompt,
server_args,
encoder_index=[0],
return_attention_mask=True,
)
)
negative_prompt_embed = negative_prompt_embeds_list[0][0]
negative_prompt_attention_mask = negative_prompt_masks_list[0][0]
else:
negative_prompt_embed = None
negative_prompt_attention_mask = None
trajectory_latents = []
trajectory_timesteps = []
trajectory_decoded = []
for i, (prompt_embed, prompt_attention_mask) in enumerate(
zip(prompt_embeds, prompt_attention_masks, strict=False)
):
prompt_embed = prompt_embed.unsqueeze(0)
prompt_attention_mask = prompt_attention_mask.unsqueeze(0)
# Collect the trajectory data (text-to-video generation)
batch = Req(
**shallow_asdict(sampling_params),
)
batch.prompt_embeds = [prompt_embed]
batch.prompt_attention_mask = [prompt_attention_mask]
batch.negative_prompt_embeds = [negative_prompt_embed]
batch.negative_attention_mask = [negative_prompt_attention_mask]
batch.num_inference_steps = 48
batch.return_trajectory_latents = True
# Enabling this will save the decoded trajectory videos.
# Used for debugging.
batch.return_trajectory_decoded = False
batch.height = args.max_height
batch.width = args.max_width
batch.fps = args.train_fps
batch.guidance_scale = 6.0
batch.do_classifier_free_guidance = True
result_batch = self.input_validation_stage(batch, server_args)
result_batch = self.timestep_preparation_stage(batch, server_args)
result_batch = self.latent_preparation_stage(
result_batch, server_args
)
result_batch = self.denoising_stage(result_batch, server_args)
result_batch = self.decoding_stage(result_batch, server_args)
trajectory_latents.append(result_batch.trajectory_latents.cpu())
trajectory_timesteps.append(result_batch.trajectory_timesteps.cpu())
trajectory_decoded.append(result_batch.trajectory_decoded)
# Prepare extra features for text-only processing
extra_features = {
"trajectory_latents": trajectory_latents,
"trajectory_timesteps": trajectory_timesteps,
}
if batch.return_trajectory_decoded:
for i, decoded_frames in enumerate(trajectory_decoded):
for j, decoded_frame in enumerate(decoded_frames):
save_decoded_latents_as_video(
decoded_frame,
f"decoded_videos/trajectory_decoded_{i}_{j}.mp4",
args.train_fps,
)
# Prepare batch data for Parquet dataset
batch_data: list[dict[str, Any]] = []
# Add progress bar for saving outputs
save_pbar = tqdm(
enumerate(valid_data["path"]),
desc="Saving outputs",
unit="item",
leave=False,
)
for idx, video_path in save_pbar:
video_name = os.path.basename(video_path).split(".")[0]
# Convert tensors to numpy arrays
text_embedding = prompt_embeds[idx].cpu().numpy()
# Get extra features for this sample
sample_extra_features = {}
if extra_features:
for key, value in extra_features.items():
if isinstance(value, torch.Tensor):
sample_extra_features[key] = value[idx].cpu().numpy()
else:
assert isinstance(value, list)
if isinstance(value[idx], torch.Tensor):
sample_extra_features[key] = (
value[idx].cpu().float().numpy()
)
else:
sample_extra_features[key] = value[idx]
# Create record for Parquet dataset (text-only ODE schema)
record: dict[str, Any] = ode_text_only_record_creator(
video_name=video_name,
text_embedding=text_embedding,
caption=valid_data["text"][idx],
trajectory_latents=sample_extra_features["trajectory_latents"],
trajectory_timesteps=sample_extra_features[
"trajectory_timesteps"
],
)
batch_data.append(record)
if batch_data:
write_pbar = tqdm(
total=1, desc="Writing to Parquet dataset", unit="batch"
)
table = records_to_table(batch_data, self.get_pyarrow_schema())
write_pbar.update(1)
write_pbar.close()
if not hasattr(self, "dataset_writer"):
self.dataset_writer = ParquetDatasetWriter(
out_dir=self.combined_parquet_dir,
samples_per_file=args.samples_per_file,
)
self.dataset_writer.append_table(table)
logger.info("Collected batch with %s samples", len(table))
if self.num_processed_samples >= args.flush_frequency:
written = self.dataset_writer.flush()
logger.info("Flushed %s samples to parquet", written)
self.num_processed_samples = 0
# Final flush for any remaining samples
if hasattr(self, "dataset_writer"):
written = self.dataset_writer.flush(write_remainder=True)
if written:
logger.info("Final flush wrote %s samples", written)
def forward(self, batch: Req, server_args: ServerArgs, args):
if not self.post_init_called:
self.post_init()
self.local_rank = int(os.getenv("RANK", 0))
os.makedirs(args.output_dir, exist_ok=True)
# Create directory for combined data
self.combined_parquet_dir = os.path.join(
args.output_dir, "combined_parquet_dataset"
)
os.makedirs(self.combined_parquet_dir, exist_ok=True)
# Loading dataset
train_dataset = gettextdataset(args)
self.preprocess_dataloader = DataLoader(
train_dataset,
batch_size=args.preprocess_video_batch_size,
num_workers=args.dataloader_num_workers,
)
self.preprocess_loader_iter = iter(self.preprocess_dataloader)
self.num_processed_samples = 0
# Add progress bar for video preprocessing
self.pbar = tqdm(
self.preprocess_loader_iter,
desc="Processing videos",
unit="batch",
disable=self.local_rank != 0,
)
# Initialize class variables for data sharing
self.video_data: dict[str, Any] = {} # Store video metadata and paths
self.latent_data: dict[str, Any] = {} # Store latent tensors
self.preprocess_text_and_trajectory(server_args, args)
EntryClass = PreprocessPipeline_ODE_Trajectory
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
T2V Data Preprocessing pipeline implementation.
This module contains an implementation of the T2V Data Preprocessing pipeline
using the modular pipeline architecture.
"""
from sglang.multimodal_gen.dataset.dataloader.schema import pyarrow_schema_t2v
from sglang.multimodal_gen.runtime.pipelines.preprocess.preprocess_pipeline_base import (
BasePreprocessPipeline,
)
class PreprocessPipeline_T2V(BasePreprocessPipeline):
"""T2V preprocessing pipeline implementation."""
_required_config_modules = ["text_encoder", "tokenizer", "vae"]
def get_pyarrow_schema(self):
"""Return the PyArrow schema for T2V pipeline."""
return pyarrow_schema_t2v
EntryClass = PreprocessPipeline_T2V
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Text-only Data Preprocessing pipeline implementation.
This module contains an implementation of the Text-only Data Preprocessing pipeline
using the modular pipeline architecture, based on the ODE Trajectory preprocessing.
"""
import os
from collections.abc import Iterator
from typing import Any
import torch
from torch.utils.data import DataLoader
from torchdata.stateful_dataloader import StatefulDataLoader
from tqdm import tqdm
from sglang.multimodal_gen.dataset import gettextdataset
from sglang.multimodal_gen.dataset.dataloader.parquet_io import (
ParquetDatasetWriter,
records_to_table,
)
from sglang.multimodal_gen.dataset.dataloader.record_schema import (
text_only_record_creator,
)
from sglang.multimodal_gen.dataset.dataloader.schema import pyarrow_schema_text_only
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.pipelines.preprocess.preprocess_pipeline_base import (
BasePreprocessPipeline,
)
from sglang.multimodal_gen.runtime.pipelines.stages import TextEncodingStage
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class PreprocessPipeline_Text(BasePreprocessPipeline):
"""Text-only preprocessing pipeline implementation."""
_required_config_modules = ["text_encoder", "tokenizer"]
preprocess_dataloader: StatefulDataLoader
preprocess_loader_iter: Iterator[dict[str, Any]]
pbar: Any
num_processed_samples: int = 0
def get_pyarrow_schema(self):
"""Return the PyArrow schema for text-only pipeline."""
return pyarrow_schema_text_only
def create_pipeline_stages(self, server_args: ServerArgs):
"""Set up pipeline stages with proper dependency injection."""
self.add_stage(
stage_name="prompt_encoding_stage",
stage=TextEncodingStage(
text_encoders=[self.get_module("text_encoder")],
tokenizers=[self.get_module("tokenizer")],
),
)
def preprocess_text_only(self, server_args: ServerArgs, args):
"""Preprocess text-only data."""
for batch_idx, data in enumerate(self.pbar):
if data is None:
continue
with torch.inference_mode():
# For text-only processing, we only need text data
# Filter out samples without text
valid_indices = []
for i, text in enumerate(data["text"]):
if text and text.strip(): # Check if text is not empty
valid_indices.append(i)
self.num_processed_samples += len(valid_indices)
if not valid_indices:
continue
# Create new batch with only valid samples (text-only)
valid_data = {
"text": [data["text"][i] for i in valid_indices],
"path": [data["path"][i] for i in valid_indices],
}
batch_captions = valid_data["text"]
# Encode text using the standalone TextEncodingStage API
prompt_embeds_list, prompt_masks_list = (
self.prompt_encoding_stage.encode_text(
batch_captions,
server_args,
encoder_index=[0],
return_attention_mask=True,
)
)
prompt_embeds = prompt_embeds_list[0]
prompt_attention_masks = prompt_masks_list[0]
assert prompt_embeds.shape[0] == prompt_attention_masks.shape[0]
logger.info("===== prompt_embeds: %s", prompt_embeds.shape)
logger.info(
"===== prompt_attention_masks: %s", prompt_attention_masks.shape
)
# Prepare batch data for Parquet dataset
batch_data = []
# Add progress bar for saving outputs
save_pbar = tqdm(
enumerate(valid_data["path"]),
desc="Saving outputs",
unit="item",
leave=False,
)
for idx, text_path in save_pbar:
text_name = os.path.basename(text_path).split(".")[0]
# Convert tensors to numpy arrays
text_embedding = prompt_embeds[idx].cpu().numpy()
# Create record for Parquet dataset (text-only schema)
record = text_only_record_creator(
text_name=text_name,
text_embedding=text_embedding,
caption=valid_data["text"][idx],
)
batch_data.append(record)
if batch_data:
write_pbar = tqdm(
total=1, desc="Writing to Parquet dataset", unit="batch"
)
table = records_to_table(batch_data, pyarrow_schema_text_only)
write_pbar.update(1)
write_pbar.close()
if not hasattr(self, "dataset_writer"):
self.dataset_writer = ParquetDatasetWriter(
out_dir=self.combined_parquet_dir,
samples_per_file=args.samples_per_file,
)
self.dataset_writer.append_table(table)
logger.info("Collected batch with %s samples", len(table))
if self.num_processed_samples >= args.flush_frequency:
written = self.dataset_writer.flush()
logger.info("Flushed %s samples to parquet", written)
self.num_processed_samples = 0
# Final flush for any remaining samples
if hasattr(self, "dataset_writer"):
written = self.dataset_writer.flush(write_remainder=True)
if written:
logger.info("Final flush wrote %s samples", written)
# Text-only record creation moved to sglang.multimodal_gen.dataset.dataloader.record_schema
def forward(self, batch: Req, server_args: ServerArgs, args):
if not self.post_init_called:
self.post_init()
self.local_rank = int(os.getenv("RANK", 0))
os.makedirs(args.output_dir, exist_ok=True)
# Create directory for combined data
self.combined_parquet_dir = os.path.join(
args.output_dir, "combined_parquet_dataset"
)
os.makedirs(self.combined_parquet_dir, exist_ok=True)
# Loading text dataset
train_dataset = gettextdataset(args)
self.preprocess_dataloader = DataLoader(
train_dataset,
batch_size=args.preprocess_video_batch_size,
num_workers=args.dataloader_num_workers,
)
self.preprocess_loader_iter = iter(self.preprocess_dataloader)
self.num_processed_samples = 0
# Add progress bar for text preprocessing
self.pbar = tqdm(
self.preprocess_loader_iter,
desc="Processing text",
unit="batch",
disable=self.local_rank != 0,
)
# Initialize class variables for data sharing
self.text_data: dict[str, Any] = {} # Store text metadata and paths
self.preprocess_text_only(server_args, args)
EntryClass = PreprocessPipeline_Text
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
import random
from collections.abc import Callable
from typing import cast
import numpy as np
import torch
import torchvision
from einops import rearrange
from torchvision import transforms
from sglang.multimodal_gen.configs.configs import VideoLoaderType
from sglang.multimodal_gen.dataset.transform import (
CenterCropResizeVideo,
TemporalRandomCrop,
)
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import (
PreprocessBatch,
Req,
)
from sglang.multimodal_gen.runtime.pipelines.stages.base import PipelineStage
from sglang.multimodal_gen.runtime.server_args import ServerArgs, WorkloadType
class VideoTransformStage(PipelineStage):
"""
Crop a video in temporal dimension.
"""
def __init__(
self,
train_fps: int,
num_frames: int,
max_height: int,
max_width: int,
do_temporal_sample: bool,
) -> None:
self.train_fps = train_fps
self.num_frames = num_frames
if do_temporal_sample:
self.temporal_sample_fn: Callable | None = TemporalRandomCrop(num_frames)
else:
self.temporal_sample_fn = None
self.video_transform = transforms.Compose(
[
CenterCropResizeVideo((max_height, max_width)),
]
)
def forward(self, batch: Req, server_args: ServerArgs) -> Req:
batch = cast(PreprocessBatch, batch)
assert isinstance(batch.fps, list)
assert isinstance(batch.num_frames, list)
if batch.data_type != "video":
return batch
if len(batch.video_loader) == 0:
raise ValueError("Video loader is not set")
video_pixel_batch = []
for i in range(len(batch.video_loader)):
frame_interval = batch.fps[i] / self.train_fps
start_frame_idx = 0
frame_indices = np.arange(
start_frame_idx, batch.num_frames[i], frame_interval
).astype(int)
if len(frame_indices) > self.num_frames:
if self.temporal_sample_fn is not None:
begin_index, end_index = self.temporal_sample_fn(len(frame_indices))
frame_indices = frame_indices[begin_index:end_index]
else:
frame_indices = frame_indices[: self.num_frames]
if (
server_args.preprocess_config.video_loader_type
== VideoLoaderType.TORCHCODEC
):
video = batch.video_loader[i].get_frames_at(frame_indices).data
elif (
server_args.preprocess_config.video_loader_type
== VideoLoaderType.TORCHVISION
):
video, _, _ = torchvision.io.read_video(
batch.video_loader[i], output_format="TCHW"
)
video = video[frame_indices]
else:
raise ValueError(
f"Invalid video loader type: {server_args.preprocess_config.video_loader_type}"
)
video = self.video_transform(video)
video_pixel_batch.append(video)
video_pixel_values = torch.stack(video_pixel_batch)
video_pixel_values = rearrange(video_pixel_values, "b t c h w -> b c t h w")
video_pixel_values = video_pixel_values.to(torch.uint8)
if server_args.workload_type == WorkloadType.I2V:
batch.pil_image = video_pixel_values[:, :, 0, :, :]
video_pixel_values = video_pixel_values.float() / 255.0
batch.latents = video_pixel_values
batch.num_frames = [video_pixel_values.shape[2]] * len(batch.video_loader)
batch.height = [video_pixel_values.shape[3]] * len(batch.video_loader)
batch.width = [video_pixel_values.shape[4]] * len(batch.video_loader)
return cast(Req, batch)
class TextTransformStage(PipelineStage):
"""
Process text data according to the cfg rate.
"""
def __init__(self, cfg_uncondition_drop_rate: float, seed: int) -> None:
self.cfg_rate = cfg_uncondition_drop_rate
self.rng = random.Random(seed)
def forward(self, batch: Req, server_args: ServerArgs) -> Req:
batch = cast(PreprocessBatch, batch)
prompts = []
for prompt in batch.prompt:
if not isinstance(prompt, list):
prompt = [prompt]
prompt = self.rng.choice(prompt)
prompt = prompt if self.rng.random() > self.cfg_rate else ""
prompts.append(prompt)
batch.prompt = prompts
return cast(Req, batch)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
import argparse
import os
from typing import Any
from sglang.multimodal_gen import PipelineConfig
from sglang.multimodal_gen.configs.models.vaes import WanVAEConfig
from sglang.multimodal_gen.runtime.architectures.preprocess.preprocess_pipeline_i2v import (
PreprocessPipeline_I2V,
)
from sglang.multimodal_gen.runtime.architectures.preprocess.preprocess_pipeline_ode_trajectory import (
PreprocessPipeline_ODE_Trajectory,
)
from sglang.multimodal_gen.runtime.architectures.preprocess.preprocess_pipeline_t2v import (
PreprocessPipeline_T2V,
)
from sglang.multimodal_gen.runtime.architectures.preprocess.preprocess_pipeline_text import (
PreprocessPipeline_Text,
)
from sglang.multimodal_gen.runtime.distributed import (
get_world_size,
maybe_init_distributed_environment_and_model_parallel,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
def main(args) -> None:
args.model_path = maybe_download_model(args.model_path)
maybe_init_distributed_environment_and_model_parallel(1, 1)
num_gpus = int(os.environ["WORLD_SIZE"])
assert num_gpus == 1, "Only support 1 GPU"
pipeline_config = PipelineConfig.from_pretrained(args.model_path)
kwargs: dict[str, Any] = {}
if args.preprocess_task == "text_only":
kwargs = {
"text_encoder_cpu_offload": False,
}
else:
# Full config for video/image processing
kwargs = {
"vae_precision": "fp32",
"vae_config": WanVAEConfig(load_encoder=True, load_decoder=True),
}
pipeline_config.update_config_from_dict(kwargs)
server_args = ServerArgs(
model_path=args.model_path,
num_gpus=get_world_size(),
dit_cpu_offload=False,
vae_cpu_offload=False,
text_encoder_cpu_offload=False,
pipeline_config=pipeline_config,
)
if args.preprocess_task == "t2v":
PreprocessPipeline = PreprocessPipeline_T2V
elif args.preprocess_task == "i2v":
PreprocessPipeline = PreprocessPipeline_I2V
elif args.preprocess_task == "text_only":
PreprocessPipeline = PreprocessPipeline_Text
elif args.preprocess_task == "ode_trajectory":
assert args.flow_shift is not None, "flow_shift is required for ode_trajectory"
server_args.pipeline_config.flow_shift = args.flow_shift
PreprocessPipeline = PreprocessPipeline_ODE_Trajectory
else:
raise ValueError(
f"Invalid preprocess task: {args.preprocess_task}. "
f"Valid options: t2v, i2v, ode_trajectory, text_only"
)
logger.info(
"Preprocess task: %s using %s",
args.preprocess_task,
PreprocessPipeline.__name__,
)
pipeline = PreprocessPipeline(args.model_path, server_args)
pipeline.forward(batch=None, server_args=server_args, args=args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# dataset & dataloader
parser.add_argument("--model_path", type=str, default="data/mochi")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--data_merge_path", type=str, required=True)
parser.add_argument("--num_frames", type=int, default=163)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=1,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--preprocess_video_batch_size",
type=int,
default=2,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument("--samples_per_file", type=int, default=64)
parser.add_argument(
"--flush_frequency",
type=int,
default=256,
help="how often to save to parquet files",
)
parser.add_argument(
"--num_latent_t", type=int, default=28, help="Number of latent timesteps."
)
parser.add_argument("--max_height", type=int, default=480)
parser.add_argument("--max_width", type=int, default=848)
parser.add_argument("--video_length_tolerance_range", type=int, default=2.0)
parser.add_argument("--group_frame", action="store_true") # TODO
parser.add_argument("--group_resolution", action="store_true") # TODO
parser.add_argument("--flow_shift", type=float, default=None)
parser.add_argument(
"--preprocess_task",
type=str,
default="t2v",
choices=["t2v", "i2v", "text_only", "ode_trajectory"],
help="Type of preprocessing task to run",
)
parser.add_argument("--train_fps", type=int, default=30)
parser.add_argument("--use_image_num", type=int, default=0)
parser.add_argument("--text_max_length", type=int, default=256)
parser.add_argument("--speed_factor", type=float, default=1.0)
parser.add_argument("--drop_short_ratio", type=float, default=1.0)
parser.add_argument("--do_temporal_sample", default=False, action="store_true")
# text encoder & vae & diffusion model
parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl")
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
parser.add_argument("--training_cfg_rate", type=float, default=0.0)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="The output directory where the model predictions and checkpoints will be written.",
)
args = parser.parse_args()
main(args)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
from sglang.multimodal_gen.runtime.distributed import (
maybe_init_distributed_environment_and_model_parallel,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.runtime.workflow.workflow_base import WorkflowBase
from sglang.multimodal_gen.utils import FlexibleArgumentParser
logger = init_logger(__name__)
def main(server_args: ServerArgs) -> None:
maybe_init_distributed_environment_and_model_parallel(1, 1)
preprocess_workflow_cls = WorkflowBase.get_workflow_cls(server_args)
preprocess_workflow = preprocess_workflow_cls(server_args)
preprocess_workflow.run()
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser = ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
main(server_args)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import (
ComposedPipelineBase,
)
from sglang.multimodal_gen.runtime.pipelines.preprocess.preprocess_stages import (
TextTransformStage,
VideoTransformStage,
)
from sglang.multimodal_gen.runtime.pipelines.stages import (
EncodingStage,
ImageEncodingStage,
TextEncodingStage,
)
from sglang.multimodal_gen.runtime.pipelines.stages.image_encoding import (
ImageVAEEncodingStage,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
class PreprocessPipelineI2V(ComposedPipelineBase):
_required_config_modules = [
"image_encoder",
"image_processor",
"text_encoder",
"tokenizer",
"vae",
]
def create_pipeline_stages(self, server_args: ServerArgs):
assert server_args.preprocess_config is not None
self.add_stage(
stage_name="text_transform_stage",
stage=TextTransformStage(
cfg_uncondition_drop_rate=server_args.preprocess_config.training_cfg_rate,
seed=server_args.preprocess_config.seed,
),
)
self.add_stage(
stage_name="prompt_encoding_stage",
stage=TextEncodingStage(
text_encoders=[self.get_module("text_encoder")],
tokenizers=[self.get_module("tokenizer")],
),
)
self.add_stage(
stage_name="video_transform_stage",
stage=VideoTransformStage(
train_fps=server_args.preprocess_config.train_fps,
num_frames=server_args.preprocess_config.num_frames,
max_height=server_args.preprocess_config.max_height,
max_width=server_args.preprocess_config.max_width,
do_temporal_sample=server_args.preprocess_config.do_temporal_sample,
),
)
if (
self.get_module("image_encoder") is not None
and self.get_module("image_processor") is not None
):
self.add_stage(
stage_name="image_encoding_stage",
stage=ImageEncodingStage(
image_encoder=self.get_module("image_encoder"),
image_processor=self.get_module("image_processor"),
),
)
self.add_stage(
stage_name="image_vae_encoding_stage",
stage=ImageVAEEncodingStage(
vae=self.get_module("vae"),
),
)
self.add_stage(
stage_name="video_encoding_stage",
stage=EncodingStage(
vae=self.get_module("vae"),
),
)
class PreprocessPipelineT2V(ComposedPipelineBase):
_required_config_modules = ["text_encoder", "tokenizer", "vae"]
def create_pipeline_stages(self, server_args: ServerArgs):
assert server_args.preprocess_config is not None
self.add_stage(
stage_name="text_transform_stage",
stage=TextTransformStage(
cfg_uncondition_drop_rate=server_args.preprocess_config.training_cfg_rate,
seed=server_args.preprocess_config.seed,
),
)
self.add_stage(
stage_name="prompt_encoding_stage",
stage=TextEncodingStage(
text_encoders=[self.get_module("text_encoder")],
tokenizers=[self.get_module("tokenizer")],
),
)
self.add_stage(
stage_name="video_transform_stage",
stage=VideoTransformStage(
train_fps=server_args.preprocess_config.train_fps,
num_frames=server_args.preprocess_config.num_frames,
max_height=server_args.preprocess_config.max_height,
max_width=server_args.preprocess_config.max_width,
do_temporal_sample=server_args.preprocess_config.do_temporal_sample,
),
)
self.add_stage(
stage_name="video_encoding_stage",
stage=EncodingStage(
vae=self.get_module("vae"),
),
)
EntryClass = [PreprocessPipelineI2V, PreprocessPipelineT2V]
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from sglang.multimodal_gen.runtime.distributed.communication_op import *
from sglang.multimodal_gen.runtime.distributed.group_coordinator import (
get_local_torch_device,
)
from sglang.multimodal_gen.runtime.distributed.parallel_state import (
cleanup_dist_env_and_memory,
get_dp_group,
get_dp_rank,
get_dp_world_size,
get_sp_group,
get_sp_parallel_rank,
get_sp_world_size,
get_tp_group,
get_tp_rank,
get_tp_world_size,
get_world_group,
get_world_rank,
get_world_size,
init_distributed_environment,
initialize_model_parallel,
maybe_init_distributed_environment_and_model_parallel,
model_parallel_is_initialized,
)
from sglang.multimodal_gen.runtime.distributed.utils import *
__all__ = [
# Initialization
"init_distributed_environment",
"initialize_model_parallel",
"cleanup_dist_env_and_memory",
"model_parallel_is_initialized",
"maybe_init_distributed_environment_and_model_parallel",
# World group
"get_world_group",
"get_world_rank",
"get_world_size",
# Data parallel group
"get_dp_group",
"get_dp_rank",
"get_dp_world_size",
# Sequence parallel group
"get_sp_group",
"get_sp_parallel_rank",
"get_sp_world_size",
# Tensor parallel group
"get_tp_group",
"get_tp_rank",
"get_tp_world_size",
# Get torch device
"get_local_torch_device",
]
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/communication_op.py
import torch
import torch.distributed
from sglang.multimodal_gen.runtime.distributed.parallel_state import (
get_cfg_group,
get_sp_group,
get_tp_group,
)
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_)
def tensor_model_parallel_all_gather(
input_: torch.Tensor, dim: int = -1
) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_tp_group().all_gather(input_, dim)
# TODO: remove model, make it sequence_parallel
def sequence_model_parallel_all_to_all_4D(
input_: torch.Tensor, scatter_dim: int = 2, gather_dim: int = 1
) -> torch.Tensor:
"""All-to-all communication of 4D tensors (e.g. QKV matrices) across sequence parallel group."""
return get_sp_group().all_to_all_4D(input_, scatter_dim, gather_dim)
def sequence_model_parallel_all_gather(
input_: torch.Tensor, dim: int = -1
) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_sp_group().all_gather(input_, dim)
def cfg_model_parallel_all_gather(
input_: torch.Tensor, dim: int = -1, separate_tensors: bool = False
) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_cfg_group().all_gather(input_, dim, separate_tensors)
def cfg_model_parallel_all_reduce(
input_: torch.Tensor,
op: torch._C._distributed_c10d.ReduceOp = torch._C._distributed_c10d.ReduceOp.SUM,
) -> torch.Tensor:
"""All-reduce the input tensor across CFG parallel group."""
return get_cfg_group().all_reduce(input_, op=op)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/base_device_communicator.py
from typing import Any
import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed import ProcessGroup, ReduceOp
class DistributedAutograd:
"""Collection of autograd functions for distributed operations.
This class provides custom autograd functions for distributed operations like all_reduce,
all_gather, and all_to_all. Each operation is implemented as a static inner class with
proper forward and backward implementations.
"""
class AllReduce(torch.autograd.Function):
"""Differentiable all_reduce operation.
The gradient of all_reduce is another all_reduce operation since the operation
combines values from all ranks equally.
"""
@staticmethod
def forward(
ctx: Any,
group: ProcessGroup,
input_: Tensor,
op: dist.ReduceOp | None = None,
) -> Tensor:
ctx.group = group
ctx.op = op
output = input_.clone()
dist.all_reduce(output, group=group, op=op)
return output
@staticmethod
def backward(ctx: Any, grad_output: Tensor) -> tuple[None, Tensor, None]:
grad_output = grad_output.clone()
dist.all_reduce(grad_output, group=ctx.group, op=ctx.op)
return None, grad_output, None
class AllGather(torch.autograd.Function):
"""Differentiable all_gather operation.
The operation gathers tensors from all ranks and concatenates them along a specified dimension.
The backward pass uses reduce_scatter to efficiently distribute gradients back to source ranks.
"""
@staticmethod
def forward(
ctx: Any, group: ProcessGroup, input_: Tensor, world_size: int, dim: int
) -> Tensor:
ctx.group = group
ctx.world_size = world_size
ctx.dim = dim
ctx.input_shape = input_.shape
input_size = input_.size()
output_size = (input_size[0] * world_size,) + input_size[1:]
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
dist.all_gather_into_tensor(output_tensor, input_, group=group)
output_tensor = output_tensor.reshape((world_size,) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(
input_size[:dim]
+ (world_size * input_size[dim],)
+ input_size[dim + 1 :]
)
return output_tensor
@staticmethod
def backward(ctx: Any, grad_output: Tensor) -> tuple[None, Tensor, None, None]:
# Split the gradient tensor along the gathered dimension
dim_size = grad_output.size(ctx.dim) // ctx.world_size
grad_chunks = grad_output.reshape(
grad_output.shape[: ctx.dim]
+ (ctx.world_size, dim_size)
+ grad_output.shape[ctx.dim + 1 :]
)
grad_chunks = grad_chunks.movedim(ctx.dim, 0)
# Each rank only needs its corresponding gradient
grad_input = torch.empty(
ctx.input_shape, dtype=grad_output.dtype, device=grad_output.device
)
dist.reduce_scatter_tensor(
grad_input, grad_chunks.contiguous(), group=ctx.group
)
return None, grad_input, None, None
class AllToAll4D(torch.autograd.Function):
"""Differentiable all_to_all operation specialized for 4D tensors.
This operation is particularly useful for attention operations where we need to
redistribute data across ranks for efficient parallel processing.
The operation supports two modes:
1. scatter_dim=2, gather_dim=1: Used for redistributing attention heads
2. scatter_dim=1, gather_dim=2: Used for redistributing sequence dimensions
"""
@staticmethod
def forward(
ctx: Any,
group: ProcessGroup,
input_: Tensor,
world_size: int,
scatter_dim: int,
gather_dim: int,
) -> Tensor:
ctx.group = group
ctx.world_size = world_size
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
if world_size == 1:
return input_
assert (
input_.dim() == 4
), f"input must be 4D tensor, got {input_.dim()} and shape {input_.shape}"
if scatter_dim == 2 and gather_dim == 1:
bs, shard_seqlen, hn, hd = input_.shape
seqlen = shard_seqlen * world_size
shard_hn = hn // world_size
input_ = input_.transpose(0, 2).contiguous() # hn, shard_seqlen, bs, hd
output = torch.empty_like(input_)
dist.all_to_all_single(
output, input_, group=group
) # hn, shard_seqlen, bs, hd
output = torch.cat(
output.split(shard_hn), dim=1
) # sharded hn, seqlen, bs, hd
output = output.transpose(
0, 2
).contiguous() # bs, seqlen, sharded_hn, hd
return output
elif scatter_dim == 1 and gather_dim == 2:
bs, seqlen, shard_hn, hd = input_.shape
hn = shard_hn * world_size
shard_seqlen = seqlen // world_size
input_ = input_.transpose(0, 2).contiguous() # shard_hn, seqlen, bs, hd
input_ = (
input_.reshape(shard_hn, world_size, shard_seqlen, bs, hd)
.transpose(0, 1)
.reshape(shard_hn * world_size, shard_seqlen, bs, hd)
.contiguous()
)
output = torch.empty_like(input_)
dist.all_to_all_single(output, input_, group=group)
output = output.transpose(
0, 2
).contiguous() # bs, seqlen, sharded_hn, hd
return output
else:
raise RuntimeError(
f"Invalid scatter_dim={scatter_dim}, gather_dim={gather_dim}. "
f"Only (scatter_dim=2, gather_dim=1) and (scatter_dim=1, gather_dim=2) are supported."
)
@staticmethod
def backward(
ctx: Any, grad_output: Tensor
) -> tuple[None, Tensor, None, None, None]:
if ctx.world_size == 1:
return None, grad_output, None, None, None
# For backward pass, we swap scatter_dim and gather_dim
output = DistributedAutograd.AllToAll4D.apply(
ctx.group, grad_output, ctx.world_size, ctx.gather_dim, ctx.scatter_dim
)
return None, output, None, None, None
class DeviceCommunicatorBase:
"""
Base class for device-specific communicator with autograd support.
It can use the `cpu_group` to initialize the communicator.
If the device has PyTorch integration (PyTorch can recognize its
communication backend), the `device_group` will also be given.
"""
def __init__(
self,
cpu_group: ProcessGroup,
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
):
self.device = device or torch.device("cpu")
self.cpu_group = cpu_group
self.device_group = device_group
self.unique_name = unique_name
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
def all_reduce(
self, input_: torch.Tensor, op: dist.ReduceOp | None = ReduceOp.SUM
) -> torch.Tensor:
"""Performs an all_reduce operation with gradient support."""
return DistributedAutograd.AllReduce.apply(self.device_group, input_, op)
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""Performs an all_gather operation with gradient support."""
if dim < 0:
dim += input_.dim()
return DistributedAutograd.AllGather.apply(
self.device_group, input_, self.world_size, dim
)
def all_to_all_4D(
self, input_: torch.Tensor, scatter_dim: int = 2, gather_dim: int = 1
) -> torch.Tensor:
"""Performs a 4D all-to-all operation with gradient support."""
return DistributedAutograd.AllToAll4D.apply(
self.device_group, input_, self.world_size, scatter_dim, gather_dim
)
def gather(
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> torch.Tensor | None:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size = self.world_size
assert (
-input_.dim() <= dim < input_.dim()
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(
input_, gather_list, dst=self.ranks[dst], group=self.device_group
)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(
self, size: torch.Size, dtype: torch.dtype, src: int | None = None
) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self) -> None:
pass
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/cpu_communicator.py
import os
import torch
from torch.distributed import ProcessGroup
from .base_device_communicator import DeviceCommunicatorBase
class CpuCommunicator(DeviceCommunicatorBase):
def __init__(
self,
cpu_group: ProcessGroup,
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
):
from sglang.multimodal_gen.runtime.platforms import current_platform
from sglang.multimodal_gen.runtime.platforms.interface import CpuArchEnum
super().__init__(cpu_group, device, device_group, unique_name)
self.dist_module = torch.distributed
if (
(current_platform.get_cpu_architecture() == CpuArchEnum.X86)
and hasattr(torch.ops._C, "init_shm_manager")
and unique_name.startswith("tp")
):
self.dist_module = _CPUSHMDistributed(self)
def all_reduce(
self,
input_: torch.Tensor,
op: torch.distributed.ReduceOp | None = torch.distributed.ReduceOp.SUM,
) -> torch.Tensor:
self.dist_module.all_reduce(input_, group=self.device_group, op=op)
return input_
def gather(
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> torch.Tensor | None:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size = self.world_size
assert (
-input_.dim() <= dim < input_.dim()
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
self.dist_module.gather(
input_, gather_list, dst=self.ranks[dst], group=self.device_group
)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * self.world_size,) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
# All-gather.
self.dist_module.all_gather_into_tensor(
output_tensor, input_, group=self.device_group
)
# Reshape
output_tensor = output_tensor.reshape((self.world_size,) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(
input_size[:dim]
+ (self.world_size * input_size[dim],)
+ input_size[dim + 1 :]
)
return output_tensor
class _CPUSHMDistributed:
def __init__(self, communicator: CpuCommunicator):
instance_identifier = os.environ["VLLM_DIST_IDENT"]
unique_name = communicator.unique_name
instance_identifier = f"{instance_identifier}-{unique_name}"
self.communicator = communicator
group_ranks = [str(rank) for rank in self.communicator.ranks]
shm_group_identifier = f"[{'-'.join(group_ranks)}]"
self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm"
self.handle = self._init_cpu_shm()
def _init_cpu_shm(self) -> int:
handle = torch.ops._C.init_shm_manager(
self.group_name,
self.communicator.world_size,
self.communicator.rank,
)
torch.distributed.barrier(self.communicator.device_group)
torch.ops._C.join_shm_manager(
handle,
self.group_name,
)
torch.distributed.barrier(self.communicator.device_group)
return int(handle)
def all_reduce(
self, input: torch.Tensor, group: ProcessGroup | None = None
) -> None:
torch.ops._C.shm_allreduce(self.handle, input)
def gather(
self,
input: torch.Tensor,
gather_list: list[torch.Tensor] | None,
dst: int = -1,
group: ProcessGroup | None = None,
) -> None:
# Note: different from the torch gather, here we use local dst rank.
torch.ops._C.shm_gather(
self.handle,
input,
gather_list,
torch.distributed.get_group_rank(group, dst),
)
def all_gather_into_tensor(
self,
output: torch.Tensor,
input: torch.Tensor,
group: ProcessGroup | None = None,
) -> None:
torch.ops._C.shm_all_gather(self.handle, input, output)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/cuda_communicator.py
import torch
from torch.distributed import ProcessGroup
from sglang.multimodal_gen.runtime.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase,
)
class CudaCommunicator(DeviceCommunicatorBase):
def __init__(
self,
cpu_group: ProcessGroup,
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
):
super().__init__(cpu_group, device, device_group, unique_name)
from sglang.multimodal_gen.runtime.distributed.device_communicators.pynccl import (
PyNcclCommunicator,
)
self.pynccl_comm: PyNcclCommunicator | None = None
if self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
device=self.device,
)
def all_reduce(self, input_, op: torch.distributed.ReduceOp | None = None):
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_, op=op)
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group, op=op)
return out
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.send(tensor, dst)
else:
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(
self, size: torch.Size, dtype: torch.dtype, src: int | None = None
) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.recv(tensor, src)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self) -> None:
if self.pynccl_comm is not None:
self.pynccl_comm = None
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/pynccl.py
# ===================== import region =====================
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp
from sglang.multimodal_gen.runtime.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary,
buffer_type,
cudaStream_t,
ncclComm_t,
ncclDataTypeEnum,
ncclRedOpTypeEnum,
ncclUniqueId,
)
from sglang.multimodal_gen.runtime.distributed.utils import StatelessProcessGroup
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import current_stream
logger = init_logger(__name__)
class PyNcclCommunicator:
def __init__(
self,
group: ProcessGroup | StatelessProcessGroup,
device: int | str | torch.device,
library_path: str | None = None,
):
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the PyNcclCommunicator to. If None,
it will be bind to f"cuda:{local_rank}".
library_path: the path to the NCCL library. If None, it will
use the default library path.
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
"""
if not isinstance(group, StatelessProcessGroup):
assert dist.is_initialized()
assert (
dist.get_backend(group) != dist.Backend.NCCL
), "PyNcclCommunicator should be attached to a non-NCCL group."
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
else:
self.rank = group.rank
self.world_size = group.world_size
self.group = group
# if world_size == 1, no need to create communicator
if self.world_size == 1:
self.available = False
self.disabled = True
return
try:
self.nccl = NCCLLibrary(library_path)
except Exception:
# disable because of missing NCCL library
# e.g. in a non-GPU environment
self.available = False
self.disabled = True
return
self.available = True
self.disabled = False
logger.info("sgl-diffusion is using nccl==%s", self.nccl.ncclGetVersion())
if self.rank == 0:
# get the unique id from NCCL
self.unique_id = self.nccl.ncclGetUniqueId()
else:
# construct an empty unique id
self.unique_id = ncclUniqueId()
if not isinstance(group, StatelessProcessGroup):
tensor = torch.ByteTensor(list(self.unique_id.internal))
ranks = dist.get_process_group_ranks(group)
# arg `src` in `broadcast` is the global rank
dist.broadcast(tensor, src=ranks[0], group=group)
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
else:
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
# nccl communicator and stream will use this device
# `torch.cuda.device` is a context manager that changes the
# current cuda device to the specified one
with torch.cuda.device(device):
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
self.world_size, self.unique_id, self.rank
)
stream = current_stream()
# A small all_reduce for warmup.
data = torch.zeros(1, device=device)
self.all_reduce(data)
if stream is not None:
stream.synchronize()
del data
def all_reduce(
self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
) -> torch.Tensor:
if self.disabled:
return None
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert in_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {in_tensor.device}"
)
out_tensor = torch.empty_like(in_tensor)
if stream is None:
stream = current_stream()
self.nccl.ncclAllReduce(
buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
ncclDataTypeEnum.from_torch(in_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
self.comm,
cudaStream_t(stream.cuda_stream),
)
return out_tensor
def all_gather(
self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}"
)
if stream is None:
stream = current_stream()
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
input_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
self.comm,
cudaStream_t(stream.cuda_stream),
)
def reduce_scatter(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None,
):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}"
)
if stream is None:
stream = current_stream()
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
output_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
self.comm,
cudaStream_t(stream.cuda_stream),
)
def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = current_stream()
self.nccl.ncclSend(
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
dst,
self.comm,
cudaStream_t(stream.cuda_stream),
)
def recv(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = current_stream()
self.nccl.ncclRecv(
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
src,
self.comm,
cudaStream_t(stream.cuda_stream),
)
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = current_stream()
if src == self.rank:
sendbuff = buffer_type(tensor.data_ptr())
# NCCL requires the sender also to have a receive buffer
recvbuff = buffer_type(tensor.data_ptr())
else:
sendbuff = buffer_type()
recvbuff = buffer_type(tensor.data_ptr())
self.nccl.ncclBroadcast(
sendbuff,
recvbuff,
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
src,
self.comm,
cudaStream_t(stream.cuda_stream),
)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/pynccl_wrapper.py
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
# often gets stuck when initializing the NCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
# contains many other potential cuda APIs, that are not allowed during
# capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
# doable, but we often encounter issues related with nccl versions, and need
# to switch between different versions of NCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of NCCL by
# changing the environment variable `SGL_DIFFUSION_NCCL_SO_PATH`, or the `so_file`
# variable in the code.
# TODO(will): support SGL_DIFFUSION_NCCL_SO_PATH
import ctypes
import platform
from dataclasses import dataclass
from typing import Any
import torch
from torch.distributed import ReduceOp
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import find_nccl_library
logger = init_logger(__name__)
# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
ncclResult_t = ctypes.c_int
ncclComm_t = ctypes.c_void_p
class ncclUniqueId(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
cudaStream_t = ctypes.c_void_p
buffer_type = ctypes.c_void_p
ncclDataType_t = ctypes.c_int
class ncclDataTypeEnum:
ncclInt8 = 0
ncclChar = 0
ncclUint8 = 1
ncclInt32 = 2
ncclInt = 2
ncclUint32 = 3
ncclInt64 = 4
ncclUint64 = 5
ncclFloat16 = 6
ncclHalf = 6
ncclFloat32 = 7
ncclFloat = 7
ncclFloat64 = 8
ncclDouble = 8
ncclBfloat16 = 9
ncclNumTypes = 10
@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.ncclInt8
if dtype == torch.uint8:
return cls.ncclUint8
if dtype == torch.int32:
return cls.ncclInt32
if dtype == torch.int64:
return cls.ncclInt64
if dtype == torch.float16:
return cls.ncclFloat16
if dtype == torch.float32:
return cls.ncclFloat32
if dtype == torch.float64:
return cls.ncclFloat64
if dtype == torch.bfloat16:
return cls.ncclBfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
ncclRedOp_t = ctypes.c_int
class ncclRedOpTypeEnum:
ncclSum = 0
ncclProd = 1
ncclMax = 2
ncclMin = 3
ncclAvg = 4
ncclNumOps = 5
@classmethod
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.ncclSum
if op == ReduceOp.PRODUCT:
return cls.ncclProd
if op == ReduceOp.MAX:
return cls.ncclMax
if op == ReduceOp.MIN:
return cls.ncclMin
if op == ReduceOp.AVG:
return cls.ncclAvg
raise ValueError(f"Unsupported op: {op}")
@dataclass
class Function:
name: str
restype: Any
argtypes: list[Any]
class NCCLLibrary:
exported_functions = [
# const char* ncclGetErrorString(ncclResult_t result)
Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
# ncclResult_t ncclGetVersion(int *version);
Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]),
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]),
# ncclResult_t ncclCommInitRank(
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
# note that ncclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
Function(
"ncclCommInitRank",
ncclResult_t,
[ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int],
),
# ncclResult_t ncclAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function(
"ncclAllReduce",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ncclRedOp_t,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclAllGather(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function(
"ncclAllGather",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclReduceScatter(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function(
"ncclReduceScatter",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ncclRedOp_t,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclSend(
# const void* sendbuff, size_t count, ncclDataType_t datatype,
# int dest, ncclComm_t comm, cudaStream_t stream);
Function(
"ncclSend",
ncclResult_t,
[
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ctypes.c_int,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclRecv(
# void* recvbuff, size_t count, ncclDataType_t datatype,
# int src, ncclComm_t comm, cudaStream_t stream);
Function(
"ncclRecv",
ncclResult_t,
[
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ctypes.c_int,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclBroadcast(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, int root, ncclComm_t comm,
# cudaStream_t stream);
Function(
"ncclBroadcast",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ctypes.c_int,
ncclComm_t,
cudaStream_t,
],
),
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
# it is better not to call it at all.
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
def __init__(self, so_file: str | None = None):
so_file = so_file or find_nccl_library()
try:
if so_file not in NCCLLibrary.path_to_dict_mapping:
lib = ctypes.CDLL(so_file)
NCCLLibrary.path_to_library_cache[so_file] = lib
self.lib = NCCLLibrary.path_to_library_cache[so_file]
except Exception as e:
logger.error(
"Failed to load NCCL library from %s ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the nccl library might not exist, be corrupted "
"or it does not support the current platform %s."
"If you already have the library, please set the "
"environment variable SGL_DIFFUSION_NCCL_SO_PATH"
" to point to the correct nccl library path.",
so_file,
platform.platform(),
)
raise e
if so_file not in NCCLLibrary.path_to_dict_mapping:
_funcs: dict[str, Any] = {}
for func in NCCLLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]
def ncclGetErrorString(self, result: ncclResult_t) -> str:
return str(self._funcs["ncclGetErrorString"](result).decode("utf-8"))
def NCCL_CHECK(self, result: ncclResult_t) -> None:
if result != 0:
error_str = self.ncclGetErrorString(result)
raise RuntimeError(f"NCCL error: {error_str}")
def ncclGetVersion(self) -> str:
version = ctypes.c_int()
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
version_str = str(version.value)
# something like 21903 --> "2.19.3"
major = version_str[0].lstrip("0")
minor = version_str[1:3].lstrip("0")
patch = version_str[3:].lstrip("0")
return f"{major}.{minor}.{patch}"
def ncclGetUniqueId(self) -> ncclUniqueId:
unique_id = ncclUniqueId()
self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id)))
return unique_id
def ncclCommInitRank(
self, world_size: int, unique_id: ncclUniqueId, rank: int
) -> ncclComm_t:
comm = ncclComm_t()
self.NCCL_CHECK(
self._funcs["ncclCommInitRank"](
ctypes.byref(comm), world_size, unique_id, rank
)
)
return comm
def ncclAllReduce(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
op: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(
self._funcs["ncclAllReduce"](
sendbuff, recvbuff, count, datatype, op, comm, stream
)
)
def ncclReduceScatter(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
op: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(
self._funcs["ncclReduceScatter"](
sendbuff, recvbuff, count, datatype, op, comm, stream
)
)
def ncclAllGather(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
# `datatype` actually should be `ncclDataType_t`
# which is an aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(
self._funcs["ncclAllGather"](
sendbuff, recvbuff, count, datatype, comm, stream
)
)
def ncclSend(
self,
sendbuff: buffer_type,
count: int,
datatype: int,
dest: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
self.NCCL_CHECK(
self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream)
)
def ncclRecv(
self,
recvbuff: buffer_type,
count: int,
datatype: int,
src: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
self.NCCL_CHECK(
self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)
)
def ncclBroadcast(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
root: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
self.NCCL_CHECK(
self._funcs["ncclBroadcast"](
sendbuff, recvbuff, count, datatype, root, comm, stream
)
)
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
__all__ = [
"NCCLLibrary",
"ncclDataTypeEnum",
"ncclRedOpTypeEnum",
"ncclUniqueId",
"ncclComm_t",
"cudaStream_t",
"buffer_type",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment