Unverified Commit c0863477 authored by Kane's avatar Kane Committed by GitHub
Browse files

Mlu590 (#520)



1. 修复之前的代码合并冲突,并测试通过。

---------
Co-authored-by: default avatarYang Yong (雍洋) <yongyang1030@163.com>
parent 47b3ce2f
...@@ -111,7 +111,7 @@ def apply_attn(block_weight, hidden_states, encoder_hidden_states, image_rotary_ ...@@ -111,7 +111,7 @@ def apply_attn(block_weight, hidden_states, encoder_hidden_states, image_rotary_
if attn_type == "torch_sdpa": if attn_type == "torch_sdpa":
joint_hidden_states = block_weight.attn.calculate.apply(q=joint_query, k=joint_key, v=joint_value) joint_hidden_states = block_weight.attn.calculate.apply(q=joint_query, k=joint_key, v=joint_value)
elif attn_type in ["flash_attn3", "sage_attn2", "mlu_flash_attn"]: elif attn_type in ["flash_attn3", "sage_attn2", "mlu_flash_attn", "flash_attn2", "mlu_sage_attn"]:
joint_query = joint_query.squeeze(0) joint_query = joint_query.squeeze(0)
joint_key = joint_key.squeeze(0) joint_key = joint_key.squeeze(0)
joint_value = joint_value.squeeze(0) joint_value = joint_value.squeeze(0)
......
...@@ -44,7 +44,7 @@ class WanModel(CompiledMethodsMixin): ...@@ -44,7 +44,7 @@ class WanModel(CompiledMethodsMixin):
super().__init__() super().__init__()
self.model_path = model_path self.model_path = model_path
self.config = config self.config = config
self.run_device = self.config.get("run_device", "cuda") self.device = self.config.get("run_device", "cuda")
self.cpu_offload = self.config.get("cpu_offload", False) self.cpu_offload = self.config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block") self.offload_granularity = self.config.get("offload_granularity", "block")
self.model_type = model_type self.model_type = model_type
......
...@@ -63,7 +63,7 @@ class DefaultRunner(BaseRunner): ...@@ -63,7 +63,7 @@ class DefaultRunner(BaseRunner):
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
self.init_device = torch.device("cpu") self.init_device = torch.device("cpu")
else: else:
self.init_device = torch.device(self.run_device) self.init_device = torch.device(self.config.get("run_device", "cuda"))
def load_vfi_model(self): def load_vfi_model(self):
if self.config["video_frame_interpolation"].get("algo", None) == "rife": if self.config["video_frame_interpolation"].get("algo", None) == "rife":
...@@ -205,7 +205,7 @@ class DefaultRunner(BaseRunner): ...@@ -205,7 +205,7 @@ class DefaultRunner(BaseRunner):
if GET_RECORDER_MODE(): if GET_RECORDER_MODE():
width, height = img_ori.size width, height = img_ori.size
monitor_cli.lightx2v_input_image_len.observe(width * height) monitor_cli.lightx2v_input_image_len.observe(width * height)
img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda() img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).to(self.init_device)
self.input_info.original_size = img_ori.size self.input_info.original_size = img_ori.size
return img, img_ori return img, img_ori
......
...@@ -71,7 +71,7 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -71,7 +71,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if qwen25vl_offload: if qwen25vl_offload:
qwen25vl_device = torch.device("cpu") qwen25vl_device = torch.device("cpu")
else: else:
qwen25vl_device = torch.device("cuda") qwen25vl_device = torch.device(self.config.get("run_device", "cuda"))
qwen25vl_quantized = self.config.get("qwen25vl_quantized", False) qwen25vl_quantized = self.config.get("qwen25vl_quantized", False)
qwen25vl_quant_scheme = self.config.get("qwen25vl_quant_scheme", None) qwen25vl_quant_scheme = self.config.get("qwen25vl_quant_scheme", None)
...@@ -93,7 +93,7 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -93,7 +93,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if byt5_offload: if byt5_offload:
byt5_device = torch.device("cpu") byt5_device = torch.device("cpu")
else: else:
byt5_device = torch.device("cuda") byt5_device = torch.device(self.config.get("run_device", "cuda"))
byt5 = ByT5TextEncoder(config=self.config, device=byt5_device, checkpoint_path=self.config["model_path"], cpu_offload=byt5_offload) byt5 = ByT5TextEncoder(config=self.config, device=byt5_device, checkpoint_path=self.config["model_path"], cpu_offload=byt5_offload)
text_encoders = [text_encoder, byt5] text_encoders = [text_encoder, byt5]
...@@ -229,7 +229,7 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -229,7 +229,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if siglip_offload: if siglip_offload:
siglip_device = torch.device("cpu") siglip_device = torch.device("cpu")
else: else:
siglip_device = torch.device("cuda") siglip_device = torch.device(self.config.get("run_device", "cuda"))
image_encoder = SiglipVisionEncoder( image_encoder = SiglipVisionEncoder(
config=self.config, config=self.config,
device=siglip_device, device=siglip_device,
...@@ -244,7 +244,7 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -244,7 +244,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if vae_offload: if vae_offload:
vae_device = torch.device("cpu") vae_device = torch.device("cpu")
else: else:
vae_device = torch.device("cuda") vae_device = torch.device(self.config.get("run_device", "cuda"))
vae_config = { vae_config = {
"checkpoint_path": self.config["model_path"], "checkpoint_path": self.config["model_path"],
...@@ -263,7 +263,7 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -263,7 +263,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if vae_offload: if vae_offload:
vae_device = torch.device("cpu") vae_device = torch.device("cpu")
else: else:
vae_device = torch.device("cuda") vae_device = torch.device(self.config.get("run_device", "cuda"))
vae_config = { vae_config = {
"checkpoint_path": self.config["model_path"], "checkpoint_path": self.config["model_path"],
...@@ -273,7 +273,7 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -273,7 +273,7 @@ class HunyuanVideo15Runner(DefaultRunner):
} }
if self.config.get("use_tae", False): if self.config.get("use_tae", False):
tae_path = self.config["tae_path"] tae_path = self.config["tae_path"]
vae_decoder = self.tae_cls(vae_path=tae_path, dtype=GET_DTYPE()).to("cuda") vae_decoder = self.tae_cls(vae_path=tae_path, dtype=GET_DTYPE()).to(self.config.get("run_device", "cuda"))
else: else:
vae_decoder = self.vae_cls(**vae_config) vae_decoder = self.vae_cls(**vae_config)
return vae_decoder return vae_decoder
...@@ -348,7 +348,8 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -348,7 +348,8 @@ class HunyuanVideo15Runner(DefaultRunner):
self.model_sr.scheduler.step_post() self.model_sr.scheduler.step_post()
del self.inputs_sr del self.inputs_sr
torch.cuda.empty_cache() torch_ext_module = getattr(torch, self.config.get("run_device", "cuda"))
torch_ext_module.empty_cache()
self.config_sr["is_sr_running"] = False self.config_sr["is_sr_running"] = False
return self.model_sr.scheduler.latents return self.model_sr.scheduler.latents
...@@ -366,10 +367,11 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -366,10 +367,11 @@ class HunyuanVideo15Runner(DefaultRunner):
text_encoder_output = self.run_text_encoder(self.input_info) text_encoder_output = self.run_text_encoder(self.input_info)
# vision_states is all zero, because we don't have any image input # vision_states is all zero, because we don't have any image input
siglip_output = torch.zeros(1, self.vision_num_semantic_tokens, self.config["hidden_size"], dtype=torch.bfloat16).cuda() siglip_output = torch.zeros(1, self.vision_num_semantic_tokens, self.config["hidden_size"], dtype=torch.bfloat16).to(self.config.get("run_device", "cuda"))
siglip_mask = torch.zeros(1, self.vision_num_semantic_tokens, dtype=torch.bfloat16, device=torch.device("cuda")) siglip_mask = torch.zeros(1, self.vision_num_semantic_tokens, dtype=torch.bfloat16, device=torch.device(self.config.get("run_device", "cuda")))
torch.cuda.empty_cache() torch_ext_module = getattr(torch, self.config.get("run_device", "cuda"))
torch_ext_module.empty_cache()
gc.collect() gc.collect()
return { return {
"text_encoder_output": text_encoder_output, "text_encoder_output": text_encoder_output,
...@@ -396,7 +398,8 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -396,7 +398,8 @@ class HunyuanVideo15Runner(DefaultRunner):
siglip_output, siglip_mask = self.run_image_encoder(img_ori) if self.config.get("use_image_encoder", True) else None siglip_output, siglip_mask = self.run_image_encoder(img_ori) if self.config.get("use_image_encoder", True) else None
cond_latents = self.run_vae_encoder(img_ori) cond_latents = self.run_vae_encoder(img_ori)
text_encoder_output = self.run_text_encoder(self.input_info) text_encoder_output = self.run_text_encoder(self.input_info)
torch.cuda.empty_cache() torch_ext_module = getattr(torch, self.config.get("run_device", "cuda"))
torch_ext_module.empty_cache()
gc.collect() gc.collect()
return { return {
"text_encoder_output": text_encoder_output, "text_encoder_output": text_encoder_output,
...@@ -422,9 +425,9 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -422,9 +425,9 @@ class HunyuanVideo15Runner(DefaultRunner):
target_height = self.target_height target_height = self.target_height
input_image_np = self.resize_and_center_crop(first_frame, target_width=target_width, target_height=target_height) input_image_np = self.resize_and_center_crop(first_frame, target_width=target_width, target_height=target_height)
vision_states = self.image_encoder.encode_images(input_image_np).last_hidden_state.to(device=torch.device("cuda"), dtype=torch.bfloat16) vision_states = self.image_encoder.encode_images(input_image_np).last_hidden_state.to(device=torch.device(self.config.get("run_device", "cuda")), dtype=torch.bfloat16)
image_encoder_output = self.image_encoder.infer(vision_states) image_encoder_output = self.image_encoder.infer(vision_states)
image_encoder_mask = torch.ones((1, image_encoder_output.shape[1]), dtype=torch.bfloat16, device=torch.device("cuda")) image_encoder_mask = torch.ones((1, image_encoder_output.shape[1]), dtype=torch.bfloat16, device=torch.device(self.config.get("run_device", "cuda")))
return image_encoder_output, image_encoder_mask return image_encoder_output, image_encoder_mask
def resize_and_center_crop(self, image, target_width, target_height): def resize_and_center_crop(self, image, target_width, target_height):
...@@ -475,7 +478,7 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -475,7 +478,7 @@ class HunyuanVideo15Runner(DefaultRunner):
] ]
) )
ref_images_pixel_values = ref_image_transform(first_frame).unsqueeze(0).unsqueeze(2).cuda() ref_images_pixel_values = ref_image_transform(first_frame).unsqueeze(0).unsqueeze(2).to(self.config.get("run_device", "cuda"))
cond_latents = self.vae_encoder.encode(ref_images_pixel_values.to(GET_DTYPE())) cond_latents = self.vae_encoder.encode(ref_images_pixel_values.to(GET_DTYPE()))
return cond_latents return cond_latents
...@@ -85,7 +85,9 @@ class QwenImageRunner(DefaultRunner): ...@@ -85,7 +85,9 @@ class QwenImageRunner(DefaultRunner):
def _run_input_encoder_local_t2i(self): def _run_input_encoder_local_t2i(self):
prompt = self.input_info.prompt prompt = self.input_info.prompt
text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt) text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt)
torch.cuda.empty_cache() if hasattr(torch, self.config.get("run_device", "cuda")):
torch_module = getattr(torch, self.config.get("run_device", "cuda"))
torch_module.empty_cache()
gc.collect() gc.collect()
return { return {
"text_encoder_output": text_encoder_output, "text_encoder_output": text_encoder_output,
...@@ -100,7 +102,7 @@ class QwenImageRunner(DefaultRunner): ...@@ -100,7 +102,7 @@ class QwenImageRunner(DefaultRunner):
if GET_RECORDER_MODE(): if GET_RECORDER_MODE():
width, height = img_ori.size width, height = img_ori.size
monitor_cli.lightx2v_input_image_len.observe(width * height) monitor_cli.lightx2v_input_image_len.observe(width * height)
img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda() img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).to(self.init_device)
self.input_info.original_size.append(img_ori.size) self.input_info.original_size.append(img_ori.size)
return img, img_ori return img, img_ori
...@@ -119,8 +121,9 @@ class QwenImageRunner(DefaultRunner): ...@@ -119,8 +121,9 @@ class QwenImageRunner(DefaultRunner):
for vae_image in text_encoder_output["image_info"]["vae_image_list"]: for vae_image in text_encoder_output["image_info"]["vae_image_list"]:
image_encoder_output = self.run_vae_encoder(image=vae_image) image_encoder_output = self.run_vae_encoder(image=vae_image)
image_encoder_output_list.append(image_encoder_output) image_encoder_output_list.append(image_encoder_output)
if hasattr(torch, self.config.get("run_device", "cuda")):
torch.cuda.empty_cache() torch_module = getattr(torch, self.config.get("run_device", "cuda"))
torch_module.empty_cache()
gc.collect() gc.collect()
return { return {
"text_encoder_output": text_encoder_output, "text_encoder_output": text_encoder_output,
...@@ -235,7 +238,9 @@ class QwenImageRunner(DefaultRunner): ...@@ -235,7 +238,9 @@ class QwenImageRunner(DefaultRunner):
images = self.vae.decode(latents, self.input_info) images = self.vae.decode(latents, self.input_info)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_decoder del self.vae_decoder
torch.cuda.empty_cache() if hasattr(torch, self.config.get("run_device", "cuda")):
torch_module = getattr(torch, self.config.get("run_device", "cuda"))
torch_module.empty_cache()
gc.collect() gc.collect()
return images return images
...@@ -254,7 +259,9 @@ class QwenImageRunner(DefaultRunner): ...@@ -254,7 +259,9 @@ class QwenImageRunner(DefaultRunner):
image.save(f"{input_info.save_result_path}") image.save(f"{input_info.save_result_path}")
del latents, generator del latents, generator
torch.cuda.empty_cache() if hasattr(torch, self.config.get("run_device", "cuda")):
torch_module = getattr(torch, self.config.get("run_device", "cuda"))
torch_module.empty_cache()
gc.collect() gc.collect()
# Return (images, audio) - audio is None for default runner # Return (images, audio) - audio is None for default runner
......
...@@ -835,7 +835,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -835,7 +835,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def load_audio_encoder(self): def load_audio_encoder(self):
audio_encoder_path = self.config.get("audio_encoder_path", os.path.join(self.config["model_path"], "TencentGameMate-chinese-hubert-large")) audio_encoder_path = self.config.get("audio_encoder_path", os.path.join(self.config["model_path"], "TencentGameMate-chinese-hubert-large"))
audio_encoder_offload = self.config.get("audio_encoder_cpu_offload", self.config.get("cpu_offload", False)) audio_encoder_offload = self.config.get("audio_encoder_cpu_offload", self.config.get("cpu_offload", False))
model = SekoAudioEncoderModel(audio_encoder_path, self.config["audio_sr"], audio_encoder_offload, run_device=self.config.get("run_device", "cuda")) model = SekoAudioEncoderModel(audio_encoder_path, self.config["audio_sr"], audio_encoder_offload, device=self.config.get("run_device", "cuda"))
return model return model
def load_audio_adapter(self): def load_audio_adapter(self):
...@@ -843,8 +843,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -843,8 +843,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if audio_adapter_offload: if audio_adapter_offload:
device = torch.device("cpu") device = torch.device("cpu")
else: else:
device = torch.device(self.run_device) device = torch.device(self.config.get("run_device", "cuda"))
audio_adapter = AudioAdapter( audio_adapter = AudioAdapter(
attention_head_dim=self.config["dim"] // self.config["num_heads"], attention_head_dim=self.config["dim"] // self.config["num_heads"],
num_attention_heads=self.config["num_heads"], num_attention_heads=self.config["num_heads"],
...@@ -857,7 +856,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -857,7 +856,7 @@ class WanAudioRunner(WanRunner): # type:ignore
quantized=self.config.get("adapter_quantized", False), quantized=self.config.get("adapter_quantized", False),
quant_scheme=self.config.get("adapter_quant_scheme", None), quant_scheme=self.config.get("adapter_quant_scheme", None),
cpu_offload=audio_adapter_offload, cpu_offload=audio_adapter_offload,
run_device=self.run_device, device=self.config.get("run_device", "cuda"),
) )
audio_adapter.to(device) audio_adapter.to(device)
...@@ -893,7 +892,7 @@ class Wan22AudioRunner(WanAudioRunner): ...@@ -893,7 +892,7 @@ class Wan22AudioRunner(WanAudioRunner):
if vae_offload: if vae_offload:
vae_device = torch.device("cpu") vae_device = torch.device("cpu")
else: else:
vae_device = torch.device("cuda") vae_device = torch.device(self.config.get("run_device", "cuda"))
vae_config = { vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"), "vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"),
"device": vae_device, "device": vae_device,
...@@ -910,7 +909,7 @@ class Wan22AudioRunner(WanAudioRunner): ...@@ -910,7 +909,7 @@ class Wan22AudioRunner(WanAudioRunner):
if vae_offload: if vae_offload:
vae_device = torch.device("cpu") vae_device = torch.device("cpu")
else: else:
vae_device = torch.device("cuda") vae_device = torch.device(self.config.get("run_device", "cuda"))
vae_config = { vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"), "vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"),
"device": vae_device, "device": vae_device,
......
...@@ -65,7 +65,7 @@ class WanRunner(DefaultRunner): ...@@ -65,7 +65,7 @@ class WanRunner(DefaultRunner):
if clip_offload: if clip_offload:
clip_device = torch.device("cpu") clip_device = torch.device("cpu")
else: else:
clip_device = torch.device(self.run_device) clip_device = torch.device(self.init_device)
# quant_config # quant_config
clip_quantized = self.config.get("clip_quantized", False) clip_quantized = self.config.get("clip_quantized", False)
if clip_quantized: if clip_quantized:
...@@ -122,7 +122,6 @@ class WanRunner(DefaultRunner): ...@@ -122,7 +122,6 @@ class WanRunner(DefaultRunner):
text_encoder = T5EncoderModel( text_encoder = T5EncoderModel(
text_len=self.config["text_len"], text_len=self.config["text_len"],
dtype=torch.bfloat16, dtype=torch.bfloat16,
run_device=self.run_device,
device=t5_device, device=t5_device,
checkpoint_path=t5_original_ckpt, checkpoint_path=t5_original_ckpt,
tokenizer_path=tokenizer_path, tokenizer_path=tokenizer_path,
...@@ -142,7 +141,7 @@ class WanRunner(DefaultRunner): ...@@ -142,7 +141,7 @@ class WanRunner(DefaultRunner):
if vae_offload: if vae_offload:
vae_device = torch.device("cpu") vae_device = torch.device("cpu")
else: else:
vae_device = torch.device(self.run_device) vae_device = torch.device(self.init_device)
vae_config = { vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name), "vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name),
...@@ -321,7 +320,7 @@ class WanRunner(DefaultRunner): ...@@ -321,7 +320,7 @@ class WanRunner(DefaultRunner):
self.config["target_video_length"], self.config["target_video_length"],
lat_h, lat_h,
lat_w, lat_w,
device=torch.device("cuda"), device=torch.device(self.config.get("run_device", "cuda")),
) )
if last_frame is not None: if last_frame is not None:
msk[:, 1:-1] = 0 msk[:, 1:-1] = 0
...@@ -343,7 +342,7 @@ class WanRunner(DefaultRunner): ...@@ -343,7 +342,7 @@ class WanRunner(DefaultRunner):
torch.nn.functional.interpolate(last_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.nn.functional.interpolate(last_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
], ],
dim=1, dim=1,
).cuda() ).to(self.init_device)
else: else:
vae_input = torch.concat( vae_input = torch.concat(
[ [
...@@ -351,7 +350,7 @@ class WanRunner(DefaultRunner): ...@@ -351,7 +350,7 @@ class WanRunner(DefaultRunner):
torch.zeros(3, self.config["target_video_length"] - 1, h, w), torch.zeros(3, self.config["target_video_length"] - 1, h, w),
], ],
dim=1, dim=1,
).cuda() ).to(self.init_device)
vae_encoder_out = self.vae_encoder.encode(vae_input.unsqueeze(0).to(GET_DTYPE())) vae_encoder_out = self.vae_encoder.encode(vae_input.unsqueeze(0).to(GET_DTYPE()))
...@@ -534,7 +533,7 @@ class Wan22DenseRunner(WanRunner): ...@@ -534,7 +533,7 @@ class Wan22DenseRunner(WanRunner):
assert img.width == ow and img.height == oh assert img.width == ow and img.height == oh
# to tensor # to tensor
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda().unsqueeze(1) img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.init_device).unsqueeze(1)
vae_encoder_out = self.get_vae_encoder_output(img) vae_encoder_out = self.get_vae_encoder_output(img)
latent_w, latent_h = ow // self.config["vae_stride"][2], oh // self.config["vae_stride"][1] latent_w, latent_h = ow // self.config["vae_stride"][2], oh // self.config["vae_stride"][1]
latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w) latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w)
......
...@@ -174,6 +174,7 @@ def get_nd_rotary_pos_embed( ...@@ -174,6 +174,7 @@ def get_nd_rotary_pos_embed(
use_real=False, use_real=False,
theta_rescale_factor: Union[float, List[float]] = 1.0, theta_rescale_factor: Union[float, List[float]] = 1.0,
interpolation_factor: Union[float, List[float]] = 1.0, interpolation_factor: Union[float, List[float]] = 1.0,
**kwds,
): ):
""" """
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
...@@ -218,6 +219,7 @@ def get_nd_rotary_pos_embed( ...@@ -218,6 +219,7 @@ def get_nd_rotary_pos_embed(
use_real=use_real, use_real=use_real,
theta_rescale_factor=theta_rescale_factor[i], theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i], interpolation_factor=interpolation_factor[i],
**kwds,
) # 2 x [WHD, rope_dim_list[i]] ) # 2 x [WHD, rope_dim_list[i]]
embs.append(emb) embs.append(emb)
...@@ -237,6 +239,7 @@ def get_1d_rotary_pos_embed( ...@@ -237,6 +239,7 @@ def get_1d_rotary_pos_embed(
use_real: bool = False, use_real: bool = False,
theta_rescale_factor: float = 1.0, theta_rescale_factor: float = 1.0,
interpolation_factor: float = 1.0, interpolation_factor: float = 1.0,
**kwds,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
""" """
Precompute the frequency tensor for complex exponential (cis) with given dimensions. Precompute the frequency tensor for complex exponential (cis) with given dimensions.
...@@ -268,7 +271,8 @@ def get_1d_rotary_pos_embed( ...@@ -268,7 +271,8 @@ def get_1d_rotary_pos_embed(
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
freqs = torch.outer(pos * interpolation_factor, freqs).cuda() # [S, D/2] device = kwds.get("device", "cuda")
freqs = torch.outer(pos * interpolation_factor, freqs).to(device) # [S, D/2]
if use_real: if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
......
...@@ -11,7 +11,7 @@ from .posemb_layers import get_nd_rotary_pos_embed ...@@ -11,7 +11,7 @@ from .posemb_layers import get_nd_rotary_pos_embed
class HunyuanVideo15Scheduler(BaseScheduler): class HunyuanVideo15Scheduler(BaseScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.device = torch.device("cuda") self.device = torch.device(self.config.get("run_device", "cuda"))
self.reverse = True self.reverse = True
self.num_train_timesteps = 1000 self.num_train_timesteps = 1000
self.sample_shift = self.config["sample_shift"] self.sample_shift = self.config["sample_shift"]
...@@ -127,13 +127,7 @@ class HunyuanVideo15Scheduler(BaseScheduler): ...@@ -127,13 +127,7 @@ class HunyuanVideo15Scheduler(BaseScheduler):
if rope_dim_list is None: if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos, freqs_sin = get_nd_rotary_pos_embed( freqs_cos, freqs_sin = get_nd_rotary_pos_embed(rope_dim_list, rope_sizes, theta=self.config["rope_theta"], use_real=True, theta_rescale_factor=1, device=self.device)
rope_dim_list,
rope_sizes,
theta=self.config["rope_theta"],
use_real=True,
theta_rescale_factor=1,
)
cos_half = freqs_cos[:, ::2].contiguous() cos_half = freqs_cos[:, ::2].contiguous()
sin_half = freqs_sin[:, ::2].contiguous() sin_half = freqs_sin[:, ::2].contiguous()
cos_sin = torch.cat([cos_half, sin_half], dim=-1) cos_sin = torch.cat([cos_half, sin_half], dim=-1)
......
...@@ -184,7 +184,7 @@ def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_ ...@@ -184,7 +184,7 @@ def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_
torch.Tensor: Causal attention mask. torch.Tensor: Causal attention mask.
""" """
seq_len = n_frame * n_hw seq_len = n_frame * n_hw
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device="cuda") mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
for i in range(seq_len): for i in range(seq_len):
i_frame = i // n_hw i_frame = i // n_hw
mask[i, : (i_frame + 1) * n_hw] = 0 mask[i, : (i_frame + 1) * n_hw] = 0
......
...@@ -1330,7 +1330,9 @@ class WanVAE: ...@@ -1330,7 +1330,9 @@ class WanVAE:
def device_synchronize( def device_synchronize(
self, self,
): ):
if "cuda" in str(self.run_device): if "cuda" in str(self.device):
torch.cuda.synchronize() torch.cuda.synchronize()
elif "mlu" in str(self.run_device): elif "mlu" in str(self.device):
torch.mlu.synchronize() torch.mlu.synchronize()
elif "npu" in str(self.device):
torch.npu.synchronize()
...@@ -85,6 +85,8 @@ class _ProfilingContext: ...@@ -85,6 +85,8 @@ class _ProfilingContext:
torch.cuda.synchronize() torch.cuda.synchronize()
elif hasattr(torch, "mlu") and torch.mlu.is_available(): elif hasattr(torch, "mlu") and torch.mlu.is_available():
torch.mlu.synchronize() torch.mlu.synchronize()
elif hasattr(torch, "npu") and torch.npu.is_available():
torch.npu.synchronize()
return return
......
...@@ -92,7 +92,8 @@ def set_parallel_config(config): ...@@ -92,7 +92,8 @@ def set_parallel_config(config):
cfg_p_size = config["parallel"].get("cfg_p_size", 1) cfg_p_size = config["parallel"].get("cfg_p_size", 1)
seq_p_size = config["parallel"].get("seq_p_size", 1) seq_p_size = config["parallel"].get("seq_p_size", 1)
assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size" assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size"
config["device_mesh"] = init_device_mesh("cuda", (cfg_p_size, seq_p_size), mesh_dim_names=("cfg_p", "seq_p")) device_str = config.get("run_device", "cuda")
config["device_mesh"] = init_device_mesh(device_str, (cfg_p_size, seq_p_size), mesh_dim_names=("cfg_p", "seq_p"))
if config["parallel"] and config["parallel"].get("seq_p_size", False) and config["parallel"]["seq_p_size"] > 1: if config["parallel"] and config["parallel"].get("seq_p_size", False) and config["parallel"]["seq_p_size"] > 1:
config["seq_parallel"] = True config["seq_parallel"] = True
...@@ -100,7 +101,7 @@ def set_parallel_config(config): ...@@ -100,7 +101,7 @@ def set_parallel_config(config):
if config.get("enable_cfg", False) and config["parallel"] and config["parallel"].get("cfg_p_size", False) and config["parallel"]["cfg_p_size"] > 1: if config.get("enable_cfg", False) and config["parallel"] and config["parallel"].get("cfg_p_size", False) and config["parallel"]["cfg_p_size"] > 1:
config["cfg_parallel"] = True config["cfg_parallel"] = True
# warmup dist # warmup dist
_a = torch.zeros([1]).to(f"cuda:{dist.get_rank()}") _a = torch.zeros([1]).to(f"{device_str}:{dist.get_rank()}")
dist.all_reduce(_a) dist.all_reduce(_a)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment