Unverified Commit 4c0a9a0d authored by Gu Shiqiao's avatar Gu Shiqiao Committed by GitHub
Browse files

Fix device bugs (#527)

parent fbb19ffc
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 360,
"audio_sr": 16000,
"target_video_length": 81,
"resize_mode": "adaptive",
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
"sample_guide_scale": 1.0,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"use_31_block": false,
"run_device": "mlu",
"rope_type": "torch",
"modulate_type": "torch"
}
...@@ -442,6 +442,25 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -442,6 +442,25 @@ class MMWeightQuantTemplate(MMWeightTemplate):
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
if self.bias_name is not None:
if self.create_cuda_buffer:
# move to cuda buffer
self.bias_cuda_buffer = weight_dict[self.bias_name].cuda()
else:
device = weight_dict[self.bias_name].device
if device.type == "cuda":
self.bias = weight_dict[self.bias_name]
elif device.type == "cpu":
bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.pin_bias.copy_(weight_dict[self.bias_name])
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
else:
self.bias = None
self.pin_bias = None
def load_fp8_perblock128_sym(self, weight_dict): def load_fp8_perblock128_sym(self, weight_dict):
if self.config.get("weight_auto_quant", False): if self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name] self.weight = weight_dict[self.weight_name]
......
...@@ -3,7 +3,11 @@ import argparse ...@@ -3,7 +3,11 @@ import argparse
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from loguru import logger from loguru import logger
from torch.distributed import ProcessGroupNCCL
try:
from torch.distributed import ProcessGroupNCCL
except ImportError:
ProcessGroupNCCL = None
from lightx2v.common.ops import * from lightx2v.common.ops import *
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner # noqa: F401 from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner # noqa: F401
......
...@@ -159,13 +159,14 @@ class ByT5TextEncoder: ...@@ -159,13 +159,14 @@ class ByT5TextEncoder:
self, self,
config, config,
device=torch.device("cpu"), device=torch.device("cpu"),
run_device=torch.device("cuda"),
checkpoint_path=None, checkpoint_path=None,
byt5_max_length=256, byt5_max_length=256,
cpu_offload=False, cpu_offload=False,
): ):
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
self.config = config self.config = config
self.device = device self.run_device = run_device
self.byt5_max_length = byt5_max_length self.byt5_max_length = byt5_max_length
self.enable_cfg = config.get("enable_cfg", False) self.enable_cfg = config.get("enable_cfg", False)
byT5_google_path = os.path.join(checkpoint_path, "text_encoder", "byt5-small") byT5_google_path = os.path.join(checkpoint_path, "text_encoder", "byt5-small")
...@@ -300,12 +301,12 @@ class ByT5TextEncoder: ...@@ -300,12 +301,12 @@ class ByT5TextEncoder:
negative_masks = [] negative_masks = []
for prompt in prompt_list: for prompt in prompt_list:
pos_emb, pos_mask = self._process_single_byt5_prompt(prompt, self.device) pos_emb, pos_mask = self._process_single_byt5_prompt(prompt, self.run_device)
positive_embeddings.append(pos_emb) positive_embeddings.append(pos_emb)
positive_masks.append(pos_mask) positive_masks.append(pos_mask)
if self.enable_cfg: # TODO: 把cfg拆出去,更适合并行 if self.enable_cfg: # TODO: 把cfg拆出去,更适合并行
neg_emb, neg_mask = self._process_single_byt5_prompt("", self.device) neg_emb, neg_mask = self._process_single_byt5_prompt("", self.run_device)
negative_embeddings.append(neg_emb) negative_embeddings.append(neg_emb)
negative_masks.append(neg_mask) negative_masks.append(neg_mask)
...@@ -327,8 +328,8 @@ class ByT5TextEncoder: ...@@ -327,8 +328,8 @@ class ByT5TextEncoder:
@torch.no_grad() @torch.no_grad()
def infer(self, prompts): def infer(self, prompts):
if self.cpu_offload: if self.cpu_offload:
self.byt5_model = self.byt5_model.to(self.device) self.byt5_model = self.byt5_model.to(self.run_device)
self.byt5_mapper = self.byt5_mapper.to(self.device) self.byt5_mapper = self.byt5_mapper.to(self.run_device)
byt5_embeddings, byt5_masks = self._prepare_byt5_embeddings(prompts) byt5_embeddings, byt5_masks = self._prepare_byt5_embeddings(prompts)
byt5_features = self.byt5_mapper(byt5_embeddings.to(torch.bfloat16)) byt5_features = self.byt5_mapper(byt5_embeddings.to(torch.bfloat16))
if self.cpu_offload: if self.cpu_offload:
......
...@@ -552,6 +552,7 @@ class Qwen25VL_TextEncoder: ...@@ -552,6 +552,7 @@ class Qwen25VL_TextEncoder:
text_len=1000, text_len=1000,
dtype=torch.float16, dtype=torch.float16,
device=torch.device("cpu"), device=torch.device("cpu"),
run_device=torch.device("cuda"),
checkpoint_path=None, checkpoint_path=None,
cpu_offload=False, cpu_offload=False,
qwen25vl_quantized=False, qwen25vl_quantized=False,
...@@ -560,7 +561,7 @@ class Qwen25VL_TextEncoder: ...@@ -560,7 +561,7 @@ class Qwen25VL_TextEncoder:
): ):
self.text_len = text_len self.text_len = text_len
self.dtype = dtype self.dtype = dtype
self.device = device self.run_device = run_device
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
self.qwen25vl_quantized = qwen25vl_quantized self.qwen25vl_quantized = qwen25vl_quantized
self.qwen25vl_quant_scheme = qwen25vl_quant_scheme self.qwen25vl_quant_scheme = qwen25vl_quant_scheme
...@@ -589,20 +590,20 @@ class Qwen25VL_TextEncoder: ...@@ -589,20 +590,20 @@ class Qwen25VL_TextEncoder:
def infer(self, texts): def infer(self, texts):
if self.cpu_offload: if self.cpu_offload:
self.text_encoder = self.text_encoder.to(self.device) self.text_encoder = self.text_encoder.to(self.run_device)
text_inputs = self.text_encoder.text2tokens(texts, data_type="video", max_length=self.text_len) text_inputs = self.text_encoder.text2tokens(texts, data_type="video", max_length=self.text_len)
prompt_outputs = self.text_encoder.encode(text_inputs, data_type="video", device=self.device) prompt_outputs = self.text_encoder.encode(text_inputs, data_type="video", device=self.run_device)
if self.cpu_offload: if self.cpu_offload:
self.text_encoder = self.text_encoder.to("cpu") self.text_encoder = self.text_encoder.to("cpu")
prompt_embeds = prompt_outputs.hidden_state prompt_embeds = prompt_outputs.hidden_state
attention_mask = prompt_outputs.attention_mask attention_mask = prompt_outputs.attention_mask
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.to(self.device) attention_mask = attention_mask.to(self.run_device)
_, seq_len = attention_mask.shape _, seq_len = attention_mask.shape
attention_mask = attention_mask.repeat(1, self.num_videos_per_prompt) attention_mask = attention_mask.repeat(1, self.num_videos_per_prompt)
attention_mask = attention_mask.view(self.num_videos_per_prompt, seq_len) attention_mask = attention_mask.view(self.num_videos_per_prompt, seq_len)
prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.device) prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.run_device)
seq_len = prompt_embeds.shape[1] seq_len = prompt_embeds.shape[1]
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
......
...@@ -95,6 +95,7 @@ class VisionEncoder(nn.Module): ...@@ -95,6 +95,7 @@ class VisionEncoder(nn.Module):
output_key: Optional[str] = None, output_key: Optional[str] = None,
logger=None, logger=None,
device=None, device=None,
run_device=None,
cpu_offload=False, cpu_offload=False,
): ):
super().__init__() super().__init__()
...@@ -120,6 +121,7 @@ class VisionEncoder(nn.Module): ...@@ -120,6 +121,7 @@ class VisionEncoder(nn.Module):
) )
self.dtype = self.model.dtype self.dtype = self.model.dtype
self.device = self.model.device self.device = self.model.device
self.run_device = run_device
self.processor, self.processor_path = load_image_processor( self.processor, self.processor_path = load_image_processor(
processor_type=self.processor_type, processor_type=self.processor_type,
...@@ -175,7 +177,7 @@ class VisionEncoder(nn.Module): ...@@ -175,7 +177,7 @@ class VisionEncoder(nn.Module):
if isinstance(images, np.ndarray): if isinstance(images, np.ndarray):
# Preprocess images if they're numpy arrays # Preprocess images if they're numpy arrays
preprocessed = self.processor.preprocess(images=images, return_tensors="pt").to(device=self.device, dtype=self.model.dtype) preprocessed = self.processor.preprocess(images=images, return_tensors="pt").to(device=self.run_device, dtype=self.model.dtype)
else: else:
# Assume already preprocessed # Assume already preprocessed
preprocessed = images preprocessed = images
...@@ -230,11 +232,13 @@ class SiglipVisionEncoder: ...@@ -230,11 +232,13 @@ class SiglipVisionEncoder:
self, self,
config, config,
device=torch.device("cpu"), device=torch.device("cpu"),
run_device=torch.device("cuda"),
checkpoint_path=None, checkpoint_path=None,
cpu_offload=False, cpu_offload=False,
): ):
self.config = config self.config = config
self.device = device self.device = device
self.run_device = run_device
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
self.vision_states_dim = 1152 self.vision_states_dim = 1152
vision_encoder_path = os.path.join(checkpoint_path, "vision_encoder", "siglip") vision_encoder_path = os.path.join(checkpoint_path, "vision_encoder", "siglip")
...@@ -248,6 +252,7 @@ class SiglipVisionEncoder: ...@@ -248,6 +252,7 @@ class SiglipVisionEncoder:
output_key=None, output_key=None,
logger=None, logger=None,
device=self.device, device=self.device,
run_device=self.run_device,
cpu_offload=self.cpu_offload, cpu_offload=self.cpu_offload,
) )
...@@ -265,7 +270,7 @@ class SiglipVisionEncoder: ...@@ -265,7 +270,7 @@ class SiglipVisionEncoder:
@torch.no_grad() @torch.no_grad()
def infer(self, vision_states): def infer(self, vision_states):
if self.cpu_offload: if self.cpu_offload:
self.vision_in = self.vision_in.to("cuda") self.vision_in = self.vision_in.to(self.run_device)
vision_states = self.vision_in(vision_states) vision_states = self.vision_in(vision_states)
if self.cpu_offload: if self.cpu_offload:
self.vision_in = self.vision_in.to("cpu") self.vision_in = self.vision_in.to("cpu")
......
...@@ -252,7 +252,7 @@ class AudioAdapter(nn.Module): ...@@ -252,7 +252,7 @@ class AudioAdapter(nn.Module):
quantized: bool = False, quantized: bool = False,
quant_scheme: str = None, quant_scheme: str = None,
cpu_offload: bool = False, cpu_offload: bool = False,
device=torch.device("cpu"), run_device=torch.device("cuda"),
): ):
super().__init__() super().__init__()
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
...@@ -263,7 +263,7 @@ class AudioAdapter(nn.Module): ...@@ -263,7 +263,7 @@ class AudioAdapter(nn.Module):
mlp_dims=mlp_dims, mlp_dims=mlp_dims,
transformer_layers=projection_transformer_layers, transformer_layers=projection_transformer_layers,
) )
self.device = torch.device(device) self.run_device = run_device
# self.num_tokens = num_tokens * 4 # self.num_tokens = num_tokens * 4
self.num_tokens_x4 = num_tokens * 4 self.num_tokens_x4 = num_tokens * 4
self.audio_pe = nn.Parameter(torch.randn(self.num_tokens_x4, mlp_dims[-1] // num_tokens) * 0.02) self.audio_pe = nn.Parameter(torch.randn(self.num_tokens_x4, mlp_dims[-1] // num_tokens) * 0.02)
...@@ -302,10 +302,10 @@ class AudioAdapter(nn.Module): ...@@ -302,10 +302,10 @@ class AudioAdapter(nn.Module):
@torch.no_grad() @torch.no_grad()
def forward_audio_proj(self, audio_feat, latent_frame): def forward_audio_proj(self, audio_feat, latent_frame):
if self.cpu_offload: if self.cpu_offload:
self.audio_proj.to(self.device) self.audio_proj.to(self.run_device)
x = self.audio_proj(audio_feat, latent_frame) x = self.audio_proj(audio_feat, latent_frame)
x = self.rearange_audio_features(x) x = self.rearange_audio_features(x)
x = x + self.audio_pe.to(self.device) x = x + self.audio_pe.to(self.run_device)
if self.cpu_offload: if self.cpu_offload:
self.audio_proj.to("cpu") self.audio_proj.to("cpu")
return x return x
...@@ -5,14 +5,15 @@ from lightx2v.utils.envs import * ...@@ -5,14 +5,15 @@ from lightx2v.utils.envs import *
class SekoAudioEncoderModel: class SekoAudioEncoderModel:
def __init__(self, model_path, audio_sr, cpu_offload, device): def __init__(self, model_path, audio_sr, cpu_offload, run_device):
self.model_path = model_path self.model_path = model_path
self.audio_sr = audio_sr self.audio_sr = audio_sr
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
if self.cpu_offload: if self.cpu_offload:
self.device = torch.device("cpu") self.device = torch.device("cpu")
else: else:
self.device = torch.device(device) self.device = torch.device(run_device)
self.run_device = run_device
self.load() self.load()
def load(self): def load(self):
...@@ -26,13 +27,13 @@ class SekoAudioEncoderModel: ...@@ -26,13 +27,13 @@ class SekoAudioEncoderModel:
self.audio_feature_encoder = self.audio_feature_encoder.to("cpu") self.audio_feature_encoder = self.audio_feature_encoder.to("cpu")
def to_cuda(self): def to_cuda(self):
self.audio_feature_encoder = self.audio_feature_encoder.to(self.device) self.audio_feature_encoder = self.audio_feature_encoder.to(self.run_device)
@torch.no_grad() @torch.no_grad()
def infer(self, audio_segment): def infer(self, audio_segment):
audio_feat = self.audio_feature_extractor(audio_segment, sampling_rate=self.audio_sr, return_tensors="pt").input_values.to(self.device).to(dtype=GET_DTYPE()) audio_feat = self.audio_feature_extractor(audio_segment, sampling_rate=self.audio_sr, return_tensors="pt").input_values.to(self.run_device).to(dtype=GET_DTYPE())
if self.cpu_offload: if self.cpu_offload:
self.audio_feature_encoder = self.audio_feature_encoder.to(self.device) self.audio_feature_encoder = self.audio_feature_encoder.to(self.run_device)
audio_feat = self.audio_feature_encoder(audio_feat, return_dict=True).last_hidden_state audio_feat = self.audio_feature_encoder(audio_feat, return_dict=True).last_hidden_state
if self.cpu_offload: if self.cpu_offload:
self.audio_feature_encoder = self.audio_feature_encoder.to("cpu") self.audio_feature_encoder = self.audio_feature_encoder.to("cpu")
......
...@@ -744,7 +744,8 @@ class T5EncoderModel: ...@@ -744,7 +744,8 @@ class T5EncoderModel:
self, self,
text_len, text_len,
dtype=torch.bfloat16, dtype=torch.bfloat16,
device=torch.device("cpu"), device=torch.device("cuda"),
run_device=torch.device("cuda"),
checkpoint_path=None, checkpoint_path=None,
tokenizer_path=None, tokenizer_path=None,
shard_fn=None, shard_fn=None,
...@@ -757,6 +758,7 @@ class T5EncoderModel: ...@@ -757,6 +758,7 @@ class T5EncoderModel:
self.text_len = text_len self.text_len = text_len
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.run_device = run_device
if t5_quantized_ckpt is not None and t5_quantized: if t5_quantized_ckpt is not None and t5_quantized:
self.checkpoint_path = t5_quantized_ckpt self.checkpoint_path = t5_quantized_ckpt
else: else:
...@@ -805,8 +807,8 @@ class T5EncoderModel: ...@@ -805,8 +807,8 @@ class T5EncoderModel:
def infer(self, texts): def infer(self, texts):
ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
ids = ids.to(self.device) ids = ids.to(self.run_device)
mask = mask.to(self.device) mask = mask.to(self.run_device)
seq_lens = mask.gt(0).sum(dim=1).long() seq_lens = mask.gt(0).sum(dim=1).long()
with torch.no_grad(): with torch.no_grad():
......
...@@ -428,7 +428,6 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r ...@@ -428,7 +428,6 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class CLIPModel: class CLIPModel:
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, load_from_rank0=False, run_device=torch.device("cuda")): def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, load_from_rank0=False, run_device=torch.device("cuda")):
self.dtype = dtype self.dtype = dtype
self.device = device
self.run_device = run_device self.run_device = run_device
self.quantized = clip_quantized self.quantized = clip_quantized
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
......
...@@ -68,7 +68,7 @@ class HunyuanVideo15PreInfer: ...@@ -68,7 +68,7 @@ class HunyuanVideo15PreInfer:
self.heads_num = config["heads_num"] self.heads_num = config["heads_num"]
self.frequency_embedding_size = 256 self.frequency_embedding_size = 256
self.max_period = 10000 self.max_period = 10000
self.device = torch.device(self.config.get("run_device", "cuda")) self.run_device = torch.device(self.config.get("run_device", "cuda"))
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
...@@ -155,7 +155,7 @@ class HunyuanVideo15PreInfer: ...@@ -155,7 +155,7 @@ class HunyuanVideo15PreInfer:
byt5_txt = byt5_txt + weights.cond_type_embedding.apply(torch.ones_like(byt5_txt[:, :, 0], device=byt5_txt.device, dtype=torch.long)) byt5_txt = byt5_txt + weights.cond_type_embedding.apply(torch.ones_like(byt5_txt[:, :, 0], device=byt5_txt.device, dtype=torch.long))
txt, text_mask = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask, zero_feat=True) txt, text_mask = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask, zero_feat=True)
siglip_output = siglip_output + weights.cond_type_embedding.apply(2 * torch.ones_like(siglip_output[:, :, 0], dtype=torch.long, device=self.device)) siglip_output = siglip_output + weights.cond_type_embedding.apply(2 * torch.ones_like(siglip_output[:, :, 0], dtype=torch.long, device=self.run_device))
txt, text_mask = self.reorder_txt_token(siglip_output, txt, siglip_mask, text_mask) txt, text_mask = self.reorder_txt_token(siglip_output, txt, siglip_mask, text_mask)
txt = txt[:, : text_mask.sum(), :] txt = txt[:, : text_mask.sum(), :]
......
...@@ -100,7 +100,7 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer): ...@@ -100,7 +100,7 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
self.config = config self.config = config
self.double_blocks_num = config["mm_double_blocks_depth"] self.double_blocks_num = config["mm_double_blocks_depth"]
self.heads_num = config["heads_num"] self.heads_num = config["heads_num"]
self.device = torch.device(self.config.get("run_device", "cuda")) self.run_device = torch.device(self.config.get("run_device", "cuda"))
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
self.seq_p_fp8_comm = self.config["parallel"].get("seq_p_fp8_comm", False) self.seq_p_fp8_comm = self.config["parallel"].get("seq_p_fp8_comm", False)
...@@ -222,7 +222,7 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer): ...@@ -222,7 +222,7 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
key = torch.cat([img_k, txt_k], dim=1) key = torch.cat([img_k, txt_k], dim=1)
value = torch.cat([img_v, txt_v], dim=1) value = torch.cat([img_v, txt_v], dim=1)
seqlen = query.shape[1] seqlen = query.shape[1]
cu_seqlens_qkv = torch.tensor([0, seqlen], dtype=torch.int32, device="cpu").to(self.device, non_blocking=True) cu_seqlens_qkv = torch.tensor([0, seqlen], dtype=torch.int32, device="cpu").to(self.run_device, non_blocking=True)
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
attn_out = weights.self_attention_parallel.apply( attn_out = weights.self_attention_parallel.apply(
......
...@@ -9,6 +9,10 @@ from .triton_ops import fuse_scale_shift_kernel ...@@ -9,6 +9,10 @@ from .triton_ops import fuse_scale_shift_kernel
from .utils import apply_wan_rope_with_chunk, apply_wan_rope_with_flashinfer, apply_wan_rope_with_torch from .utils import apply_wan_rope_with_chunk, apply_wan_rope_with_flashinfer, apply_wan_rope_with_torch
def modulate(x, scale, shift):
return x * (1 + scale.squeeze()) + shift.squeeze()
class WanTransformerInfer(BaseTransformerInfer): class WanTransformerInfer(BaseTransformerInfer):
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
...@@ -21,6 +25,10 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -21,6 +25,10 @@ class WanTransformerInfer(BaseTransformerInfer):
self.head_dim = config["dim"] // config["num_heads"] self.head_dim = config["dim"] // config["num_heads"]
self.window_size = config.get("window_size", (-1, -1)) self.window_size = config.get("window_size", (-1, -1))
self.parallel_attention = None self.parallel_attention = None
if self.config.get("modulate_type", "triton") == "triton":
self.modulate_func = fuse_scale_shift_kernel
else:
self.modulate_func = modulate
if self.config.get("rope_type", "flashinfer") == "flashinfer": if self.config.get("rope_type", "flashinfer") == "flashinfer":
if self.config.get("rope_chunk", False): if self.config.get("rope_chunk", False):
self.apply_rope_func = partial(apply_wan_rope_with_chunk, chunk_size=self.config.get("rope_chunk_size", 100), rope_func=apply_wan_rope_with_flashinfer) self.apply_rope_func = partial(apply_wan_rope_with_chunk, chunk_size=self.config.get("rope_chunk_size", 100), rope_func=apply_wan_rope_with_flashinfer)
...@@ -146,7 +154,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -146,7 +154,7 @@ class WanTransformerInfer(BaseTransformerInfer):
norm1_out = phase.norm1.apply(x) norm1_out = phase.norm1.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype: if self.sensitive_layer_dtype != self.infer_dtype:
norm1_out = norm1_out.to(self.sensitive_layer_dtype) norm1_out = norm1_out.to(self.sensitive_layer_dtype)
norm1_out = fuse_scale_shift_kernel(norm1_out, scale=scale_msa, shift=shift_msa).squeeze(0) norm1_out = self.modulate_func(norm1_out, scale=scale_msa, shift=shift_msa).squeeze()
if self.sensitive_layer_dtype != self.infer_dtype: if self.sensitive_layer_dtype != self.infer_dtype:
norm1_out = norm1_out.to(self.infer_dtype) norm1_out = norm1_out.to(self.infer_dtype)
...@@ -285,7 +293,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -285,7 +293,7 @@ class WanTransformerInfer(BaseTransformerInfer):
norm2_out = phase.norm2.apply(x) norm2_out = phase.norm2.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype: if self.sensitive_layer_dtype != self.infer_dtype:
norm2_out = norm2_out.to(self.sensitive_layer_dtype) norm2_out = norm2_out.to(self.sensitive_layer_dtype)
norm2_out = fuse_scale_shift_kernel(norm2_out, scale=c_scale_msa, shift=c_shift_msa).squeeze(0) norm2_out = self.modulate_func(norm2_out, scale=c_scale_msa, shift=c_shift_msa).squeeze()
if self.sensitive_layer_dtype != self.infer_dtype: if self.sensitive_layer_dtype != self.infer_dtype:
norm2_out = norm2_out.to(self.infer_dtype) norm2_out = norm2_out.to(self.infer_dtype)
......
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
try:
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
except ImportError:
apply_rope_with_cos_sin_cache_inplace = None
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
......
...@@ -71,7 +71,7 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -71,7 +71,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if qwen25vl_offload: if qwen25vl_offload:
qwen25vl_device = torch.device("cpu") qwen25vl_device = torch.device("cpu")
else: else:
qwen25vl_device = torch.device(self.config.get("run_device", "cuda")) qwen25vl_device = torch.device(self.run_device)
qwen25vl_quantized = self.config.get("qwen25vl_quantized", False) qwen25vl_quantized = self.config.get("qwen25vl_quantized", False)
qwen25vl_quant_scheme = self.config.get("qwen25vl_quant_scheme", None) qwen25vl_quant_scheme = self.config.get("qwen25vl_quant_scheme", None)
...@@ -82,6 +82,7 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -82,6 +82,7 @@ class HunyuanVideo15Runner(DefaultRunner):
text_encoder = Qwen25VL_TextEncoder( text_encoder = Qwen25VL_TextEncoder(
dtype=torch.float16, dtype=torch.float16,
device=qwen25vl_device, device=qwen25vl_device,
run_device=self.run_device,
checkpoint_path=text_encoder_path, checkpoint_path=text_encoder_path,
cpu_offload=qwen25vl_offload, cpu_offload=qwen25vl_offload,
qwen25vl_quantized=qwen25vl_quantized, qwen25vl_quantized=qwen25vl_quantized,
...@@ -93,9 +94,9 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -93,9 +94,9 @@ class HunyuanVideo15Runner(DefaultRunner):
if byt5_offload: if byt5_offload:
byt5_device = torch.device("cpu") byt5_device = torch.device("cpu")
else: else:
byt5_device = torch.device(self.config.get("run_device", "cuda")) byt5_device = torch.device(self.run_device)
byt5 = ByT5TextEncoder(config=self.config, device=byt5_device, checkpoint_path=self.config["model_path"], cpu_offload=byt5_offload) byt5 = ByT5TextEncoder(config=self.config, device=byt5_device, run_device=self.run_device, checkpoint_path=self.config["model_path"], cpu_offload=byt5_offload)
text_encoders = [text_encoder, byt5] text_encoders = [text_encoder, byt5]
return text_encoders return text_encoders
...@@ -229,10 +230,11 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -229,10 +230,11 @@ class HunyuanVideo15Runner(DefaultRunner):
if siglip_offload: if siglip_offload:
siglip_device = torch.device("cpu") siglip_device = torch.device("cpu")
else: else:
siglip_device = torch.device(self.config.get("run_device", "cuda")) siglip_device = torch.device(self.run_device)
image_encoder = SiglipVisionEncoder( image_encoder = SiglipVisionEncoder(
config=self.config, config=self.config,
device=siglip_device, device=siglip_device,
run_device=self.run_device,
checkpoint_path=self.config["model_path"], checkpoint_path=self.config["model_path"],
cpu_offload=siglip_offload, cpu_offload=siglip_offload,
) )
...@@ -244,7 +246,7 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -244,7 +246,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if vae_offload: if vae_offload:
vae_device = torch.device("cpu") vae_device = torch.device("cpu")
else: else:
vae_device = torch.device(self.config.get("run_device", "cuda")) vae_device = torch.device(self.run_device)
vae_config = { vae_config = {
"checkpoint_path": self.config["model_path"], "checkpoint_path": self.config["model_path"],
...@@ -263,7 +265,7 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -263,7 +265,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if vae_offload: if vae_offload:
vae_device = torch.device("cpu") vae_device = torch.device("cpu")
else: else:
vae_device = torch.device(self.config.get("run_device", "cuda")) vae_device = torch.device(self.run_device)
vae_config = { vae_config = {
"checkpoint_path": self.config["model_path"], "checkpoint_path": self.config["model_path"],
...@@ -273,7 +275,7 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -273,7 +275,7 @@ class HunyuanVideo15Runner(DefaultRunner):
} }
if self.config.get("use_tae", False): if self.config.get("use_tae", False):
tae_path = self.config["tae_path"] tae_path = self.config["tae_path"]
vae_decoder = self.tae_cls(vae_path=tae_path, dtype=GET_DTYPE()).to(self.config.get("run_device", "cuda")) vae_decoder = self.tae_cls(vae_path=tae_path, dtype=GET_DTYPE()).to(self.run_device)
else: else:
vae_decoder = self.vae_cls(**vae_config) vae_decoder = self.vae_cls(**vae_config)
return vae_decoder return vae_decoder
...@@ -348,7 +350,7 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -348,7 +350,7 @@ class HunyuanVideo15Runner(DefaultRunner):
self.model_sr.scheduler.step_post() self.model_sr.scheduler.step_post()
del self.inputs_sr del self.inputs_sr
torch_ext_module = getattr(torch, self.config.get("run_device", "cuda")) torch_ext_module = getattr(torch, self.run_device)
torch_ext_module.empty_cache() torch_ext_module.empty_cache()
self.config_sr["is_sr_running"] = False self.config_sr["is_sr_running"] = False
...@@ -367,10 +369,10 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -367,10 +369,10 @@ class HunyuanVideo15Runner(DefaultRunner):
text_encoder_output = self.run_text_encoder(self.input_info) text_encoder_output = self.run_text_encoder(self.input_info)
# vision_states is all zero, because we don't have any image input # vision_states is all zero, because we don't have any image input
siglip_output = torch.zeros(1, self.vision_num_semantic_tokens, self.config["hidden_size"], dtype=torch.bfloat16).to(self.config.get("run_device", "cuda")) siglip_output = torch.zeros(1, self.vision_num_semantic_tokens, self.config["hidden_size"], dtype=torch.bfloat16).to(self.run_device)
siglip_mask = torch.zeros(1, self.vision_num_semantic_tokens, dtype=torch.bfloat16, device=torch.device(self.config.get("run_device", "cuda"))) siglip_mask = torch.zeros(1, self.vision_num_semantic_tokens, dtype=torch.bfloat16, device=torch.device(self.run_device))
torch_ext_module = getattr(torch, self.config.get("run_device", "cuda")) torch_ext_module = getattr(torch, self.run_device)
torch_ext_module.empty_cache() torch_ext_module.empty_cache()
gc.collect() gc.collect()
return { return {
...@@ -398,7 +400,7 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -398,7 +400,7 @@ class HunyuanVideo15Runner(DefaultRunner):
siglip_output, siglip_mask = self.run_image_encoder(img_ori) if self.config.get("use_image_encoder", True) else None siglip_output, siglip_mask = self.run_image_encoder(img_ori) if self.config.get("use_image_encoder", True) else None
cond_latents = self.run_vae_encoder(img_ori) cond_latents = self.run_vae_encoder(img_ori)
text_encoder_output = self.run_text_encoder(self.input_info) text_encoder_output = self.run_text_encoder(self.input_info)
torch_ext_module = getattr(torch, self.config.get("run_device", "cuda")) torch_ext_module = getattr(torch, self.run_device)
torch_ext_module.empty_cache() torch_ext_module.empty_cache()
gc.collect() gc.collect()
return { return {
...@@ -425,9 +427,9 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -425,9 +427,9 @@ class HunyuanVideo15Runner(DefaultRunner):
target_height = self.target_height target_height = self.target_height
input_image_np = self.resize_and_center_crop(first_frame, target_width=target_width, target_height=target_height) input_image_np = self.resize_and_center_crop(first_frame, target_width=target_width, target_height=target_height)
vision_states = self.image_encoder.encode_images(input_image_np).last_hidden_state.to(device=torch.device(self.config.get("run_device", "cuda")), dtype=torch.bfloat16) vision_states = self.image_encoder.encode_images(input_image_np).last_hidden_state.to(device=torch.device(self.run_device), dtype=torch.bfloat16)
image_encoder_output = self.image_encoder.infer(vision_states) image_encoder_output = self.image_encoder.infer(vision_states)
image_encoder_mask = torch.ones((1, image_encoder_output.shape[1]), dtype=torch.bfloat16, device=torch.device(self.config.get("run_device", "cuda"))) image_encoder_mask = torch.ones((1, image_encoder_output.shape[1]), dtype=torch.bfloat16, device=torch.device(self.run_device))
return image_encoder_output, image_encoder_mask return image_encoder_output, image_encoder_mask
def resize_and_center_crop(self, image, target_width, target_height): def resize_and_center_crop(self, image, target_width, target_height):
...@@ -478,7 +480,6 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -478,7 +480,6 @@ class HunyuanVideo15Runner(DefaultRunner):
] ]
) )
ref_images_pixel_values = ref_image_transform(first_frame).unsqueeze(0).unsqueeze(2).to(self.config.get("run_device", "cuda")) ref_images_pixel_values = ref_image_transform(first_frame).unsqueeze(0).unsqueeze(2).to(self.run_device)
cond_latents = self.vae_encoder.encode(ref_images_pixel_values.to(GET_DTYPE())) cond_latents = self.vae_encoder.encode(ref_images_pixel_values.to(GET_DTYPE()))
return cond_latents return cond_latents
...@@ -85,8 +85,8 @@ class QwenImageRunner(DefaultRunner): ...@@ -85,8 +85,8 @@ class QwenImageRunner(DefaultRunner):
def _run_input_encoder_local_t2i(self): def _run_input_encoder_local_t2i(self):
prompt = self.input_info.prompt prompt = self.input_info.prompt
text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt) text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt)
if hasattr(torch, self.config.get("run_device", "cuda")): if hasattr(torch, self.run_device):
torch_module = getattr(torch, self.config.get("run_device", "cuda")) torch_module = getattr(torch, self.run_device)
torch_module.empty_cache() torch_module.empty_cache()
gc.collect() gc.collect()
return { return {
...@@ -102,7 +102,7 @@ class QwenImageRunner(DefaultRunner): ...@@ -102,7 +102,7 @@ class QwenImageRunner(DefaultRunner):
if GET_RECORDER_MODE(): if GET_RECORDER_MODE():
width, height = img_ori.size width, height = img_ori.size
monitor_cli.lightx2v_input_image_len.observe(width * height) monitor_cli.lightx2v_input_image_len.observe(width * height)
img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).to(self.init_device) img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).to(self.run_device)
self.input_info.original_size.append(img_ori.size) self.input_info.original_size.append(img_ori.size)
return img, img_ori return img, img_ori
...@@ -121,8 +121,8 @@ class QwenImageRunner(DefaultRunner): ...@@ -121,8 +121,8 @@ class QwenImageRunner(DefaultRunner):
for vae_image in text_encoder_output["image_info"]["vae_image_list"]: for vae_image in text_encoder_output["image_info"]["vae_image_list"]:
image_encoder_output = self.run_vae_encoder(image=vae_image) image_encoder_output = self.run_vae_encoder(image=vae_image)
image_encoder_output_list.append(image_encoder_output) image_encoder_output_list.append(image_encoder_output)
if hasattr(torch, self.config.get("run_device", "cuda")): if hasattr(torch, self.run_device):
torch_module = getattr(torch, self.config.get("run_device", "cuda")) torch_module = getattr(torch, self.run_device)
torch_module.empty_cache() torch_module.empty_cache()
gc.collect() gc.collect()
return { return {
...@@ -238,8 +238,8 @@ class QwenImageRunner(DefaultRunner): ...@@ -238,8 +238,8 @@ class QwenImageRunner(DefaultRunner):
images = self.vae.decode(latents, self.input_info) images = self.vae.decode(latents, self.input_info)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_decoder del self.vae_decoder
if hasattr(torch, self.config.get("run_device", "cuda")): if hasattr(torch, self.run_device):
torch_module = getattr(torch, self.config.get("run_device", "cuda")) torch_module = getattr(torch, self.run_device)
torch_module.empty_cache() torch_module.empty_cache()
gc.collect() gc.collect()
return images return images
...@@ -259,8 +259,8 @@ class QwenImageRunner(DefaultRunner): ...@@ -259,8 +259,8 @@ class QwenImageRunner(DefaultRunner):
image.save(f"{input_info.save_result_path}") image.save(f"{input_info.save_result_path}")
del latents, generator del latents, generator
if hasattr(torch, self.config.get("run_device", "cuda")): if hasattr(torch, self.run_device):
torch_module = getattr(torch, self.config.get("run_device", "cuda")) torch_module = getattr(torch, self.run_device)
torch_module.empty_cache() torch_module.empty_cache()
gc.collect() gc.collect()
......
...@@ -835,7 +835,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -835,7 +835,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def load_audio_encoder(self): def load_audio_encoder(self):
audio_encoder_path = self.config.get("audio_encoder_path", os.path.join(self.config["model_path"], "TencentGameMate-chinese-hubert-large")) audio_encoder_path = self.config.get("audio_encoder_path", os.path.join(self.config["model_path"], "TencentGameMate-chinese-hubert-large"))
audio_encoder_offload = self.config.get("audio_encoder_cpu_offload", self.config.get("cpu_offload", False)) audio_encoder_offload = self.config.get("audio_encoder_cpu_offload", self.config.get("cpu_offload", False))
model = SekoAudioEncoderModel(audio_encoder_path, self.config["audio_sr"], audio_encoder_offload, device=self.config.get("run_device", "cuda")) model = SekoAudioEncoderModel(audio_encoder_path, self.config["audio_sr"], audio_encoder_offload, run_device=self.config.get("run_device", "cuda"))
return model return model
def load_audio_adapter(self): def load_audio_adapter(self):
...@@ -843,7 +843,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -843,7 +843,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if audio_adapter_offload: if audio_adapter_offload:
device = torch.device("cpu") device = torch.device("cpu")
else: else:
device = torch.device(self.config.get("run_device", "cuda")) device = torch.device(self.run_device)
audio_adapter = AudioAdapter( audio_adapter = AudioAdapter(
attention_head_dim=self.config["dim"] // self.config["num_heads"], attention_head_dim=self.config["dim"] // self.config["num_heads"],
num_attention_heads=self.config["num_heads"], num_attention_heads=self.config["num_heads"],
...@@ -856,7 +856,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -856,7 +856,7 @@ class WanAudioRunner(WanRunner): # type:ignore
quantized=self.config.get("adapter_quantized", False), quantized=self.config.get("adapter_quantized", False),
quant_scheme=self.config.get("adapter_quant_scheme", None), quant_scheme=self.config.get("adapter_quant_scheme", None),
cpu_offload=audio_adapter_offload, cpu_offload=audio_adapter_offload,
device=self.config.get("run_device", "cuda"), run_device=self.run_device,
) )
audio_adapter.to(device) audio_adapter.to(device)
...@@ -892,7 +892,7 @@ class Wan22AudioRunner(WanAudioRunner): ...@@ -892,7 +892,7 @@ class Wan22AudioRunner(WanAudioRunner):
if vae_offload: if vae_offload:
vae_device = torch.device("cpu") vae_device = torch.device("cpu")
else: else:
vae_device = torch.device(self.config.get("run_device", "cuda")) vae_device = torch.device(self.run_device)
vae_config = { vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"), "vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"),
"device": vae_device, "device": vae_device,
...@@ -909,7 +909,7 @@ class Wan22AudioRunner(WanAudioRunner): ...@@ -909,7 +909,7 @@ class Wan22AudioRunner(WanAudioRunner):
if vae_offload: if vae_offload:
vae_device = torch.device("cpu") vae_device = torch.device("cpu")
else: else:
vae_device = torch.device(self.config.get("run_device", "cuda")) vae_device = torch.device(self.run_device)
vae_config = { vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"), "vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"),
"device": vae_device, "device": vae_device,
......
...@@ -65,7 +65,7 @@ class WanRunner(DefaultRunner): ...@@ -65,7 +65,7 @@ class WanRunner(DefaultRunner):
if clip_offload: if clip_offload:
clip_device = torch.device("cpu") clip_device = torch.device("cpu")
else: else:
clip_device = torch.device(self.init_device) clip_device = torch.device(self.run_device)
# quant_config # quant_config
clip_quantized = self.config.get("clip_quantized", False) clip_quantized = self.config.get("clip_quantized", False)
if clip_quantized: if clip_quantized:
...@@ -123,6 +123,7 @@ class WanRunner(DefaultRunner): ...@@ -123,6 +123,7 @@ class WanRunner(DefaultRunner):
text_len=self.config["text_len"], text_len=self.config["text_len"],
dtype=torch.bfloat16, dtype=torch.bfloat16,
device=t5_device, device=t5_device,
run_device=self.run_device,
checkpoint_path=t5_original_ckpt, checkpoint_path=t5_original_ckpt,
tokenizer_path=tokenizer_path, tokenizer_path=tokenizer_path,
shard_fn=None, shard_fn=None,
...@@ -141,7 +142,7 @@ class WanRunner(DefaultRunner): ...@@ -141,7 +142,7 @@ class WanRunner(DefaultRunner):
if vae_offload: if vae_offload:
vae_device = torch.device("cpu") vae_device = torch.device("cpu")
else: else:
vae_device = torch.device(self.init_device) vae_device = torch.device(self.run_device)
vae_config = { vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name), "vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name),
...@@ -320,7 +321,7 @@ class WanRunner(DefaultRunner): ...@@ -320,7 +321,7 @@ class WanRunner(DefaultRunner):
self.config["target_video_length"], self.config["target_video_length"],
lat_h, lat_h,
lat_w, lat_w,
device=torch.device(self.config.get("run_device", "cuda")), device=torch.device(self.run_device),
) )
if last_frame is not None: if last_frame is not None:
msk[:, 1:-1] = 0 msk[:, 1:-1] = 0
...@@ -342,7 +343,7 @@ class WanRunner(DefaultRunner): ...@@ -342,7 +343,7 @@ class WanRunner(DefaultRunner):
torch.nn.functional.interpolate(last_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.nn.functional.interpolate(last_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
], ],
dim=1, dim=1,
).to(self.init_device) ).to(self.run_device)
else: else:
vae_input = torch.concat( vae_input = torch.concat(
[ [
...@@ -350,7 +351,7 @@ class WanRunner(DefaultRunner): ...@@ -350,7 +351,7 @@ class WanRunner(DefaultRunner):
torch.zeros(3, self.config["target_video_length"] - 1, h, w), torch.zeros(3, self.config["target_video_length"] - 1, h, w),
], ],
dim=1, dim=1,
).to(self.init_device) ).to(self.run_device)
vae_encoder_out = self.vae_encoder.encode(vae_input.unsqueeze(0).to(GET_DTYPE())) vae_encoder_out = self.vae_encoder.encode(vae_input.unsqueeze(0).to(GET_DTYPE()))
...@@ -533,7 +534,7 @@ class Wan22DenseRunner(WanRunner): ...@@ -533,7 +534,7 @@ class Wan22DenseRunner(WanRunner):
assert img.width == ow and img.height == oh assert img.width == ow and img.height == oh
# to tensor # to tensor
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.init_device).unsqueeze(1) img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.run_device).unsqueeze(1)
vae_encoder_out = self.get_vae_encoder_output(img) vae_encoder_out = self.get_vae_encoder_output(img)
latent_w, latent_h = ow // self.config["vae_stride"][2], oh // self.config["vae_stride"][1] latent_w, latent_h = ow // self.config["vae_stride"][2], oh // self.config["vae_stride"][1]
latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w) latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w)
......
...@@ -271,8 +271,8 @@ def get_1d_rotary_pos_embed( ...@@ -271,8 +271,8 @@ def get_1d_rotary_pos_embed(
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
device = kwds.get("device", "cuda") run_device = kwds.get("run_device", "cuda")
freqs = torch.outer(pos * interpolation_factor, freqs).to(device) # [S, D/2] freqs = torch.outer(pos * interpolation_factor, freqs).to(run_device) # [S, D/2]
if use_real: if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
......
...@@ -11,7 +11,7 @@ from .posemb_layers import get_nd_rotary_pos_embed ...@@ -11,7 +11,7 @@ from .posemb_layers import get_nd_rotary_pos_embed
class HunyuanVideo15Scheduler(BaseScheduler): class HunyuanVideo15Scheduler(BaseScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.device = torch.device(self.config.get("run_device", "cuda")) self.run_device = torch.device(self.config.get("run_device", "cuda"))
self.reverse = True self.reverse = True
self.num_train_timesteps = 1000 self.num_train_timesteps = 1000
self.sample_shift = self.config["sample_shift"] self.sample_shift = self.config["sample_shift"]
...@@ -25,13 +25,13 @@ class HunyuanVideo15Scheduler(BaseScheduler): ...@@ -25,13 +25,13 @@ class HunyuanVideo15Scheduler(BaseScheduler):
def prepare(self, seed, latent_shape, image_encoder_output=None): def prepare(self, seed, latent_shape, image_encoder_output=None):
self.prepare_latents(seed, latent_shape, dtype=torch.bfloat16) self.prepare_latents(seed, latent_shape, dtype=torch.bfloat16)
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift) self.set_timesteps(self.infer_steps, device=self.run_device, shift=self.sample_shift)
self.multitask_mask = self.get_task_mask(self.config["task"], latent_shape[-3]) self.multitask_mask = self.get_task_mask(self.config["task"], latent_shape[-3])
self.cond_latents_concat, self.mask_concat = self._prepare_cond_latents_and_mask(self.config["task"], image_encoder_output["cond_latents"], self.latents, self.multitask_mask, self.reorg_token) self.cond_latents_concat, self.mask_concat = self._prepare_cond_latents_and_mask(self.config["task"], image_encoder_output["cond_latents"], self.latents, self.multitask_mask, self.reorg_token)
self.cos_sin = self.prepare_cos_sin((latent_shape[1], latent_shape[2], latent_shape[3])) self.cos_sin = self.prepare_cos_sin((latent_shape[1], latent_shape[2], latent_shape[3]))
def prepare_latents(self, seed, latent_shape, dtype=torch.bfloat16): def prepare_latents(self, seed, latent_shape, dtype=torch.bfloat16):
self.generator = torch.Generator(device=self.device).manual_seed(seed) self.generator = torch.Generator(device=self.run_device).manual_seed(seed)
self.latents = torch.randn( self.latents = torch.randn(
1, 1,
latent_shape[0], latent_shape[0],
...@@ -39,7 +39,7 @@ class HunyuanVideo15Scheduler(BaseScheduler): ...@@ -39,7 +39,7 @@ class HunyuanVideo15Scheduler(BaseScheduler):
latent_shape[2], latent_shape[2],
latent_shape[3], latent_shape[3],
dtype=dtype, dtype=dtype,
device=self.device, device=self.run_device,
generator=self.generator, generator=self.generator,
) )
...@@ -127,7 +127,7 @@ class HunyuanVideo15Scheduler(BaseScheduler): ...@@ -127,7 +127,7 @@ class HunyuanVideo15Scheduler(BaseScheduler):
if rope_dim_list is None: if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(rope_dim_list, rope_sizes, theta=self.config["rope_theta"], use_real=True, theta_rescale_factor=1, device=self.device) freqs_cos, freqs_sin = get_nd_rotary_pos_embed(rope_dim_list, rope_sizes, theta=self.config["rope_theta"], use_real=True, theta_rescale_factor=1, device=self.run_device)
cos_half = freqs_cos[:, ::2].contiguous() cos_half = freqs_cos[:, ::2].contiguous()
sin_half = freqs_sin[:, ::2].contiguous() sin_half = freqs_sin[:, ::2].contiguous()
cos_sin = torch.cat([cos_half, sin_half], dim=-1) cos_sin = torch.cat([cos_half, sin_half], dim=-1)
...@@ -151,7 +151,7 @@ class HunyuanVideo15SRScheduler(HunyuanVideo15Scheduler): ...@@ -151,7 +151,7 @@ class HunyuanVideo15SRScheduler(HunyuanVideo15Scheduler):
dtype = lq_latents.dtype dtype = lq_latents.dtype
device = lq_latents.device device = lq_latents.device
self.prepare_latents(seed, latent_shape, lq_latents, dtype=dtype) self.prepare_latents(seed, latent_shape, lq_latents, dtype=dtype)
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift) self.set_timesteps(self.infer_steps, device=self.run_device, shift=self.sample_shift)
self.cos_sin = self.prepare_cos_sin((latent_shape[1], latent_shape[2], latent_shape[3])) self.cos_sin = self.prepare_cos_sin((latent_shape[1], latent_shape[2], latent_shape[3]))
tgt_shape = latent_shape[-2:] tgt_shape = latent_shape[-2:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment