Unverified Commit fcc2a411 authored by Kane's avatar Kane Committed by GitHub
Browse files

Mlu590 deployment (#453)

Feature:
    1. added mlu590 bfloat16, single-gpu and multi-gpus inference.
    2. added mlu590 int8 inference.
parent 989a30a0
from .flash_attn import FlashAttn2Weight, FlashAttn3Weight from .flash_attn import FlashAttn2Weight, FlashAttn3Weight, MluFlashAttnWeight
from .nbhd_attn import NbhdAttnWeight, NbhdAttnWeightFlashInfer from .nbhd_attn import NbhdAttnWeight, NbhdAttnWeightFlashInfer
from .radial_attn import RadialAttnWeight from .radial_attn import RadialAttnWeight
from .ring_attn import RingAttnWeight from .ring_attn import RingAttnWeight
......
import math
from loguru import logger from loguru import logger
try: try:
...@@ -13,6 +15,12 @@ except ImportError: ...@@ -13,6 +15,12 @@ except ImportError:
logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first") logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first")
flash_attn_varlen_func_v3 = None flash_attn_varlen_func_v3 = None
try:
import torch_mlu_ops as tmo
except ImportError:
logger.info("torch_mlu_ops not found.")
tmo = None
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate from .template import AttnWeightTemplate
...@@ -80,3 +88,35 @@ class FlashAttn3Weight(AttnWeightTemplate): ...@@ -80,3 +88,35 @@ class FlashAttn3Weight(AttnWeightTemplate):
max_seqlen_kv, max_seqlen_kv,
).reshape(bs * max_seqlen_q, -1) ).reshape(bs * max_seqlen_q, -1)
return x return x
@ATTN_WEIGHT_REGISTER("mlu_flash_attn")
class MluFlashAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, **kws):
if len(q.shape) == 3:
bs = 1
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
elif len(q.shape) == 4:
bs = q.shape[0]
softmax_scale = 1 / math.sqrt(q.shape[-1])
x = tmo.flash_attention(
q=q,
k=k,
v=v,
cu_seq_lens_q=cu_seqlens_q,
cu_seq_lens_kv=cu_seqlens_kv,
max_seq_len_q=max_seqlen_q,
max_seq_len_kv=max_seqlen_kv,
softmax_scale=softmax_scale,
return_lse=False,
out_dtype=q.dtype,
is_causal=False,
out=None,
alibi_slope=None,
attn_bias=None,
)
x = x.reshape(bs * max_seqlen_q, -1)
return x
import math
import torch import torch
from loguru import logger from loguru import logger
...@@ -5,7 +7,7 @@ from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER ...@@ -5,7 +7,7 @@ from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate from .template import AttnWeightTemplate
if torch.cuda.get_device_capability(0) in [(8, 9), (12, 0)]: if torch.cuda.is_available() and torch.cuda.get_device_capability(0) in [(8, 9), (12, 0)]:
try: try:
from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
except ImportError: except ImportError:
...@@ -24,6 +26,12 @@ except ImportError: ...@@ -24,6 +26,12 @@ except ImportError:
logger.info("sageattn3 not found, please install sageattention first") logger.info("sageattn3 not found, please install sageattention first")
sageattn3_blackwell = None sageattn3_blackwell = None
try:
import torch_mlu_ops as tmo
except ImportError:
tmo = None
logger.info("torch_mlu_ops not found.")
@ATTN_WEIGHT_REGISTER("sage_attn2") @ATTN_WEIGHT_REGISTER("sage_attn2")
class SageAttn2Weight(AttnWeightTemplate): class SageAttn2Weight(AttnWeightTemplate):
...@@ -81,3 +89,22 @@ class SageAttn3Weight(AttnWeightTemplate): ...@@ -81,3 +89,22 @@ class SageAttn3Weight(AttnWeightTemplate):
x = sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2).reshape(bs * max_seqlen_q, -1) x = sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2).reshape(bs * max_seqlen_q, -1)
return x return x
@ATTN_WEIGHT_REGISTER("mlu_sage_attn")
class MluSageAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, **kws):
if len(q.shape) == 3:
bs = 1
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
elif len(q.shape) == 4:
bs = q.shape[0]
softmax_scale = 1 / math.sqrt(q.shape[-1])
x = tmo.sage_attn(
q=q, k=k, v=v, cu_seq_lens_q=None, cu_seq_lens_kv=None, max_seq_len_kv=max_seqlen_kv, max_seq_len_q=max_seqlen_q, is_causal=False, compute_dtype=torch.bfloat16, softmax_scale=softmax_scale
)
x = x.reshape(bs * max_seqlen_q, -1)
return x
...@@ -53,7 +53,7 @@ class UlyssesAttnWeight(AttnWeightTemplate): ...@@ -53,7 +53,7 @@ class UlyssesAttnWeight(AttnWeightTemplate):
img_q = all2all_seq2head(img_q, group=seq_p_group) img_q = all2all_seq2head(img_q, group=seq_p_group)
img_k = all2all_seq2head(img_k, group=seq_p_group) img_k = all2all_seq2head(img_k, group=seq_p_group)
img_v = all2all_seq2head(img_v, group=seq_p_group) img_v = all2all_seq2head(img_v, group=seq_p_group)
torch.cuda.synchronize() # 确保CUDA操作完成 self.device_synchronize() # 确保CUDA操作完成
# 处理文本的查询、键和值,选择当前进程的头 # 处理文本的查询、键和值,选择当前进程的头
txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :] txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
...@@ -66,7 +66,7 @@ class UlyssesAttnWeight(AttnWeightTemplate): ...@@ -66,7 +66,7 @@ class UlyssesAttnWeight(AttnWeightTemplate):
v = torch.cat((img_v, txt_v), dim=0) v = torch.cat((img_v, txt_v), dim=0)
# 初始化累积序列长度张量 # 初始化累积序列长度张量
cu_seqlens_qkv = torch.zeros([2], dtype=torch.int32, device="cuda") cu_seqlens_qkv = torch.zeros([2], dtype=torch.int32, device=self.config.get("run_device", "cuda"))
s = txt_qkv_len + img_q.shape[0] # 计算文本和图像的总长度 s = txt_qkv_len + img_q.shape[0] # 计算文本和图像的总长度
s1 = s # 当前样本的结束位置 s1 = s # 当前样本的结束位置
cu_seqlens_qkv[1] = s1 # 设置累积序列长度 cu_seqlens_qkv[1] = s1 # 设置累积序列长度
...@@ -100,9 +100,19 @@ class UlyssesAttnWeight(AttnWeightTemplate): ...@@ -100,9 +100,19 @@ class UlyssesAttnWeight(AttnWeightTemplate):
img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果 img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果
img_attn = all2all_head2seq(img_attn, group=seq_p_group) # 将头的格式转换回序列格式 img_attn = all2all_head2seq(img_attn, group=seq_p_group) # 将头的格式转换回序列格式
img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状 img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状
torch.cuda.synchronize() # 确保CUDA操作完成 self.device_synchronize() # 确保CUDA操作完成
return img_attn return img_attn
def device_synchronize(
self,
):
if torch.cuda.is_available():
torch.cuda.synchronize()
self.config["run_device"] = "cuda"
elif hasattr(torch, "mlu") and torch.mlu.is_available():
torch.mlu.synchronize()
self.config["run_device"] = "mlu"
@ATTN_WEIGHT_REGISTER("ulysses-4090") @ATTN_WEIGHT_REGISTER("ulysses-4090")
class Ulysses4090AttnWeight(AttnWeightTemplate): class Ulysses4090AttnWeight(AttnWeightTemplate):
......
...@@ -65,6 +65,11 @@ try: ...@@ -65,6 +65,11 @@ try:
except ImportError: except ImportError:
marlin_cuda_quant = None marlin_cuda_quant = None
try:
import torch_mlu_ops as tmo
except ImportError:
tmo = None
class MMWeightTemplate(metaclass=ABCMeta): class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): def __init__(self, weight_name, bias_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
...@@ -121,7 +126,7 @@ class MMWeight(MMWeightTemplate): ...@@ -121,7 +126,7 @@ class MMWeight(MMWeightTemplate):
self.bias_cuda_buffer = weight_dict[self.bias_name].cuda() self.bias_cuda_buffer = weight_dict[self.bias_name].cuda()
else: else:
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
if device.type == "cuda": if device.type in ["cuda", "mlu", "npu"]:
self.weight = weight_dict[self.weight_name].t() self.weight = weight_dict[self.weight_name].t()
if self.bias_name is not None: if self.bias_name is not None:
self.bias = weight_dict[self.bias_name] self.bias = weight_dict[self.bias_name]
...@@ -266,7 +271,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -266,7 +271,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight_scale_cuda_buffer = weight_dict[self.weight_scale_name].float().cuda() self.weight_scale_cuda_buffer = weight_dict[self.weight_scale_name].float().cuda()
else: else:
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
if device.type == "cuda": if device.type in ["cuda", "mlu", "npu"]:
self.weight = weight_dict[self.weight_name] self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name].float() self.weight_scale = weight_dict[self.weight_scale_name].float()
elif device.type == "cpu": elif device.type == "cpu":
...@@ -330,7 +335,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -330,7 +335,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device) self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else: else:
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
if device.type == "cuda": if device.type in ["cuda", "mlu", "npu"]:
self.weight = weight_dict[self.weight_name] self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name] self.weight_scale = weight_dict[self.weight_scale_name]
elif device.type == "cpu": elif device.type == "cpu":
...@@ -1014,3 +1019,31 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate): ...@@ -1014,3 +1019,31 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
if hasattr(self, "bias") and self.bias is not None: if hasattr(self, "bias") and self.bias is not None:
output_tensor.add_(self.bias) output_tensor.add_(self.bias)
return output_tensor return output_tensor
@MM_WEIGHT_REGISTER("int8-tmo")
class MMWeightWint8channelAint8channeldynamicMlu(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Mlu
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: mlu
"""
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_int8_perchannel_sym_tmo
def act_quant_int8_perchannel_sym_tmo(self, x):
input_tensor_quant, input_tensor_scale = tmo.scaled_quantize(x)
return input_tensor_quant, input_tensor_scale
def apply(self, input_tensor):
dtype = input_tensor.dtype
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = tmo.scaled_matmul(input_tensor_quant, self.weight.contiguous(), input_tensor_scale, self.weight_scale.squeeze(-1), output_dtype=dtype, use_hp_active=True)
return output_tensor
...@@ -30,7 +30,7 @@ class RMSWeightTemplate(metaclass=ABCMeta): ...@@ -30,7 +30,7 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self.weight_cuda_buffer = weight_dict[self.weight_name].cuda() self.weight_cuda_buffer = weight_dict[self.weight_name].cuda()
else: else:
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
if device.type == "cuda": if device.type in ["cuda", "mlu", "npu"]:
self.weight = weight_dict[self.weight_name] self.weight = weight_dict[self.weight_name]
elif device.type == "cpu": elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape weight_shape = weight_dict[self.weight_name].shape
......
...@@ -96,8 +96,13 @@ def main(): ...@@ -96,8 +96,13 @@ def main():
config = set_config(args) config = set_config(args)
if config["parallel"]: if config["parallel"]:
dist.init_process_group(backend="nccl") run_device = config.get("run_device", "cuda")
torch.cuda.set_device(dist.get_rank()) if "cuda" in run_device:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
elif "mlu" in run_device:
dist.init_process_group(backend="cncl")
torch.mlu.set_device(dist.get_rank())
set_parallel_config(config) set_parallel_config(config)
print_config(config) print_config(config)
......
...@@ -26,6 +26,11 @@ try: ...@@ -26,6 +26,11 @@ try:
except ImportError: except ImportError:
fp8_linear = None fp8_linear = None
try:
import torch_mlu_ops as tmo
except ImportError:
tmo = None
class VllmQuantLinearInt8(nn.Module): class VllmQuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16): def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
...@@ -310,3 +315,19 @@ class Q8FQuantLinearFp8(nn.Module): ...@@ -310,3 +315,19 @@ class Q8FQuantLinearFp8(nn.Module):
self.weight_scale = maybe_cast(self.weight_scale) self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias) self.bias = maybe_cast(self.bias)
return self return self
class MluQuantLinearInt8(VllmQuantLinearInt8):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__(in_features, out_features, bias, dtype)
def act_quant_func(self, x):
input_tensor_quant, input_tensor_scale = tmo.scaled_quantize(x)
return input_tensor_quant, input_tensor_scale
def forward(self, input_tensor):
input_tensor = input_tensor.squeeze(0)
dtype = input_tensor.dtype
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = tmo.scaled_matmul(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale.squeeze(-1), output_dtype=dtype)
return output_tensor.unsqueeze(0)
...@@ -61,7 +61,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -61,7 +61,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
if self.cpu_offload: if self.cpu_offload:
self.device = torch.device("cpu") self.device = torch.device("cpu")
else: else:
self.device = torch.device("cuda") self.device = torch.device(self.config.get("run_device", "cuda"))
self.dtype = torch.bfloat16 self.dtype = torch.bfloat16
self.load() self.load()
...@@ -69,7 +69,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -69,7 +69,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
def load(self): def load(self):
self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(os.path.join(self.config["model_path"], "text_encoder"), torch_dtype=torch.bfloat16) self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(os.path.join(self.config["model_path"], "text_encoder"), torch_dtype=torch.bfloat16)
if not self.cpu_offload: if not self.cpu_offload:
self.text_encoder = self.text_encoder.to("cuda") self.text_encoder = self.text_encoder.to(self.device)
self.tokenizer = Qwen2Tokenizer.from_pretrained(os.path.join(self.config["model_path"], "tokenizer")) self.tokenizer = Qwen2Tokenizer.from_pretrained(os.path.join(self.config["model_path"], "tokenizer"))
if self.config["task"] == "i2i": if self.config["task"] == "i2i":
...@@ -95,7 +95,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -95,7 +95,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
@torch.no_grad() @torch.no_grad()
def infer(self, text, image_list=None): def infer(self, text, image_list=None):
if self.cpu_offload: if self.cpu_offload:
self.text_encoder.to(torch.device("cuda")) self.text_encoder.to(self.device)
if image_list is not None: if image_list is not None:
condition_image_list = [] condition_image_list = []
...@@ -130,7 +130,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -130,7 +130,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
images=condition_image_list, images=condition_image_list,
padding=True, padding=True,
return_tensors="pt", return_tensors="pt",
).to(torch.device("cuda")) ).to(torch.device(self.device))
encoder_hidden_states = self.text_encoder( encoder_hidden_states = self.text_encoder(
input_ids=model_inputs.input_ids, input_ids=model_inputs.input_ids,
...@@ -153,7 +153,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -153,7 +153,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
txt = [template.format(e) for e in text] txt = [template.format(e) for e in text]
image_info = {} image_info = {}
model_inputs = self.tokenizer(txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(torch.device("cuda")) model_inputs = self.tokenizer(txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(self.device)
encoder_hidden_states = self.text_encoder( encoder_hidden_states = self.text_encoder(
input_ids=model_inputs.input_ids, input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask, attention_mask=model_inputs.attention_mask,
...@@ -169,7 +169,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -169,7 +169,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]) encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list])
prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=torch.device("cuda")) prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.device)
prompt_embeds_mask = encoder_attention_mask prompt_embeds_mask = encoder_attention_mask
_, seq_len, _ = prompt_embeds.shape _, seq_len, _ = prompt_embeds.shape
...@@ -180,7 +180,12 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -180,7 +180,12 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
if self.cpu_offload: if self.cpu_offload:
self.text_encoder.to(torch.device("cpu")) self.text_encoder.to(torch.device("cpu"))
torch.cuda.empty_cache() if "mlu" in str(self.device):
torch.mlu.empty_cache()
elif "cuda" in str(self.device):
torch.cuda.empty_cache()
elif "npu" in str(self.device):
torch.npu.empty_cache()
gc.collect() gc.collect()
return prompt_embeds, prompt_embeds_mask, image_info return prompt_embeds, prompt_embeds_mask, image_info
...@@ -252,6 +252,7 @@ class AudioAdapter(nn.Module): ...@@ -252,6 +252,7 @@ class AudioAdapter(nn.Module):
quantized: bool = False, quantized: bool = False,
quant_scheme: str = None, quant_scheme: str = None,
cpu_offload: bool = False, cpu_offload: bool = False,
device=torch.device("cuda"),
): ):
super().__init__() super().__init__()
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
...@@ -262,6 +263,7 @@ class AudioAdapter(nn.Module): ...@@ -262,6 +263,7 @@ class AudioAdapter(nn.Module):
mlp_dims=mlp_dims, mlp_dims=mlp_dims,
transformer_layers=projection_transformer_layers, transformer_layers=projection_transformer_layers,
) )
self.device = torch.device(device)
# self.num_tokens = num_tokens * 4 # self.num_tokens = num_tokens * 4
self.num_tokens_x4 = num_tokens * 4 self.num_tokens_x4 = num_tokens * 4
self.audio_pe = nn.Parameter(torch.randn(self.num_tokens_x4, mlp_dims[-1] // num_tokens) * 0.02) self.audio_pe = nn.Parameter(torch.randn(self.num_tokens_x4, mlp_dims[-1] // num_tokens) * 0.02)
...@@ -300,10 +302,10 @@ class AudioAdapter(nn.Module): ...@@ -300,10 +302,10 @@ class AudioAdapter(nn.Module):
@torch.no_grad() @torch.no_grad()
def forward_audio_proj(self, audio_feat, latent_frame): def forward_audio_proj(self, audio_feat, latent_frame):
if self.cpu_offload: if self.cpu_offload:
self.audio_proj.to("cuda") self.audio_proj.to(self.device)
x = self.audio_proj(audio_feat, latent_frame) x = self.audio_proj(audio_feat, latent_frame)
x = self.rearange_audio_features(x) x = self.rearange_audio_features(x)
x = x + self.audio_pe.cuda() x = x + self.audio_pe.to(self.device)
if self.cpu_offload: if self.cpu_offload:
self.audio_proj.to("cpu") self.audio_proj.to("cpu")
return x return x
...@@ -5,14 +5,14 @@ from lightx2v.utils.envs import * ...@@ -5,14 +5,14 @@ from lightx2v.utils.envs import *
class SekoAudioEncoderModel: class SekoAudioEncoderModel:
def __init__(self, model_path, audio_sr, cpu_offload): def __init__(self, model_path, audio_sr, cpu_offload, device):
self.model_path = model_path self.model_path = model_path
self.audio_sr = audio_sr self.audio_sr = audio_sr
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
if self.cpu_offload: if self.cpu_offload:
self.device = torch.device("cpu") self.device = torch.device("cpu")
else: else:
self.device = torch.device("cuda") self.device = torch.device(device)
self.load() self.load()
def load(self): def load(self):
...@@ -30,7 +30,7 @@ class SekoAudioEncoderModel: ...@@ -30,7 +30,7 @@ class SekoAudioEncoderModel:
@torch.no_grad() @torch.no_grad()
def infer(self, audio_segment): def infer(self, audio_segment):
audio_feat = self.audio_feature_extractor(audio_segment, sampling_rate=self.audio_sr, return_tensors="pt").input_values.cuda().to(dtype=GET_DTYPE()) audio_feat = self.audio_feature_extractor(audio_segment, sampling_rate=self.audio_sr, return_tensors="pt").input_values.to(self.device).to(dtype=GET_DTYPE())
if self.cpu_offload: if self.cpu_offload:
self.audio_feature_encoder = self.audio_feature_encoder.to("cuda") self.audio_feature_encoder = self.audio_feature_encoder.to("cuda")
audio_feat = self.audio_feature_encoder(audio_feat, return_dict=True).last_hidden_state audio_feat = self.audio_feature_encoder(audio_feat, return_dict=True).last_hidden_state
......
...@@ -25,7 +25,8 @@ from lightx2v.models.input_encoders.hf.q_linear import ( # noqa E402 ...@@ -25,7 +25,8 @@ from lightx2v.models.input_encoders.hf.q_linear import ( # noqa E402
Q8FQuantLinearInt8, # noqa E402 Q8FQuantLinearInt8, # noqa E402
SglQuantLinearFp8, # noqa E402 SglQuantLinearFp8, # noqa E402
TorchaoQuantLinearInt8, # noqa E402 TorchaoQuantLinearInt8, # noqa E402
VllmQuantLinearInt8, # noqa E402 VllmQuantLinearInt8, # noqa E402,
MluQuantLinearInt8,
) )
from lightx2v.models.input_encoders.hf.wan.t5.tokenizer import HuggingfaceTokenizer # noqa E402 from lightx2v.models.input_encoders.hf.wan.t5.tokenizer import HuggingfaceTokenizer # noqa E402
from lightx2v.utils.envs import * # noqa E402 from lightx2v.utils.envs import * # noqa E402
...@@ -201,6 +202,8 @@ class T5Attention(nn.Module): ...@@ -201,6 +202,8 @@ class T5Attention(nn.Module):
linear_cls = Q8FQuantLinearInt8 linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f": elif quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8 linear_cls = Q8FQuantLinearFp8
elif quant_scheme == "int8-tmo":
linear_cls = MluQuantLinearInt8
else: else:
NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}") NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
else: else:
...@@ -272,6 +275,8 @@ class T5FeedForward(nn.Module): ...@@ -272,6 +275,8 @@ class T5FeedForward(nn.Module):
linear_cls = Q8FQuantLinearInt8 linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f": elif quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8 linear_cls = Q8FQuantLinearFp8
elif quant_scheme == "int8-tmo":
linear_cls = MluQuantLinearInt8
else: else:
NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}") NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
else: else:
...@@ -741,7 +746,7 @@ class T5EncoderModel: ...@@ -741,7 +746,7 @@ class T5EncoderModel:
self, self,
text_len, text_len,
dtype=torch.bfloat16, dtype=torch.bfloat16,
device=torch.cuda.current_device(), device=torch.device("cuda"),
checkpoint_path=None, checkpoint_path=None,
tokenizer_path=None, tokenizer_path=None,
shard_fn=None, shard_fn=None,
...@@ -802,8 +807,8 @@ class T5EncoderModel: ...@@ -802,8 +807,8 @@ class T5EncoderModel:
def infer(self, texts): def infer(self, texts):
ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
ids = ids.cuda() ids = ids.to(self.device)
mask = mask.cuda() mask = mask.to(self.device)
seq_lens = mask.gt(0).sum(dim=1).long() seq_lens = mask.gt(0).sum(dim=1).long()
with torch.no_grad(): with torch.no_grad():
......
...@@ -10,7 +10,7 @@ from loguru import logger ...@@ -10,7 +10,7 @@ from loguru import logger
# from lightx2v.attentions import attention # from lightx2v.attentions import attention
from lightx2v.common.ops.attn import TorchSDPAWeight from lightx2v.common.ops.attn import TorchSDPAWeight
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, SglQuantLinearFp8, TorchaoQuantLinearInt8, VllmQuantLinearInt8 from lightx2v.models.input_encoders.hf.q_linear import MluQuantLinearInt8, Q8FQuantLinearFp8, Q8FQuantLinearInt8, SglQuantLinearFp8, TorchaoQuantLinearInt8, VllmQuantLinearInt8
from lightx2v.utils.utils import load_weights from lightx2v.utils.utils import load_weights
__all__ = [ __all__ = [
...@@ -69,6 +69,8 @@ class SelfAttention(nn.Module): ...@@ -69,6 +69,8 @@ class SelfAttention(nn.Module):
linear_cls = Q8FQuantLinearInt8 linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f": elif quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8 linear_cls = Q8FQuantLinearFp8
elif quant_scheme == "int8-tmo":
linear_cls = MluQuantLinearInt8
else: else:
NotImplementedError(f"Unsupported CLip quant scheme: {quant_scheme}") NotImplementedError(f"Unsupported CLip quant scheme: {quant_scheme}")
else: else:
...@@ -149,6 +151,8 @@ class AttentionBlock(nn.Module): ...@@ -149,6 +151,8 @@ class AttentionBlock(nn.Module):
linear_cls = Q8FQuantLinearInt8 linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f": elif quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8 linear_cls = Q8FQuantLinearFp8
elif quant_scheme == "int8-tmo":
linear_cls = MluQuantLinearInt8
else: else:
NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}") NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
else: else:
...@@ -288,7 +292,7 @@ class VisionTransformer(nn.Module): ...@@ -288,7 +292,7 @@ class VisionTransformer(nn.Module):
b = x.size(0) b = x.size(0)
# embeddings # embeddings
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) x = self.patch_embedding(x.type(self.patch_embedding.weight.type())).flatten(2).permute(0, 2, 1)
if self.pool_type in ("token", "token_fc"): if self.pool_type in ("token", "token_fc"):
x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1) x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
if interpolation: if interpolation:
......
...@@ -111,7 +111,7 @@ def apply_attn(block_weight, hidden_states, encoder_hidden_states, image_rotary_ ...@@ -111,7 +111,7 @@ def apply_attn(block_weight, hidden_states, encoder_hidden_states, image_rotary_
if attn_type == "torch_sdpa": if attn_type == "torch_sdpa":
joint_hidden_states = block_weight.attn.calculate.apply(q=joint_query, k=joint_key, v=joint_value) joint_hidden_states = block_weight.attn.calculate.apply(q=joint_query, k=joint_key, v=joint_value)
elif attn_type in ["flash_attn3", "sage_attn2"]: elif attn_type in ["flash_attn3", "sage_attn2", "mlu_flash_attn"]:
joint_query = joint_query.squeeze(0) joint_query = joint_query.squeeze(0)
joint_key = joint_key.squeeze(0) joint_key = joint_key.squeeze(0)
joint_value = joint_value.squeeze(0) joint_value = joint_value.squeeze(0)
......
...@@ -28,7 +28,7 @@ class QwenImageTransformerModel: ...@@ -28,7 +28,7 @@ class QwenImageTransformerModel:
self.model_path = os.path.join(config["model_path"], "transformer") self.model_path = os.path.join(config["model_path"], "transformer")
self.cpu_offload = config.get("cpu_offload", False) self.cpu_offload = config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block") self.offload_granularity = self.config.get("offload_granularity", "block")
self.device = torch.device("cpu") if self.cpu_offload else torch.device("cuda") self.device = torch.device("cpu") if self.cpu_offload else torch.device(self.config.get("run_device", "cuda"))
with open(os.path.join(config["model_path"], "transformer", "config.json"), "r") as f: with open(os.path.join(config["model_path"], "transformer", "config.json"), "r") as f:
transformer_config = json.load(f) transformer_config = json.load(f)
...@@ -124,8 +124,8 @@ class QwenImageTransformerModel: ...@@ -124,8 +124,8 @@ class QwenImageTransformerModel:
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer): def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else [] remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
if self.device.type == "cuda" and dist.is_initialized(): if self.device.type in ["cuda", "mlu", "npu"] and dist.is_initialized():
device = torch.device("cuda:{}".format(dist.get_rank())) device = torch.device("{}:{}".format(self.device.type, dist.get_rank()))
else: else:
device = self.device device = self.device
......
...@@ -22,15 +22,15 @@ class WanAudioModel(WanModel): ...@@ -22,15 +22,15 @@ class WanAudioModel(WanModel):
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device):
self.config = config self.config = config
self._load_adapter_ckpt()
super().__init__(model_path, config, device) super().__init__(model_path, config, device)
self._load_adapter_ckpt()
def _load_adapter_ckpt(self): def _load_adapter_ckpt(self):
if self.config.get("adapter_model_path", None) is None: if self.config.get("adapter_model_path", None) is None:
if self.config.get("adapter_quantized", False): if self.config.get("adapter_quantized", False):
if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f", "fp8-vllm", "fp8-sgl"]: if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f", "fp8-vllm", "fp8-sgl"]:
adapter_model_name = "audio_adapter_model_fp8.safetensors" adapter_model_name = "audio_adapter_model_fp8.safetensors"
elif self.config.get("adapter_quant_scheme", None) in ["int8", "int8-q8f", "int8-vllm", "int8-sgl"]: elif self.config.get("adapter_quant_scheme", None) in ["int8", "int8-q8f", "int8-vllm", "int8-sgl", "int8-tmo"]:
adapter_model_name = "audio_adapter_model_int8.safetensors" adapter_model_name = "audio_adapter_model_int8.safetensors"
elif self.config.get("adapter_quant_scheme", None) in ["mxfp4"]: elif self.config.get("adapter_quant_scheme", None) in ["mxfp4"]:
adapter_model_name = "audio_adapter_model_mxfp4.safetensors" adapter_model_name = "audio_adapter_model_mxfp4.safetensors"
...@@ -50,7 +50,7 @@ class WanAudioModel(WanModel): ...@@ -50,7 +50,7 @@ class WanAudioModel(WanModel):
if not adapter_offload: if not adapter_offload:
if not dist.is_initialized() or not load_from_rank0: if not dist.is_initialized() or not load_from_rank0:
for key in self.adapter_weights_dict: for key in self.adapter_weights_dict:
self.adapter_weights_dict[key] = self.adapter_weights_dict[key].cuda() self.adapter_weights_dict[key] = self.adapter_weights_dict[key].to(torch.device(self.device))
def _init_infer_class(self): def _init_infer_class(self):
super()._init_infer_class() super()._init_infer_class()
......
...@@ -9,6 +9,7 @@ from ..utils import rope_params, sinusoidal_embedding_1d ...@@ -9,6 +9,7 @@ from ..utils import rope_params, sinusoidal_embedding_1d
class WanAudioPreInfer(WanPreInfer): class WanAudioPreInfer(WanPreInfer):
def __init__(self, config): def __init__(self, config):
super().__init__(config)
assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0 assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
d = config["dim"] // config["num_heads"] d = config["dim"] // config["num_heads"]
self.config = config self.config = config
...@@ -20,7 +21,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -20,7 +21,7 @@ class WanAudioPreInfer(WanPreInfer):
rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)),
], ],
dim=1, dim=1,
).cuda() ).to(self.device)
self.freq_dim = config["freq_dim"] self.freq_dim = config["freq_dim"]
self.dim = config["dim"] self.dim = config["dim"]
self.rope_t_dim = d // 2 - 2 * (d // 6) self.rope_t_dim = d // 2 - 2 * (d // 6)
......
...@@ -13,6 +13,7 @@ class WanPreInfer: ...@@ -13,6 +13,7 @@ class WanPreInfer:
d = config["dim"] // config["num_heads"] d = config["dim"] // config["num_heads"]
self.clean_cuda_cache = config.get("clean_cuda_cache", False) self.clean_cuda_cache = config.get("clean_cuda_cache", False)
self.task = config["task"] self.task = config["task"]
self.device = torch.device(self.config.get("run_device", "cuda"))
self.freqs = torch.cat( self.freqs = torch.cat(
[ [
rope_params(1024, d - 4 * (d // 6)), rope_params(1024, d - 4 * (d // 6)),
...@@ -20,7 +21,7 @@ class WanPreInfer: ...@@ -20,7 +21,7 @@ class WanPreInfer:
rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)),
], ],
dim=1, dim=1,
).cuda() ).to(self.device)
self.freq_dim = config["freq_dim"] self.freq_dim = config["freq_dim"]
self.dim = config["dim"] self.dim = config["dim"]
self.enable_dynamic_cfg = config.get("enable_dynamic_cfg", False) self.enable_dynamic_cfg = config.get("enable_dynamic_cfg", False)
......
...@@ -74,6 +74,7 @@ class WanModel(CompiledMethodsMixin): ...@@ -74,6 +74,7 @@ class WanModel(CompiledMethodsMixin):
"mxfp4", "mxfp4",
"mxfp6-mxfp8", "mxfp6-mxfp8",
"mxfp8", "mxfp8",
"int8-tmo",
] ]
self.device = device self.device = device
self._init_infer_class() self._init_infer_class()
...@@ -137,8 +138,8 @@ class WanModel(CompiledMethodsMixin): ...@@ -137,8 +138,8 @@ class WanModel(CompiledMethodsMixin):
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer): def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else [] remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
if self.device.type == "cuda" and dist.is_initialized(): if (self.device.type == "cuda" or self.device.type == "mlu") and dist.is_initialized():
device = torch.device("cuda:{}".format(dist.get_rank())) device = torch.device("{}:{}".format(self.device.type, dist.get_rank()))
else: else:
device = self.device device = self.device
......
...@@ -145,9 +145,9 @@ class BaseRunner(ABC): ...@@ -145,9 +145,9 @@ class BaseRunner(ABC):
if world_size > 1: if world_size > 1:
if rank == signal_rank: if rank == signal_rank:
t = torch.tensor([stopped], dtype=torch.int32).to(device="cuda") t = torch.tensor([stopped], dtype=torch.int32).to(device=self.config.get("run_device", "cuda"))
else: else:
t = torch.zeros(1, dtype=torch.int32, device="cuda") t = torch.zeros(1, dtype=torch.int32, device=self.config.get("run_device", "cuda"))
dist.broadcast(t, src=signal_rank) dist.broadcast(t, src=signal_rank)
stopped = t.item() stopped = t.item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment