Unverified Commit c0863477 authored by Kane's avatar Kane Committed by GitHub
Browse files

Mlu590 (#520)



1. 修复之前的代码合并冲突,并测试通过。

---------
Co-authored-by: default avatarYang Yong (雍洋) <yongyang1030@163.com>
parent 47b3ce2f
...@@ -117,6 +117,9 @@ class UlyssesAttnWeight(AttnWeightTemplate): ...@@ -117,6 +117,9 @@ class UlyssesAttnWeight(AttnWeightTemplate):
elif hasattr(torch, "mlu") and torch.mlu.is_available(): elif hasattr(torch, "mlu") and torch.mlu.is_available():
torch.mlu.synchronize() torch.mlu.synchronize()
self.config["run_device"] = "mlu" self.config["run_device"] = "mlu"
elif hasattr(torch, "npu") and torch.npu.is_available():
torch.npu.synchronize()
self.config["run_device"] = "npu"
@ATTN_WEIGHT_REGISTER("ulysses-4090") @ATTN_WEIGHT_REGISTER("ulysses-4090")
......
...@@ -35,7 +35,7 @@ class Conv3dWeight(Conv3dWeightTemplate): ...@@ -35,7 +35,7 @@ class Conv3dWeight(Conv3dWeightTemplate):
def load(self, weight_dict): def load(self, weight_dict):
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
if device.type == "cuda": if device.type in ["cuda", "mlu", "npu"]:
self.weight = weight_dict[self.weight_name] self.weight = weight_dict[self.weight_name]
if self.bias_name is not None: if self.bias_name is not None:
self.bias = weight_dict[self.bias_name] self.bias = weight_dict[self.bias_name]
......
...@@ -22,7 +22,7 @@ class EmbeddingWeightTemplate(metaclass=ABCMeta): ...@@ -22,7 +22,7 @@ class EmbeddingWeightTemplate(metaclass=ABCMeta):
self.weight_cuda_buffer = weight_dict[self.weight_name].cuda() self.weight_cuda_buffer = weight_dict[self.weight_name].cuda()
else: else:
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
if device.type == "cuda": if device.type in ["cuda", "mlu", "npu"]:
self.weight = weight_dict[self.weight_name] self.weight = weight_dict[self.weight_name]
elif device.type == "cpu": elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape weight_shape = weight_dict[self.weight_name].shape
......
...@@ -296,7 +296,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -296,7 +296,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.bias_cuda_buffer = weight_dict[self.bias_name].cuda() self.bias_cuda_buffer = weight_dict[self.bias_name].cuda()
else: else:
device = weight_dict[self.bias_name].device device = weight_dict[self.bias_name].device
if device.type == "cuda": if device.type in ["cuda", "mlu", "npu"]:
self.bias = weight_dict[self.bias_name] self.bias = weight_dict[self.bias_name]
elif device.type == "cpu": elif device.type == "cpu":
bias_shape = weight_dict[self.bias_name].shape bias_shape = weight_dict[self.bias_name].shape
...@@ -362,7 +362,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -362,7 +362,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device) self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else: else:
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
if device.type == "cuda": if device.type in ["cuda", "mlu", "npu"]:
self.weight = weight_dict[self.weight_name] self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name] self.weight_scale = weight_dict[self.weight_scale_name]
elif device.type == "cpu": elif device.type == "cpu":
...@@ -387,7 +387,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -387,7 +387,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device) self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else: else:
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
if device.type == "cuda": if device.type in ["cuda", "mlu", "npu"]:
self.weight = weight_dict[self.weight_name] self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name] self.weight_scale = weight_dict[self.weight_scale_name]
elif device.type == "cpu": elif device.type == "cpu":
...@@ -412,7 +412,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -412,7 +412,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
weight_global_scale = weight_dict[f"{self.weight_name}_global_scale"] weight_global_scale = weight_dict[f"{self.weight_name}_global_scale"]
alpha = 1.0 / (input_global_scale * weight_global_scale) alpha = 1.0 / (input_global_scale * weight_global_scale)
if device.type == "cuda": if device.type in ["cuda", "mlu", "npu"]:
self.weight = weight_dict[self.weight_name] self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name] self.weight_scale = weight_dict[self.weight_scale_name]
self.input_global_scale = input_global_scale self.input_global_scale = input_global_scale
...@@ -1172,8 +1172,8 @@ class MMWeightWint8channelAint8channeldynamicMlu(MMWeightQuantTemplate): ...@@ -1172,8 +1172,8 @@ class MMWeightWint8channelAint8channeldynamicMlu(MMWeightQuantTemplate):
Kernel: mlu Kernel: mlu
""" """
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None): def __init__(self, weight_name, bias_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file) super().__init__(weight_name, bias_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.load_func = self.load_int8_perchannel_sym self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = False self.weight_need_transpose = False
self.act_quant_func = self.act_quant_int8_perchannel_sym_tmo self.act_quant_func = self.act_quant_int8_perchannel_sym_tmo
......
...@@ -32,7 +32,7 @@ class LNWeightTemplate(metaclass=ABCMeta): ...@@ -32,7 +32,7 @@ class LNWeightTemplate(metaclass=ABCMeta):
else: else:
if self.weight_name is not None: if self.weight_name is not None:
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
if device.type == "cuda": if device.type in ["cuda", "mlu", "npu"]:
self.weight = weight_dict[self.weight_name] self.weight = weight_dict[self.weight_name]
if self.bias_name is not None: if self.bias_name is not None:
self.bias = weight_dict[self.bias_name] self.bias = weight_dict[self.bias_name]
......
...@@ -337,6 +337,8 @@ def maybe_contiguous(x): ...@@ -337,6 +337,8 @@ def maybe_contiguous(x):
def triton_autotune_configs(): def triton_autotune_configs():
if not torch.cuda.is_available():
return []
# Return configs with a valid warp count for the current device # Return configs with a valid warp count for the current device
configs = [] configs = []
# Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
......
...@@ -29,7 +29,7 @@ class DefaultTensor: ...@@ -29,7 +29,7 @@ class DefaultTensor:
self.tensor_cuda_buffer = weight_dict[self.tensor_name].cuda() self.tensor_cuda_buffer = weight_dict[self.tensor_name].cuda()
else: else:
device = weight_dict[self.tensor_name].device device = weight_dict[self.tensor_name].device
if device.type == "cuda": if device.type in ["cuda", "mlu", "npu"]:
self.tensor = weight_dict[self.tensor_name] self.tensor = weight_dict[self.tensor_name]
elif device.type == "cpu": elif device.type == "cpu":
tensor_shape = weight_dict[self.tensor_name].shape tensor_shape = weight_dict[self.tensor_name].shape
......
...@@ -158,7 +158,7 @@ class ByT5TextEncoder: ...@@ -158,7 +158,7 @@ class ByT5TextEncoder:
def __init__( def __init__(
self, self,
config, config,
device=torch.cuda.current_device(), device=torch.device("cpu"),
checkpoint_path=None, checkpoint_path=None,
byt5_max_length=256, byt5_max_length=256,
cpu_offload=False, cpu_offload=False,
...@@ -277,8 +277,8 @@ class ByT5TextEncoder: ...@@ -277,8 +277,8 @@ class ByT5TextEncoder:
formatted_text = self.prompt_format.format_prompt(glyph_texts, text_styles) formatted_text = self.prompt_format.format_prompt(glyph_texts, text_styles)
text_ids, text_mask = self.get_byt5_text_tokens(self.byt5_tokenizer, self.byt5_max_length, formatted_text) text_ids, text_mask = self.get_byt5_text_tokens(self.byt5_tokenizer, self.byt5_max_length, formatted_text)
text_ids = text_ids.to("cuda") text_ids = text_ids.to(device)
text_mask = text_mask.to("cuda") text_mask = text_mask.to(device)
byt5_outputs = self.byt5_model(text_ids, attention_mask=text_mask.float()) byt5_outputs = self.byt5_model(text_ids, attention_mask=text_mask.float())
byt5_embeddings = byt5_outputs[0] byt5_embeddings = byt5_outputs[0]
...@@ -300,12 +300,12 @@ class ByT5TextEncoder: ...@@ -300,12 +300,12 @@ class ByT5TextEncoder:
negative_masks = [] negative_masks = []
for prompt in prompt_list: for prompt in prompt_list:
pos_emb, pos_mask = self._process_single_byt5_prompt(prompt, "cuda") pos_emb, pos_mask = self._process_single_byt5_prompt(prompt, self.device)
positive_embeddings.append(pos_emb) positive_embeddings.append(pos_emb)
positive_masks.append(pos_mask) positive_masks.append(pos_mask)
if self.enable_cfg: # TODO: 把cfg拆出去,更适合并行 if self.enable_cfg: # TODO: 把cfg拆出去,更适合并行
neg_emb, neg_mask = self._process_single_byt5_prompt("", "cuda") neg_emb, neg_mask = self._process_single_byt5_prompt("", self.device)
negative_embeddings.append(neg_emb) negative_embeddings.append(neg_emb)
negative_masks.append(neg_mask) negative_masks.append(neg_mask)
...@@ -327,8 +327,8 @@ class ByT5TextEncoder: ...@@ -327,8 +327,8 @@ class ByT5TextEncoder:
@torch.no_grad() @torch.no_grad()
def infer(self, prompts): def infer(self, prompts):
if self.cpu_offload: if self.cpu_offload:
self.byt5_model = self.byt5_model.to("cuda") self.byt5_model = self.byt5_model.to(self.device)
self.byt5_mapper = self.byt5_mapper.to("cuda") self.byt5_mapper = self.byt5_mapper.to(self.device)
byt5_embeddings, byt5_masks = self._prepare_byt5_embeddings(prompts) byt5_embeddings, byt5_masks = self._prepare_byt5_embeddings(prompts)
byt5_features = self.byt5_mapper(byt5_embeddings.to(torch.bfloat16)) byt5_features = self.byt5_mapper(byt5_embeddings.to(torch.bfloat16))
if self.cpu_offload: if self.cpu_offload:
......
...@@ -144,7 +144,13 @@ def load_text_encoder( ...@@ -144,7 +144,13 @@ def load_text_encoder(
continue continue
new_w_dict[key.replace("model.", "")] = weight_dict[key] new_w_dict[key.replace("model.", "")] = weight_dict[key]
del weight_dict del weight_dict
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif "mlu" in str(device):
torch.mlu.empty_cache()
elif "npu" in str(device):
torch.npu.empty_cache()
gc.collect() gc.collect()
text_encoder.load_state_dict(new_w_dict, assign=True) text_encoder.load_state_dict(new_w_dict, assign=True)
...@@ -545,7 +551,7 @@ class Qwen25VL_TextEncoder: ...@@ -545,7 +551,7 @@ class Qwen25VL_TextEncoder:
self, self,
text_len=1000, text_len=1000,
dtype=torch.float16, dtype=torch.float16,
device=torch.cuda.current_device(), device=torch.device("cpu"),
checkpoint_path=None, checkpoint_path=None,
cpu_offload=False, cpu_offload=False,
qwen25vl_quantized=False, qwen25vl_quantized=False,
...@@ -583,20 +589,20 @@ class Qwen25VL_TextEncoder: ...@@ -583,20 +589,20 @@ class Qwen25VL_TextEncoder:
def infer(self, texts): def infer(self, texts):
if self.cpu_offload: if self.cpu_offload:
self.text_encoder = self.text_encoder.to("cuda") self.text_encoder = self.text_encoder.to(self.device)
text_inputs = self.text_encoder.text2tokens(texts, data_type="video", max_length=self.text_len) text_inputs = self.text_encoder.text2tokens(texts, data_type="video", max_length=self.text_len)
prompt_outputs = self.text_encoder.encode(text_inputs, data_type="video", device="cuda") prompt_outputs = self.text_encoder.encode(text_inputs, data_type="video", device=self.device)
if self.cpu_offload: if self.cpu_offload:
self.text_encoder = self.text_encoder.to("cpu") self.text_encoder = self.text_encoder.to("cpu")
prompt_embeds = prompt_outputs.hidden_state prompt_embeds = prompt_outputs.hidden_state
attention_mask = prompt_outputs.attention_mask attention_mask = prompt_outputs.attention_mask
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.cuda() attention_mask = attention_mask.to(self.device)
_, seq_len = attention_mask.shape _, seq_len = attention_mask.shape
attention_mask = attention_mask.repeat(1, self.num_videos_per_prompt) attention_mask = attention_mask.repeat(1, self.num_videos_per_prompt)
attention_mask = attention_mask.view(self.num_videos_per_prompt, seq_len) attention_mask = attention_mask.view(self.num_videos_per_prompt, seq_len)
prompt_embeds = prompt_embeds.to(dtype=self.dtype, device="cuda") prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.device)
seq_len = prompt_embeds.shape[1] seq_len = prompt_embeds.shape[1]
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
......
...@@ -175,7 +175,7 @@ class VisionEncoder(nn.Module): ...@@ -175,7 +175,7 @@ class VisionEncoder(nn.Module):
if isinstance(images, np.ndarray): if isinstance(images, np.ndarray):
# Preprocess images if they're numpy arrays # Preprocess images if they're numpy arrays
preprocessed = self.processor.preprocess(images=images, return_tensors="pt").to(device="cuda", dtype=self.model.dtype) preprocessed = self.processor.preprocess(images=images, return_tensors="pt").to(device=self.device, dtype=self.model.dtype)
else: else:
# Assume already preprocessed # Assume already preprocessed
preprocessed = images preprocessed = images
...@@ -229,7 +229,7 @@ class SiglipVisionEncoder: ...@@ -229,7 +229,7 @@ class SiglipVisionEncoder:
def __init__( def __init__(
self, self,
config, config,
device=torch.cuda.current_device(), device=torch.device("cpu"),
checkpoint_path=None, checkpoint_path=None,
cpu_offload=False, cpu_offload=False,
): ):
......
...@@ -62,7 +62,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -62,7 +62,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
if self.cpu_offload: if self.cpu_offload:
self.device = torch.device("cpu") self.device = torch.device("cpu")
else: else:
self.device = torch.device(self.run_device) self.device = torch.device(self.config.get("run_device", "cuda"))
self.dtype = torch.bfloat16 self.dtype = torch.bfloat16
self.load() self.load()
...@@ -95,7 +95,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -95,7 +95,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
@torch.no_grad() @torch.no_grad()
def infer(self, text, image_list=None): def infer(self, text, image_list=None):
if self.cpu_offload: if self.cpu_offload:
self.text_encoder.to(self.run_device) self.text_encoder.to(self.device)
if image_list is not None: if image_list is not None:
condition_image_list = [] condition_image_list = []
...@@ -130,7 +130,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -130,7 +130,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
images=condition_image_list, images=condition_image_list,
padding=True, padding=True,
return_tensors="pt", return_tensors="pt",
).to(torch.device(self.run_device)) ).to(torch.device(self.device))
encoder_hidden_states = self.text_encoder( encoder_hidden_states = self.text_encoder(
input_ids=model_inputs.input_ids, input_ids=model_inputs.input_ids,
...@@ -153,7 +153,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -153,7 +153,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
txt = [template.format(e) for e in text] txt = [template.format(e) for e in text]
image_info = {} image_info = {}
model_inputs = self.tokenizer(txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(torch.device(self.run_device)) model_inputs = self.tokenizer(txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(self.device)
encoder_hidden_states = self.text_encoder( encoder_hidden_states = self.text_encoder(
input_ids=model_inputs.input_ids, input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask, attention_mask=model_inputs.attention_mask,
...@@ -169,7 +169,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -169,7 +169,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]) encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list])
prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.run_device) prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.device)
prompt_embeds_mask = encoder_attention_mask prompt_embeds_mask = encoder_attention_mask
_, seq_len, _ = prompt_embeds.shape _, seq_len, _ = prompt_embeds.shape
...@@ -180,12 +180,9 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -180,12 +180,9 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
if self.cpu_offload: if self.cpu_offload:
self.text_encoder.to(torch.device("cpu")) self.text_encoder.to(torch.device("cpu"))
if "mlu" in str(self.device): if hasattr(torch, self.config.get("run_device", "cuda")):
torch.mlu.empty_cache() torch_module = getattr(torch, self.config.get("run_device", "cuda"))
elif "cuda" in str(self.device): torch_module.empty_cache()
torch.cuda.empty_cache()
elif "npu" in str(self.device):
torch.npu.empty_cache()
gc.collect() gc.collect()
return prompt_embeds, prompt_embeds_mask, image_info return prompt_embeds, prompt_embeds_mask, image_info
...@@ -252,7 +252,7 @@ class AudioAdapter(nn.Module): ...@@ -252,7 +252,7 @@ class AudioAdapter(nn.Module):
quantized: bool = False, quantized: bool = False,
quant_scheme: str = None, quant_scheme: str = None,
cpu_offload: bool = False, cpu_offload: bool = False,
run_device=torch.device("cuda"), device=torch.device("cpu"),
): ):
super().__init__() super().__init__()
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
...@@ -263,7 +263,7 @@ class AudioAdapter(nn.Module): ...@@ -263,7 +263,7 @@ class AudioAdapter(nn.Module):
mlp_dims=mlp_dims, mlp_dims=mlp_dims,
transformer_layers=projection_transformer_layers, transformer_layers=projection_transformer_layers,
) )
self.run_device = run_device self.device = torch.device(device)
# self.num_tokens = num_tokens * 4 # self.num_tokens = num_tokens * 4
self.num_tokens_x4 = num_tokens * 4 self.num_tokens_x4 = num_tokens * 4
self.audio_pe = nn.Parameter(torch.randn(self.num_tokens_x4, mlp_dims[-1] // num_tokens) * 0.02) self.audio_pe = nn.Parameter(torch.randn(self.num_tokens_x4, mlp_dims[-1] // num_tokens) * 0.02)
...@@ -302,10 +302,10 @@ class AudioAdapter(nn.Module): ...@@ -302,10 +302,10 @@ class AudioAdapter(nn.Module):
@torch.no_grad() @torch.no_grad()
def forward_audio_proj(self, audio_feat, latent_frame): def forward_audio_proj(self, audio_feat, latent_frame):
if self.cpu_offload: if self.cpu_offload:
self.audio_proj.to(self.run_device) self.audio_proj.to(self.device)
x = self.audio_proj(audio_feat, latent_frame) x = self.audio_proj(audio_feat, latent_frame)
x = self.rearange_audio_features(x) x = self.rearange_audio_features(x)
x = x + self.audio_pe.to(self.run_device) x = x + self.audio_pe.to(self.device)
if self.cpu_offload: if self.cpu_offload:
self.audio_proj.to("cpu") self.audio_proj.to("cpu")
return x return x
...@@ -5,15 +5,14 @@ from lightx2v.utils.envs import * ...@@ -5,15 +5,14 @@ from lightx2v.utils.envs import *
class SekoAudioEncoderModel: class SekoAudioEncoderModel:
def __init__(self, model_path, audio_sr, cpu_offload, run_device): def __init__(self, model_path, audio_sr, cpu_offload, device):
self.model_path = model_path self.model_path = model_path
self.audio_sr = audio_sr self.audio_sr = audio_sr
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
if self.cpu_offload: if self.cpu_offload:
self.device = torch.device("cpu") self.device = torch.device("cpu")
else: else:
self.device = torch.device(run_device) self.device = torch.device(device)
self.run_device = run_device
self.load() self.load()
def load(self): def load(self):
...@@ -27,13 +26,13 @@ class SekoAudioEncoderModel: ...@@ -27,13 +26,13 @@ class SekoAudioEncoderModel:
self.audio_feature_encoder = self.audio_feature_encoder.to("cpu") self.audio_feature_encoder = self.audio_feature_encoder.to("cpu")
def to_cuda(self): def to_cuda(self):
self.audio_feature_encoder = self.audio_feature_encoder.to(self.run_device) self.audio_feature_encoder = self.audio_feature_encoder.to(self.device)
@torch.no_grad() @torch.no_grad()
def infer(self, audio_segment): def infer(self, audio_segment):
audio_feat = self.audio_feature_extractor(audio_segment, sampling_rate=self.audio_sr, return_tensors="pt").input_values.to(self.run_device).to(dtype=GET_DTYPE()) audio_feat = self.audio_feature_extractor(audio_segment, sampling_rate=self.audio_sr, return_tensors="pt").input_values.to(self.device).to(dtype=GET_DTYPE())
if self.cpu_offload: if self.cpu_offload:
self.audio_feature_encoder = self.audio_feature_encoder.to(self.run_device) self.audio_feature_encoder = self.audio_feature_encoder.to(self.device)
audio_feat = self.audio_feature_encoder(audio_feat, return_dict=True).last_hidden_state audio_feat = self.audio_feature_encoder(audio_feat, return_dict=True).last_hidden_state
if self.cpu_offload: if self.cpu_offload:
self.audio_feature_encoder = self.audio_feature_encoder.to("cpu") self.audio_feature_encoder = self.audio_feature_encoder.to("cpu")
......
...@@ -744,8 +744,7 @@ class T5EncoderModel: ...@@ -744,8 +744,7 @@ class T5EncoderModel:
self, self,
text_len, text_len,
dtype=torch.bfloat16, dtype=torch.bfloat16,
device=torch.device("cuda"), device=torch.device("cpu"),
run_device=torch.device("cuda"),
checkpoint_path=None, checkpoint_path=None,
tokenizer_path=None, tokenizer_path=None,
shard_fn=None, shard_fn=None,
...@@ -758,7 +757,6 @@ class T5EncoderModel: ...@@ -758,7 +757,6 @@ class T5EncoderModel:
self.text_len = text_len self.text_len = text_len
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.run_device = run_device
if t5_quantized_ckpt is not None and t5_quantized: if t5_quantized_ckpt is not None and t5_quantized:
self.checkpoint_path = t5_quantized_ckpt self.checkpoint_path = t5_quantized_ckpt
else: else:
...@@ -807,8 +805,8 @@ class T5EncoderModel: ...@@ -807,8 +805,8 @@ class T5EncoderModel:
def infer(self, texts): def infer(self, texts):
ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
ids = ids.to(self.run_device) ids = ids.to(self.device)
mask = mask.to(self.run_device) mask = mask.to(self.device)
seq_lens = mask.gt(0).sum(dim=1).long() seq_lens = mask.gt(0).sum(dim=1).long()
with torch.no_grad(): with torch.no_grad():
......
...@@ -292,7 +292,7 @@ class VisionTransformer(nn.Module): ...@@ -292,7 +292,7 @@ class VisionTransformer(nn.Module):
b = x.size(0) b = x.size(0)
# embeddings # embeddings
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) x = self.patch_embedding(x.type(self.patch_embedding.weight.type())).flatten(2).permute(0, 2, 1)
if self.pool_type in ("token", "token_fc"): if self.pool_type in ("token", "token_fc"):
x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1) x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
if interpolation: if interpolation:
......
...@@ -10,7 +10,7 @@ except ImportError: ...@@ -10,7 +10,7 @@ except ImportError:
flash_attn_varlen_func_v3 = None flash_attn_varlen_func_v3 = None
logger.info("flash_attn_varlen_func_v3 not available") logger.info("flash_attn_varlen_func_v3 not available")
if torch.cuda.get_device_capability(0) in [(8, 9), (12, 0)]: if torch.cuda.is_available() and torch.cuda.get_device_capability(0) in [(8, 9), (12, 0)]:
try: try:
from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
except ImportError: except ImportError:
......
...@@ -68,6 +68,7 @@ class HunyuanVideo15PreInfer: ...@@ -68,6 +68,7 @@ class HunyuanVideo15PreInfer:
self.heads_num = config["heads_num"] self.heads_num = config["heads_num"]
self.frequency_embedding_size = 256 self.frequency_embedding_size = 256
self.max_period = 10000 self.max_period = 10000
self.device = torch.device(self.config.get("run_device", "cuda"))
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
...@@ -154,7 +155,7 @@ class HunyuanVideo15PreInfer: ...@@ -154,7 +155,7 @@ class HunyuanVideo15PreInfer:
byt5_txt = byt5_txt + weights.cond_type_embedding.apply(torch.ones_like(byt5_txt[:, :, 0], device=byt5_txt.device, dtype=torch.long)) byt5_txt = byt5_txt + weights.cond_type_embedding.apply(torch.ones_like(byt5_txt[:, :, 0], device=byt5_txt.device, dtype=torch.long))
txt, text_mask = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask, zero_feat=True) txt, text_mask = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask, zero_feat=True)
siglip_output = siglip_output + weights.cond_type_embedding.apply(2 * torch.ones_like(siglip_output[:, :, 0], dtype=torch.long, device=torch.device("cuda"))) siglip_output = siglip_output + weights.cond_type_embedding.apply(2 * torch.ones_like(siglip_output[:, :, 0], dtype=torch.long, device=self.device))
txt, text_mask = self.reorder_txt_token(siglip_output, txt, siglip_mask, text_mask) txt, text_mask = self.reorder_txt_token(siglip_output, txt, siglip_mask, text_mask)
txt = txt[:, : text_mask.sum(), :] txt = txt[:, : text_mask.sum(), :]
......
...@@ -3,7 +3,11 @@ from typing import Tuple ...@@ -3,7 +3,11 @@ from typing import Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
try:
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
except Exception as e:
apply_rope_with_cos_sin_cache_inplace = None
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
...@@ -96,6 +100,7 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer): ...@@ -96,6 +100,7 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
self.config = config self.config = config
self.double_blocks_num = config["mm_double_blocks_depth"] self.double_blocks_num = config["mm_double_blocks_depth"]
self.heads_num = config["heads_num"] self.heads_num = config["heads_num"]
self.device = torch.device(self.config.get("run_device", "cuda"))
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
else: else:
...@@ -215,7 +220,7 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer): ...@@ -215,7 +220,7 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
key = torch.cat([img_k, txt_k], dim=1) key = torch.cat([img_k, txt_k], dim=1)
value = torch.cat([img_v, txt_v], dim=1) value = torch.cat([img_v, txt_v], dim=1)
seqlen = query.shape[1] seqlen = query.shape[1]
cu_seqlens_qkv = torch.tensor([0, seqlen], dtype=torch.int32, device="cpu").to("cuda", non_blocking=True) cu_seqlens_qkv = torch.tensor([0, seqlen], dtype=torch.int32, device="cpu").to(self.device, non_blocking=True)
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
attn_out = weights.self_attention_parallel.apply( attn_out = weights.self_attention_parallel.apply(
......
...@@ -339,6 +339,8 @@ def maybe_contiguous(x): ...@@ -339,6 +339,8 @@ def maybe_contiguous(x):
def triton_autotune_configs(): def triton_autotune_configs():
if not torch.cuda.is_available():
return []
# Return configs with a valid warp count for the current device # Return configs with a valid warp count for the current device
configs = [] configs = []
# Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
......
...@@ -176,8 +176,8 @@ class HunyuanVideo15Model(CompiledMethodsMixin): ...@@ -176,8 +176,8 @@ class HunyuanVideo15Model(CompiledMethodsMixin):
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer): def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else [] remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
if self.device.type == "cuda" and dist.is_initialized(): if self.device.type != "cpu" and dist.is_initialized():
device = torch.device("cuda:{}".format(dist.get_rank())) device = torch.device("{}:{}".format(self.device.type, dist.get_rank()))
else: else:
device = self.device device = self.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment