Commit e2778d0d authored by litzh's avatar litzh
Browse files

Initial commit

parents
Pipeline #3370 canceled with stages
"""
Wan2.2 distilled model image-to-video generation example.
This example demonstrates how to use LightX2V with Wan2.2 distilled model for I2V generation.
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for Wan2.2 distilled I2V task
# For wan2.1, use model_cls="wan2.1_distill"
pipe = LightX2VPipeline(
model_path="/path/to/wan2.2/Wan2.2-I2V-A14B",
model_cls="wan2.2_moe_distill",
task="i2v",
# Distilled weights: For wan2.1, only need to specify dit_original_ckpt="/path/to/wan2.1_i2v_720p_lightx2v_4step.safetensors"
low_noise_original_ckpt="/path/to/wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors",
high_noise_original_ckpt="/path/to/wan2.2_i2v_A14b_high_noise_lightx2v_4step_1030.safetensors",
)
# Alternative: create generator from config JSON file
# pipe.create_generator(
# config_json="../configs/wan22/wan_moe_i2v_distill.json"
# )
# Enable offloading to significantly reduce VRAM usage
# Suitable for RTX 30/40/50 consumer GPUs
pipe.enable_offload(
cpu_offload=True,
offload_granularity="block",
text_encoder_offload=True,
image_encoder_offload=False,
vae_offload=False,
)
# Create generator manually with specified parameters
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=4,
height=480, # Can be set to 720 for higher resolution
width=832, # Can be set to 1280 for higher resolution
num_frames=81,
guidance_scale=1,
sample_shift=5.0,
)
seed = 42
prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
save_result_path = "/path/to/save_results/output.mp4"
image_path = "/path/to/img_0.jpg"
pipe.generate(
seed=seed,
image_path=image_path,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
"""
Wan2.1 image-to-video generation example.
This example demonstrates how to use LightX2V with Wan2.1 model for I2V generation.
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for Wan2.1 I2V task
# For wan2.1, use model_cls="wan2.1"
pipe = LightX2VPipeline(
model_path="/path/to/Wan2.1-I2V-14B-480P",
model_cls="wan2.1_distill",
task="i2v",
)
# Alternative: create generator from config JSON file
# pipe.create_generator(
# config_json="../configs/wan22/wan_moe_i2v.json"
# )
# Enable offloading to significantly reduce VRAM usage with minimal speed impact
# Suitable for RTX 30/40/50 consumer GPUs
pipe.enable_offload(
cpu_offload=True,
offload_granularity="block", # For Wan models, supports both "block" and "phase"
text_encoder_offload=True,
image_encoder_offload=False,
vae_offload=False,
)
pipe.enable_quantize(dit_quantized=True, dit_quantized_ckpt="lightx2v/Wan-NVFP4/wan2.1_i2v_480p_nvfp4_lightx2v_4step.safetensors", quant_scheme="nvfp4")
# Create generator manually with specified parameters
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=4,
height=480, # Can be set to 720 for higher resolution
width=832, # Can be set to 1280 for higher resolution
num_frames=81,
guidance_scale=1.0, # For wan2.1, guidance_scale is a scalar (e.g., 5.0)
sample_shift=5.0,
)
# Generation parameters
seed = 42
prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
image_path = "/path/to/img_0.jpg"
save_result_path = "/path/to/save_results/output.mp4"
# Generate video
pipe.generate(
seed=seed,
image_path=image_path,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
"""
Wan2.2 distilled model with LoRA image-to-video generation example.
This example demonstrates how to use LightX2V with Wan2.2 distilled model and LoRA for I2V generation.
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for Wan2.2 distilled I2V task with LoRA
# For wan2.1, use model_cls="wan2.1_distill"
pipe = LightX2VPipeline(
model_path="/path/to/wan2.2/Wan2.2-I2V-A14B",
model_cls="wan2.2_moe_distill",
task="i2v",
)
# Alternative: create generator from config JSON file
# pipe.create_generator(
# config_json="../configs/wan22/wan_moe_i2v_distill_with_lora.json"
# )
# Enable offloading to significantly reduce VRAM usage
# Suitable for RTX 30/40/50 consumer GPUs
pipe.enable_offload(
cpu_offload=True,
offload_granularity="block",
text_encoder_offload=True,
image_encoder_offload=False,
vae_offload=False,
)
# Load distilled LoRA weights
pipe.enable_lora(
[
{"name": "high_noise_model", "path": "lightx2v/Wan2.2-Distill-Loras/wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors", "strength": 1.0},
{"name": "low_noise_model", "path": "lightx2v/Wan2.2-Distill-Loras/wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors", "strength": 1.0},
],
lora_dynamic_apply=False, # Support inference with LoRA weights, save memory but slower, default is False
)
# Create generator with specified parameters
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=4,
height=480, # Can be set to 720 for higher resolution
width=832, # Can be set to 1280 for higher resolution
num_frames=81,
guidance_scale=1,
sample_shift=5.0,
)
seed = 42
prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
save_result_path = "/path/to/save_results/output.mp4"
image_path = "/path/to/img_0.jpg"
pipe.generate(
seed=seed,
image_path=image_path,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
"""
Wan2.1 text-to-video generation example.
This example demonstrates how to use LightX2V with Wan2.1 model for T2V generation.
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for Wan2.1 T2V task
pipe = LightX2VPipeline(
model_path="/path/to/Wan2.1-T2V-14B",
model_cls="wan2.1",
task="t2v",
)
# Alternative: create generator from config JSON file
# pipe.create_generator(config_json="../configs/wan/wan_t2v.json")
# Create generator with specified parameters
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=50,
height=480, # Can be set to 720 for higher resolution
width=832, # Can be set to 1280 for higher resolution
num_frames=81,
guidance_scale=5.0,
sample_shift=5.0,
)
seed = 42
prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
save_result_path = "/path/to/save_results/output.mp4"
pipe.generate(
seed=seed,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
"""
Wan2.1 text-to-video generation example.
This example demonstrates how to use LightX2V with Wan2.1 model for T2V generation.
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for Wan2.1 T2V task
pipe = LightX2VPipeline(
model_path="/path/to/Wan2.1-T2V-1.3B",
model_cls="wan2.1_distill",
task="t2v",
)
# Alternative: create generator from config JSON file
# pipe.create_generator(config_json="../configs/wan/wan_t2v.json")
pipe.enable_quantize(dit_quantized=True, dit_quantized_ckpt="lightx2v/Wan-NVFP4/wan2.1_t2v_1_3b_nvfp4_lightx2v_4step.safetensors", quant_scheme="nvfp4")
# Create generator with specified parameters
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=4,
height=480, # Can be set to 720 for higher resolution
width=832, # Can be set to 1280 for higher resolution
num_frames=81,
guidance_scale=1.0,
sample_shift=5.0,
)
seed = 42
prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
save_result_path = "/path/to/save_results/output.mp4"
pipe.generate(
seed=seed,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
"""
Wan2.1 text-to-video generation example.
This example demonstrates how to use LightX2V with Wan2.2-5B model for T2V generation.
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for Wan2.2-5b TI2V task
pipe = LightX2VPipeline(
model_path="/path/to/Wan2.2-TI2V-5B",
model_cls="wan2.2",
task="t2v",
)
# Alternative: create generator from config JSON file
# pipe.create_generator(config_json="../configs/wan22/wan_ti2v_t2v.json")
# Create generator with specified parameters
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=50,
height=704,
width=1280,
num_frames=121,
fps=25,
guidance_scale=5.0,
sample_shift=5.0,
)
seed = 42
prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
save_result_path = "/path/to/save_results/output.mp4"
pipe.generate(
seed=seed,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
"""
Wan2.1 VACE (Video Animate Character Exchange) generation example.
This example demonstrates how to use LightX2V with Wan2.1 VACE model for character exchange in videos.
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for VACE task
pipe = LightX2VPipeline(
model_path="/path/to/Wan2.1-VACE-1.3B",
src_ref_images="../assets/inputs/imgs/girl.png,../assets/inputs/imgs/snake.png",
model_cls="wan2.1_vace",
task="vace",
)
# Alternative: create generator from config JSON file
# pipe.create_generator(
# config_json="../configs/wan/wan_vace.json"
# )
# Optional: enable offloading to significantly reduce VRAM usage
# Suitable for RTX 30/40/50 consumer GPUs
# pipe.enable_offload(
# cpu_offload=True,
# offload_granularity="block",
# text_encoder_offload=True,
# image_encoder_offload=False,
# vae_offload=False,
# )
# Create generator with specified parameters
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=40,
height=480, # Can be set to 720 for higher resolution
width=832, # Can be set to 1280 for higher resolution
num_frames=81,
guidance_scale=5,
sample_shift=16,
)
seed = 42
prompt = "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
save_result_path = "/path/to/save_results/output.mp4"
pipe.generate(
seed=seed,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
import json
import os
# Paths
CONFIG_PATH = "/workspace/LightX2V/configs/worldplay/worldplay_ar_i2v_480p.json"
MODEL_PATH = "/data/nvme1/models/hunyuan/HunyuanVideo-1.5"
ACTION_CKPT = "/data/nvme1/models/hunyuan/HY-WorldPlay/ar_model/diffusion_pytorch_model.safetensors"
IMAGE_PATH = "/workspace/HY-WorldPlay/assets/img/test.png"
OUTPUT_PATH = "/workspace/LightX2V/save_results/HY-WorldPlay/"
# Input parameters
PROMPT = "A paved pathway leads towards a stone arch bridge spanning a calm body of water. Lush green trees and foliage line the path and the far bank of the water. A traditional-style pavilion with a tiered, reddish-brown roof sits on the far shore. The water reflects the surrounding greenery and the sky. The scene is bathed in soft, natural light, creating a tranquil and serene atmosphere."
SEED = 1
POSE = "d-31"
os.makedirs(OUTPUT_PATH, exist_ok=True)
def main():
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
from lightx2v.utils.lockable_dict import LockableDict
from lightx2v.utils.registry_factory import RUNNER_REGISTER
# Load config from JSON
with open(CONFIG_PATH, "r") as f:
config_dict = json.load(f)
# Add runtime paths
config_dict["model_path"] = MODEL_PATH
config_dict["action_ckpt"] = ACTION_CKPT
config_dict["transformer_model_path"] = os.path.join(MODEL_PATH, "transformer/480p_i2v")
config = LockableDict(config_dict)
runner = RUNNER_REGISTER[config["model_cls"]](config)
runner.init_modules()
# Prepare input info
input_data = {
"seed": SEED,
"prompt": PROMPT,
"prompt_enhanced": "",
"negative_prompt": "",
"image_path": IMAGE_PATH,
"save_result_path": os.path.join(OUTPUT_PATH, "worldplay_ar_test.mp4"),
"return_result_tensor": False,
"pose": POSE,
}
input_info = init_empty_input_info("i2v")
update_input_info_from_dict(input_info, input_data)
result = runner.run_pipeline(input_info)
return result
if __name__ == "__main__":
main()
import json
import os
# Paths
CONFIG_PATH = "/workspace/LightX2V/configs/worldplay/worldplay_bi_i2v_480p.json"
MODEL_PATH = "/data/nvme1/models/hunyuan/hf_cache/hub/models--tencent--HunyuanVideo-1.5/snapshots/9b49404b3f5df2a8f0b31df27a0c7ab872e7b038"
ACTION_CKPT = "/data/nvme1/models/hunyuan/HY-WorldPlay/bidirectional_model/diffusion_pytorch_model.safetensors"
IMAGE_PATH = "/workspace/HY-WorldPlay/assets/img/test.png"
OUTPUT_PATH = "/workspace/LightX2V/save_results/HY-WorldPlay/"
# Input parameters
PROMPT = "A paved pathway leads towards a stone arch bridge spanning a calm body of water. Lush green trees and foliage line the path and the far bank of the water. A traditional-style pavilion with a tiered, reddish-brown roof sits on the far shore. The water reflects the surrounding greenery and the sky. The scene is bathed in soft, natural light, creating a tranquil and serene atmosphere."
SEED = 1
POSE = "s-31"
os.makedirs(OUTPUT_PATH, exist_ok=True)
def main():
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
from lightx2v.utils.lockable_dict import LockableDict
from lightx2v.utils.registry_factory import RUNNER_REGISTER
# Load config from JSON
with open(CONFIG_PATH, "r") as f:
config_dict = json.load(f)
# Add runtime paths
config_dict["model_path"] = MODEL_PATH
config_dict["action_ckpt"] = ACTION_CKPT
config_dict["transformer_model_path"] = os.path.join(MODEL_PATH, "transformer/480p_i2v")
config = LockableDict(config_dict)
runner = RUNNER_REGISTER[config["model_cls"]](config)
runner.init_modules()
# Prepare input info
input_data = {
"seed": SEED,
"prompt": PROMPT,
"prompt_enhanced": "",
"negative_prompt": "",
"image_path": IMAGE_PATH,
"save_result_path": os.path.join(OUTPUT_PATH, "worldplay_bi_test.mp4"),
"return_result_tensor": False,
"pose": POSE,
}
input_info = init_empty_input_info("i2v")
update_input_info_from_dict(input_info, input_data)
result = runner.run_pipeline(input_info)
return result
if __name__ == "__main__":
main()
import json
import os
# Paths
CONFIG_PATH = "/workspace/LightX2V/configs/worldplay/worldplay_distill_i2v_480p.json"
MODEL_PATH = "/data/nvme1/models/hunyuan/hf_cache/hub/models--tencent--HunyuanVideo-1.5/snapshots/9b49404b3f5df2a8f0b31df27a0c7ab872e7b038"
ACTION_CKPT = "/data/nvme1/models/hunyuan/HY-WorldPlay/ar_distilled_action_model/diffusion_pytorch_model.safetensors"
IMAGE_PATH = "/workspace/HY-WorldPlay/assets/img/test.png"
OUTPUT_PATH = "/workspace/LightX2V/save_results/HY-WorldPlay/"
# Input parameters
PROMPT = "A paved pathway leads towards a stone arch bridge spanning a calm body of water. Lush green trees and foliage line the path and the far bank of the water. A traditional-style pavilion with a tiered, reddish-brown roof sits on the far shore. The water reflects the surrounding greenery and the sky. The scene is bathed in soft, natural light, creating a tranquil and serene atmosphere. The pathway is composed of large, rectangular stones, and the bridge is constructed of light gray stone. The overall composition emphasizes the peaceful and harmonious nature of the landscape."
SEED = 1
POSE = "w-31"
os.makedirs(OUTPUT_PATH, exist_ok=True)
def main():
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
from lightx2v.utils.lockable_dict import LockableDict
from lightx2v.utils.registry_factory import RUNNER_REGISTER
# Load config from JSON
with open(CONFIG_PATH, "r") as f:
config_dict = json.load(f)
# Add runtime paths
config_dict["model_path"] = MODEL_PATH
config_dict["action_ckpt"] = ACTION_CKPT
config_dict["transformer_model_path"] = os.path.join(MODEL_PATH, "transformer/480p_i2v")
config = LockableDict(config_dict)
runner = RUNNER_REGISTER[config["model_cls"]](config)
runner.init_modules()
# Prepare input info
input_data = {
"seed": SEED,
"prompt": PROMPT,
"prompt_enhanced": "",
"negative_prompt": "",
"image_path": IMAGE_PATH,
"save_result_path": os.path.join(OUTPUT_PATH, "worldplay_distill_test.mp4"),
"return_result_tensor": False,
"pose": POSE,
}
input_info = init_empty_input_info("i2v")
update_input_info_from_dict(input_info, input_data)
result = runner.run_pipeline(input_info)
return result
if __name__ == "__main__":
main()
"""
Z-Image image-to-image generation example.
This example demonstrates how to use LightX2V with Z-Image-Turbo model for T2I generation.
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for Z-Image-edit T2I task
pipe = LightX2VPipeline(
model_path="Tongyi-MAI/Z-Image-Turbo",
model_cls="z_image",
task="t2i",
)
# Alternative: create generator from config JSON file
# pipe.create_generator(
# config_json="../configs/z_image/z_image_turbo_t2i.json"
# )
# Load fp8 distilled weights (and int4 Qwen3 model (optional))
pipe.enable_quantize(
dit_quantized=True,
dit_quantized_ckpt="lightx2v/Z-Image-Turbo-Quantized/z_image_turbo_scaled_fp8_e4m3fn.safetensors",
quant_scheme="fp8-sgl",
# text_encoder_quantized=True,
# text_encoder_quantized_ckpt="JunHowie/Qwen3-4B-GPTQ-Int4",
# text_encoder_quant_scheme="int4"
)
# Enable offloading to significantly reduce VRAM usage with minimal speed impact
# Suitable for RTX 30/40/50 consumer GPUs
pipe.enable_offload(
cpu_offload=True,
offload_granularity="model", # ["model", "block"]
)
# Create generator manually with specified parameters
pipe.create_generator(
attn_mode="flash_attn3",
aspect_ratio="16:9",
infer_steps=9,
guidance_scale=1,
)
# Generation parameters
seed = 42
prompt = 'A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition, Ultra HD, 4K, cinematic composition.'
negative_prompt = ""
save_result_path = "/path/to/save_results/output.png"
# Generate video
pipe.generate(
seed=seed,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
__version__ = "0.1.0"
__author__ = "LightX2V Contributors"
__license__ = "Apache 2.0"
import lightx2v_platform.set_ai_device
from lightx2v import common, deploy, models, utils
from lightx2v.pipeline import LightX2VPipeline
__all__ = [
"__version__",
"__author__",
"__license__",
"models",
"common",
"deploy",
"utils",
"LightX2VPipeline",
]
class WeightModule:
def __init__(self):
self._modules = {}
self._parameters = {}
def is_empty(self):
return len(self._modules) == 0 and len(self._parameters) == 0
def add_module(self, name, module):
self._modules[name] = module
setattr(self, name, module)
def register_parameter(self, name, param):
self._parameters[name] = param
setattr(self, name, param)
def load(self, weight_dict):
for _, module in self._modules.items():
if hasattr(module, "load"):
module.load(weight_dict)
for _, parameter in self._parameters.items():
if hasattr(parameter, "load"):
parameter.load(weight_dict)
def register_diff(self, weight_dict):
for _, module in self._modules.items():
if hasattr(module, "register_diff"):
module.register_diff(weight_dict)
for _, parameter in self._parameters.items():
if hasattr(parameter, "register_diff"):
parameter.register_diff(weight_dict)
def register_lora(self, weight_dict, strength):
for _, module in self._modules.items():
if hasattr(module, "register_lora"):
module.register_lora(weight_dict, strength)
for _, parameter in self._parameters.items():
if hasattr(parameter, "register_lora"):
parameter.register_lora(weight_dict, strength)
def update_lora(self, weight_dict, strength):
for _, module in self._modules.items():
if hasattr(module, "update_lora"):
module.update_lora(weight_dict, strength)
for _, parameter in self._parameters.items():
if hasattr(parameter, "update_lora"):
parameter.update_lora(weight_dict, strength)
def remove_lora(self):
for _, module in self._modules.items():
if hasattr(module, "remove_lora"):
module.remove_lora()
for _, parameter in self._parameters.items():
if hasattr(parameter, "remove_lora"):
parameter.remove_lora()
def state_dict(self, destination=None):
if destination is None:
destination = {}
for _, param in self._parameters.items():
if param is not None:
param.state_dict(destination)
for _, module in self._modules.items():
if module is not None:
module.state_dict(destination)
return destination
def load_state_dict(self, destination, block_index, adapter_block_index=None):
if destination is None:
destination = {}
for _, param in self._parameters.items():
if param is not None:
param.load_state_dict(destination, block_index, adapter_block_index)
for _, module in self._modules.items():
if module is not None:
module.load_state_dict(destination, block_index, adapter_block_index)
return destination
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
for _, param in self._parameters.items():
if param is not None:
param.load_state_dict_from_disk(block_index, adapter_block_index)
for _, module in self._modules.items():
if module is not None:
module.load_state_dict_from_disk(block_index, adapter_block_index)
def named_parameters(self, prefix=""):
for name, param in self._parameters.items():
if param is not None:
yield prefix + name, param
for name, module in self._modules.items():
if module is not None:
yield from module.named_parameters(prefix + name + ".")
def to_cpu(self, non_blocking=False):
for name, param in self._parameters.items():
if param is not None:
if hasattr(param, "cpu"):
self._parameters[name] = param.cpu()
setattr(self, name, self._parameters[name])
elif hasattr(param, "to_cpu"):
self._parameters[name].to_cpu()
setattr(self, name, self._parameters[name])
for module in self._modules.values():
if isinstance(module, WeightModuleList):
for i in range(len(module)):
for m in module[i]._modules.values():
if m is not None and hasattr(m, "to_cpu"):
m.to_cpu()
for m in module[i]._parameters.values():
if m is not None and hasattr(m, "to_cpu"):
m.to_cpu()
else:
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu()
def to_cuda(self, non_blocking=False):
for name, param in self._parameters.items():
if param is not None:
if hasattr(param, "cuda"):
self._parameters[name] = param.cuda()
elif hasattr(param, "to_cuda"):
self._parameters[name].to_cuda()
setattr(self, name, self._parameters[name])
for module in self._modules.values():
if isinstance(module, WeightModuleList):
for i in range(len(module)):
for m in module[i]._modules.values():
if m is not None and hasattr(m, "to_cuda"):
m.to_cuda()
for m in module[i]._parameters.values():
if m is not None and hasattr(m, "to_cuda"):
m.to_cuda()
else:
if module is not None and hasattr(module, "to_cuda"):
module.to_cuda()
def to_cpu_async(self, non_blocking=True):
for name, param in self._parameters.items():
if param is not None:
if hasattr(param, "cpu"):
self._parameters[name] = param.cpu(non_blocking=True)
setattr(self, name, self._parameters[name])
elif hasattr(param, "to_cpu"):
self._parameters[name].to_cpu(non_blocking=True)
setattr(self, name, self._parameters[name])
for module in self._modules.values():
if isinstance(module, WeightModuleList):
for i in range(len(module)):
for m in module[i]._modules.values():
if m is not None and hasattr(m, "to_cpu"):
m.to_cpu(non_blocking=True)
for m in module[i]._parameters.values():
if m is not None and hasattr(m, "to_cpu"):
m.to_cpu(non_blocking=True)
else:
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu(non_blocking=True)
def to_cuda_async(self, non_blocking=True):
for name, param in self._parameters.items():
if param is not None:
if hasattr(param, "cuda"):
self._parameters[name] = param.cuda(non_blocking=True)
elif hasattr(param, "to_cuda"):
self._parameters[name].to_cuda(non_blocking=True)
setattr(self, name, self._parameters[name])
for module in self._modules.values():
if isinstance(module, WeightModuleList):
for i in range(len(module)):
for m in module[i]._modules.values():
if m is not None and hasattr(m, "to_cuda"):
m.to_cuda(non_blocking=True)
for m in module[i]._parameters.values():
if m is not None and hasattr(m, "to_cuda"):
m.to_cuda(non_blocking=True)
else:
if module is not None and hasattr(module, "to_cuda"):
module.to_cuda(non_blocking=True)
class WeightModuleList(WeightModule):
def __init__(self, modules=None):
super().__init__()
self._list = []
if modules is not None:
for idx, module in enumerate(modules):
self.append(module)
def append(self, module):
idx = len(self._list)
self._list.append(module)
self.add_module(str(idx), module)
def __getitem__(self, idx):
return self._list[idx]
def __setitem__(self, idx, module):
self._list[idx] = module
self.add_module(str(idx), module)
def __len__(self):
return len(self._list)
def __iter__(self):
return iter(self._list)
from concurrent.futures import ThreadPoolExecutor
import torch
from loguru import logger
from packaging.version import parse
from tqdm import tqdm
from lightx2v.utils.profiler import ExcludedProfilingContext
from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE)
class WeightAsyncStreamManager(object):
def __init__(self, offload_granularity):
self.offload_granularity = offload_granularity
self.init_stream = torch_device_module.Stream(priority=0)
self.need_init_first_buffer = True
self.lazy_load = False
torch_version = parse(torch.__version__.split("+")[0])
if AI_DEVICE == "cuda" and torch_version >= parse("2.7"):
self.cuda_load_stream = torch_device_module.Stream(priority=1)
self.compute_stream = torch_device_module.Stream(priority=1)
else:
self.cuda_load_stream = torch_device_module.Stream(priority=0)
self.compute_stream = torch_device_module.Stream(priority=-1)
def init_cpu_buffer(self, blocks_cpu_buffer=None, phases_cpu_buffer=None):
self.need_init_first_buffer = True
if self.offload_granularity == "block":
assert blocks_cpu_buffer is not None
self.cpu_buffers = [blocks_cpu_buffer[i] for i in range(len(blocks_cpu_buffer))]
elif self.offload_granularity == "phase":
assert phases_cpu_buffer is not None
self.cpu_buffers = [phases_cpu_buffer[i] for i in range(len(phases_cpu_buffer))]
else:
raise NotImplementedError
def init_cuda_buffer(self, blocks_cuda_buffer=None, phases_cuda_buffer=None):
self.need_init_first_buffer = True
if self.offload_granularity == "block":
assert blocks_cuda_buffer is not None
self.cuda_buffers = [blocks_cuda_buffer[i] for i in range(len(blocks_cuda_buffer))]
elif self.offload_granularity == "phase":
assert phases_cuda_buffer is not None
self.cuda_buffers = [phases_cuda_buffer[i] for i in range(len(phases_cuda_buffer))]
else:
raise NotImplementedError
def init_first_buffer(self, blocks, adapter_block_idx=None):
with torch_device_module.stream(self.init_stream):
if hasattr(self, "cpu_buffers"):
if self.offload_granularity == "block":
self.cuda_buffers[0].load_state_dict(self.cpu_buffers[0].state_dict(), 0, adapter_block_idx)
else:
self.cuda_buffers[0].load_state_dict(self.cpu_buffers[0][0].state_dict(), 0, adapter_block_idx)
else:
if self.offload_granularity == "block":
self.cuda_buffers[0].load_state_dict(blocks[0].state_dict(), 0, adapter_block_idx)
else:
self.cuda_buffers[0].load_state_dict(blocks[0].compute_phases[0].state_dict(), 0, adapter_block_idx)
self.init_stream.synchronize()
self.need_init_first_buffer = False
def prefetch_weights(self, block_idx, blocks, adapter_block_idx=None):
with torch_device_module.stream(self.cuda_load_stream):
if hasattr(self, "cpu_buffers"):
self.cuda_buffers[1].load_state_dict(self.cpu_buffers[0].state_dict(), block_idx, adapter_block_idx)
else:
self.cuda_buffers[1].load_state_dict(blocks[block_idx].state_dict(), block_idx, adapter_block_idx)
def prefetch_phase(self, block_idx, phase_idx, blocks, adapter_block_idx=None):
with torch_device_module.stream(self.cuda_load_stream):
if hasattr(self, "cpu_buffers"):
self.cuda_buffers[phase_idx].load_state_dict(self.cpu_buffers[0][phase_idx].state_dict(), block_idx, adapter_block_idx)
else:
self.cuda_buffers[phase_idx].load_state_dict(blocks[block_idx].compute_phases[phase_idx].state_dict(), block_idx, adapter_block_idx)
def swap_blocks(self):
self.cuda_load_stream.synchronize()
self.compute_stream.synchronize()
self.cuda_buffers[0], self.cuda_buffers[1] = (
self.cuda_buffers[1],
self.cuda_buffers[0],
)
def swap_phases(self):
self.cuda_load_stream.synchronize()
self.compute_stream.synchronize()
@ExcludedProfilingContext("🔥 warm_up_cpu_buffers")
def warm_up_cpu_buffers(self, blocks_num):
logger.info("🔥 Warming up cpu buffers...")
for i in tqdm(range(blocks_num)):
for phase in self.cpu_buffers[0]:
phase.load_state_dict_from_disk(i, None)
for phase in self.cpu_buffers[1]:
phase.load_state_dict_from_disk(i, None)
for phase in self.cpu_buffers[0]:
phase.load_state_dict_from_disk(0, None)
for phase in self.cpu_buffers[1]:
phase.load_state_dict_from_disk(1, None)
logger.info("✅ CPU buffers warm-up completed.")
def init_lazy_load(self, num_workers=6):
self.lazy_load = True
self.executor = ThreadPoolExecutor(max_workers=num_workers)
self.prefetch_futures = []
self.prefetch_block_idx = -1
def start_prefetch_block(self, block_idx, adapter_block_idx=None):
self.prefetch_block_idx = block_idx
self.prefetch_futures = []
if self.offload_granularity == "block":
future = self.executor.submit(self.cpu_buffers[1].load_state_dict_from_disk, block_idx, adapter_block_idx)
self.prefetch_futures.append(future)
else:
for phase in self.cpu_buffers[1]:
future = self.executor.submit(phase.load_state_dict_from_disk, block_idx, adapter_block_idx)
self.prefetch_futures.append(future)
def swap_cpu_buffers(self):
# import time
# wait_start = time.time()
# already_done = all(f.done() for f in self.prefetch_futures)
for f in self.prefetch_futures:
f.result()
# wait_time = time.time() - wait_start
# logger.debug(f"[Prefetch] block {self.prefetch_block_idx}: wait={wait_time:.3f}s, already_done={already_done}")
self.cpu_buffers = [self.cpu_buffers[1], self.cpu_buffers[0]]
def __del__(self):
if hasattr(self, "executor") and self.executor is not None:
for f in self.prefetch_futures:
if not f.done():
f.result()
self.executor.shutdown(wait=False)
self.executor = None
logger.debug("ThreadPoolExecutor shut down successfully.")
from .attn import *
from .conv import *
from .embedding import *
from .mm import *
from .norm import *
from .tensor import *
from .draft_attn import DraftAttnWeight
from .flash_attn import FlashAttn2Weight, FlashAttn3Weight
from .general_sparse_attn import GeneralSparseAttnWeight
from .nbhd_attn import NbhdAttnWeight, NbhdAttnWeightFlashInfer
from .radial_attn import RadialAttnWeight
from .ring_attn import RingAttnWeight
from .sage_attn import SageAttn2Weight, SageAttn3Weight
from .sla_attn import SlaAttnWeight
from .sparse_mask_generator import NbhdMaskGenerator, SlaMaskGenerator, SvgMaskGenerator
from .sparse_operator import MagiOperator, SlaTritonOperator
from .spassage_attn import SageAttnWeight
from .svg2_attn import Svg2AttnWeight
from .svg_attn import SvgAttnWeight
from .torch_sdpa import TorchSDPAWeight
from .ulysses_attn import Ulysses4090AttnWeight, UlyssesAttnWeight
import math
import torch
import torch.nn.functional as F
from loguru import logger
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
try:
from magi_attention.functional import flex_flash_attn_func as magi_ffa_func
except ImportError:
magi_ffa_func = None
flash_attn_varlen_func = None
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as _func
flash_attn_varlen_func = _func
except ImportError:
logger.info("flash_attn_varlen_func not found, please install flash_attn2 first")
try:
from flash_attn_interface import flash_attn_varlen_func as _func
flash_attn_varlen_func = _func
except ImportError:
logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first")
@ATTN_WEIGHT_REGISTER("draft_attn")
class DraftAttnWeight(AttnWeightTemplate):
sparsity_ratio = 0.75
reorg_idx_dict = {}
restore_idx_dict = {}
bucket_offsets_dict = {}
def __init__(self):
self.config = {}
@staticmethod
def build_grid_gather_index_and_bucket_fast(H, W, pool_h, pool_w, seqlen):
Gh = (H + pool_h - 1) // pool_h
Gw = (W + pool_w - 1) // pool_w
# Single frame
gather_single = []
bucket_sizes_single = []
for gh in range(Gh):
h0 = gh * pool_h
h1 = min(h0 + pool_h, H)
block_h = h1 - h0
for gw in range(Gw):
w0 = gw * pool_w
w1 = min(w0 + pool_w, W)
block_w = w1 - w0
# bucket size
bucket_size = block_h * block_w
bucket_sizes_single.append(bucket_size)
# gather index
for i in range(h0, h1):
row_base = i * W
for j in range(w0, w1):
gather_single.append(row_base + j)
bucket_sizes = []
bucket_offsets = [0]
running = 0
# bucket + offsets
for sz in bucket_sizes_single:
bucket_sizes.append(sz)
running += sz
bucket_offsets.append(running)
frame_num = seqlen // (H * W)
gather_index = []
for f in range(frame_num):
frame_base = f * H * W
# index
gather_index.extend(idx + frame_base for idx in gather_single)
return gather_index, bucket_sizes, bucket_offsets
@classmethod
@torch.compiler.disable
def prepare_reorg_idx_and_bucket_offset(cls, seqlen, frame_h, frame_w, pool_h, pool_w, device):
if (seqlen, frame_h, frame_w) in cls.reorg_idx_dict:
return
reorg_idx, bucket_sizes, bucket_offsets = cls.build_grid_gather_index_and_bucket_fast(
H=frame_h,
W=frame_w,
pool_h=pool_h,
pool_w=pool_w,
seqlen=seqlen,
)
reorg_idx = torch.tensor(reorg_idx, dtype=torch.long, device=device)
restore_idx = torch.empty_like(reorg_idx)
restore_idx[reorg_idx] = torch.arange(reorg_idx.numel(), device=device)
cls.reorg_idx_dict[(seqlen, frame_h, frame_w)] = reorg_idx
cls.restore_idx_dict[(seqlen, frame_h, frame_w)] = restore_idx
cls.bucket_offsets_dict[(seqlen, frame_h, frame_w)] = torch.tensor(bucket_offsets, dtype=torch.int32, device=device)
logger.info(f"DraftAttnWeight: reorg_idx len: {len(reorg_idx)}")
logger.info(f"DraftAttnWeight: bucket_sizes: {bucket_sizes}")
logger.info(f"DraftAttnWeight: bucket_offsets: {bucket_offsets}")
logger.info(f"DraftAttnWeight: using sparsity ratio {cls.sparsity_ratio}")
def sample_qk_attention_2d(
self,
q: torch.Tensor,
k: torch.Tensor,
frame_h: int,
frame_w: int,
pool_h: int,
pool_w: int,
):
L, H, D = q.shape
frame_tokens = frame_h * frame_w
assert L % frame_tokens == 0, "L must be multiple of frame_h*frame_w"
num_frames = L // frame_tokens
# 1) Slice out the video part and reshape to frames:
# [L, H, D] → [num_frames, frame_h, frame_w, H, D]
q_vid = q.view(num_frames, frame_h, frame_w, H, D)
k_vid = k.view(num_frames, frame_h, frame_w, H, D)
# 2) Permute & merge (num_frames, H*D) into channel dim:
# → [num_frames, H*D, frame_h, frame_w]
q_vid = q_vid.permute(0, 3, 4, 1, 2).reshape(num_frames, H * D, frame_h, frame_w)
k_vid = k_vid.permute(0, 3, 4, 1, 2).reshape(num_frames, H * D, frame_h, frame_w)
# 3) 2D avg‐pool each frame (ceil_mode ensures we cover the edges):
# → [num_frames, H*D, S_h, S_w]
q_pooled = F.avg_pool2d(q_vid, kernel_size=(pool_h, pool_w), stride=(pool_h, pool_w), ceil_mode=True)
k_pooled = F.avg_pool2d(k_vid, kernel_size=(pool_h, pool_w), stride=(pool_h, pool_w), ceil_mode=True)
S_h, S_w = q_pooled.shape[-2:]
S = num_frames * S_h * S_w
# 4) Un‐merge channel back to [S, H, D]:
# → [num_frames, H, D, S_h, S_w] → [S, H, D]
def unmerge(x):
x = x.reshape(num_frames, H, D, S_h, S_w)
return x.permute(0, 3, 4, 1, 2).reshape(S, H, D)
sampled_q = unmerge(q_pooled)
sampled_k = unmerge(k_pooled)
# 5) Compute per‐head scaled dot‐prod attention:
# [S, H, D] → [H, S, D]
q_heads = sampled_q.permute(1, 0, 2)
k_heads = sampled_k.permute(1, 0, 2)
# → [H, S, S]
scores = torch.einsum("hld,hmd->hlm", q_heads, k_heads) / math.sqrt(D)
attn_map = torch.softmax(scores, dim=-1)
return attn_map
def attention_percentile_mask_headwise(self, attn_map: torch.Tensor, r: float) -> torch.BoolTensor:
"""
Build a mask per head so that each head keeps its top-r fraction of entries as True.
Args:
attn_map: Tensor of shape [H, S, S], attention scores (e.g. after softmax).
r: float in (0,1), fraction of entries *per head* to keep True.
Returns:
mask: BoolTensor of shape [H, S, S], where for each head h,
mask[h].float().mean() ≈ r.
"""
H, S, _ = attn_map.shape
flat = attn_map.reshape(H, -1) # [H, S*S]
n = flat.shape[1]
k = int((1.0 - r) * n)
if k == 0:
return torch.ones_like(attn_map, dtype=torch.bool)
if k >= n:
return torch.zeros_like(attn_map, dtype=torch.bool)
# Calculate threshold for each head independently
thresholds = torch.kthvalue(flat, k, dim=1).values # [H]
mask = attn_map >= thresholds[:, None, None] # broadcasting
return mask
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
block_idx=0,
scheduler=None,
**kwargs,
):
if block_idx < 1:
if cu_seqlens_q is not None:
cu_seqlens_q = cu_seqlens_q.to(q.device)
if cu_seqlens_kv is not None:
cu_seqlens_kv = cu_seqlens_kv.to(k.device)
out = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
)
return out.reshape(out.shape[0], -1)
if scheduler is not None:
frame_h = scheduler.latents.shape[2] // scheduler.patch_size[1]
frame_w = scheduler.latents.shape[3] // scheduler.patch_size[2]
else:
frame_h, frame_w = 32, 48
seqlen, head_num, head_dim = q.shape
frame_size = frame_h * frame_w
num_frames = seqlen // frame_size
pool_h, pool_w = (8, 16) if frame_h < frame_w else (16, 8)
self.prepare_reorg_idx_and_bucket_offset(
seqlen=seqlen,
frame_h=frame_h,
frame_w=frame_w,
pool_h=pool_h,
pool_w=pool_w,
device=q.device,
)
attn = self.sample_qk_attention_2d(
q,
k,
frame_h=frame_h,
frame_w=frame_w,
pool_h=pool_h,
pool_w=pool_w,
)
mask = self.attention_percentile_mask_headwise(attn, 1 - self.sparsity_ratio)
# sink mask
mask_size_pre_frame = mask.shape[1] // num_frames
mask[:, :, :mask_size_pre_frame] = True
# diagonal mask
block_indices = torch.arange(mask.shape[1], device=mask.device) // mask_size_pre_frame
mask |= block_indices[:, None] == block_indices[None, :]
h_indices, i_indices, j_indices = torch.nonzero(mask, as_tuple=True) # [N, 3] -> [head, i, j]
bucket_offsets = self.bucket_offsets_dict[(seqlen, frame_h, frame_w)]
base_offset = h_indices * seqlen
q_frame_base = (i_indices // mask_size_pre_frame) * frame_size
q_bucket_idx = i_indices % mask_size_pre_frame
q_start = base_offset + q_frame_base + bucket_offsets[q_bucket_idx]
q_end = base_offset + q_frame_base + bucket_offsets[q_bucket_idx + 1]
k_frame_base = (j_indices // mask_size_pre_frame) * frame_size
k_bucket_idx = j_indices % mask_size_pre_frame
k_start = base_offset + k_frame_base + bucket_offsets[k_bucket_idx]
k_end = base_offset + k_frame_base + bucket_offsets[k_bucket_idx + 1]
q_ranges = torch.stack([q_start, q_end], dim=1).to(dtype=torch.int32)
k_ranges = torch.stack([k_start, k_end], dim=1).to(dtype=torch.int32)
attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device=q.device)
reorg_idx = self.reorg_idx_dict[(seqlen, frame_h, frame_w)]
q = q[reorg_idx]
k = k[reorg_idx]
v = v[reorg_idx]
q = q.permute(1, 0, 2).reshape(head_num * seqlen, 1, head_dim)
k = k.permute(1, 0, 2).reshape(head_num * seqlen, 1, head_dim)
v = v.permute(1, 0, 2).reshape(head_num * seqlen, 1, head_dim)
out = magi_ffa_func(
q,
k,
v,
q_ranges=q_ranges,
k_ranges=k_ranges,
attn_type_map=attn_type_map,
auto_range_merge=True,
)[0]
out = out.reshape(head_num, seqlen, head_dim).permute(1, 0, 2)
restore_idx = self.restore_idx_dict[(seqlen, frame_h, frame_w)]
out = out[restore_idx]
return out.reshape(out.shape[0], -1)
from loguru import logger
try:
import flash_attn # noqa: F401
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
logger.info("flash_attn_varlen_func not found, please install flash_attn2 first")
flash_attn_varlen_func = None
try:
from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
except ImportError:
logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first")
flash_attn_varlen_func_v3 = None
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
@ATTN_WEIGHT_REGISTER("flash_attn2")
class FlashAttn2Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
**kwargs,
):
if len(q.shape) == 3:
bs = 1
elif len(q.shape) == 4:
bs = q.shape[0]
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
x = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
).reshape(bs * max_seqlen_q, -1)
return x
@ATTN_WEIGHT_REGISTER("flash_attn3")
class FlashAttn3Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
**kwargs,
):
if len(q.shape) == 3:
bs = 1
elif len(q.shape) == 4:
bs = q.shape[0]
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
x = flash_attn_varlen_func_v3(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
).reshape(bs * max_seqlen_q, -1)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment