Unverified Commit b50498fa authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Add lightx2v_platform (#541)

parent 31da6925
...@@ -5,15 +5,14 @@ ...@@ -5,15 +5,14 @@
"audio_sr": 16000, "audio_sr": 16000,
"target_video_length": 81, "target_video_length": 81,
"resize_mode": "adaptive", "resize_mode": "adaptive",
"self_attn_1_type": "flash_attn2", "self_attn_1_type": "mlu_sage_attn",
"cross_attn_1_type": "flash_attn2", "cross_attn_1_type": "mlu_sage_attn",
"cross_attn_2_type": "flash_attn2", "cross_attn_2_type": "mlu_sage_attn",
"sample_guide_scale": 1.0, "sample_guide_scale": 1.0,
"sample_shift": 5, "sample_shift": 5,
"enable_cfg": false, "enable_cfg": false,
"cpu_offload": false, "cpu_offload": false,
"use_31_block": false, "use_31_block": false,
"run_device": "mlu",
"rope_type": "torch", "rope_type": "torch",
"modulate_type": "torch" "modulate_type": "torch"
} }
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
"video_duration": 5, "video_duration": 5,
"audio_sr": 16000, "audio_sr": 16000,
"target_video_length": 81, "target_video_length": 81,
"resize_mode": "adaptive",
"self_attn_1_type": "sage_attn3", "self_attn_1_type": "sage_attn3",
"cross_attn_1_type": "sage_attn3", "cross_attn_1_type": "sage_attn3",
"cross_attn_2_type": "sage_attn3", "cross_attn_2_type": "sage_attn3",
......
...@@ -2,6 +2,7 @@ __version__ = "0.1.0" ...@@ -2,6 +2,7 @@ __version__ = "0.1.0"
__author__ = "LightX2V Contributors" __author__ = "LightX2V Contributors"
__license__ = "Apache 2.0" __license__ = "Apache 2.0"
import lightx2v_platform.set_ai_device
from lightx2v import common, deploy, models, utils from lightx2v import common, deploy, models, utils
from lightx2v.pipeline import LightX2VPipeline from lightx2v.pipeline import LightX2VPipeline
......
from .flash_attn import FlashAttn2Weight, FlashAttn3Weight, MluFlashAttnWeight from .flash_attn import FlashAttn2Weight, FlashAttn3Weight
from .nbhd_attn import NbhdAttnWeight, NbhdAttnWeightFlashInfer from .nbhd_attn import NbhdAttnWeight, NbhdAttnWeightFlashInfer
from .radial_attn import RadialAttnWeight from .radial_attn import RadialAttnWeight
from .ring_attn import RingAttnWeight from .ring_attn import RingAttnWeight
......
import math
from loguru import logger from loguru import logger
try: try:
...@@ -15,12 +13,6 @@ except ImportError: ...@@ -15,12 +13,6 @@ except ImportError:
logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first") logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first")
flash_attn_varlen_func_v3 = None flash_attn_varlen_func_v3 = None
try:
import torch_mlu_ops as tmo
except ImportError:
logger.info("torch_mlu_ops not found.")
tmo = None
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate from .template import AttnWeightTemplate
...@@ -94,35 +86,3 @@ class FlashAttn3Weight(AttnWeightTemplate): ...@@ -94,35 +86,3 @@ class FlashAttn3Weight(AttnWeightTemplate):
max_seqlen_kv, max_seqlen_kv,
).reshape(bs * max_seqlen_q, -1) ).reshape(bs * max_seqlen_q, -1)
return x return x
@ATTN_WEIGHT_REGISTER("mlu_flash_attn")
class MluFlashAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, **kws):
if len(q.shape) == 3:
bs = 1
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
elif len(q.shape) == 4:
bs = q.shape[0]
softmax_scale = 1 / math.sqrt(q.shape[-1])
x = tmo.flash_attention(
q=q,
k=k,
v=v,
cu_seq_lens_q=cu_seqlens_q,
cu_seq_lens_kv=cu_seqlens_kv,
max_seq_len_q=max_seqlen_q,
max_seq_len_kv=max_seqlen_kv,
softmax_scale=softmax_scale,
return_lse=False,
out_dtype=q.dtype,
is_causal=False,
out=None,
alibi_slope=None,
attn_bias=None,
)
x = x.reshape(bs * max_seqlen_q, -1)
return x
import math
import torch import torch
from loguru import logger from loguru import logger
...@@ -26,12 +24,6 @@ except ImportError: ...@@ -26,12 +24,6 @@ except ImportError:
logger.info("sageattn3 not found, please install sageattention first") logger.info("sageattn3 not found, please install sageattention first")
sageattn3_blackwell = None sageattn3_blackwell = None
try:
import torch_mlu_ops as tmo
except ImportError:
tmo = None
logger.info("torch_mlu_ops not found.")
@ATTN_WEIGHT_REGISTER("sage_attn2") @ATTN_WEIGHT_REGISTER("sage_attn2")
class SageAttn2Weight(AttnWeightTemplate): class SageAttn2Weight(AttnWeightTemplate):
...@@ -89,22 +81,3 @@ class SageAttn3Weight(AttnWeightTemplate): ...@@ -89,22 +81,3 @@ class SageAttn3Weight(AttnWeightTemplate):
x = sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2).reshape(bs * max_seqlen_q, -1) x = sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2).reshape(bs * max_seqlen_q, -1)
return x return x
@ATTN_WEIGHT_REGISTER("mlu_sage_attn")
class MluSageAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, **kws):
if len(q.shape) == 3:
bs = 1
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
elif len(q.shape) == 4:
bs = q.shape[0]
softmax_scale = 1 / math.sqrt(q.shape[-1])
x = tmo.sage_attn(
q=q, k=k, v=v, cu_seq_lens_q=None, cu_seq_lens_kv=None, max_seq_len_kv=max_seqlen_kv, max_seq_len_q=max_seqlen_q, is_causal=False, compute_dtype=torch.bfloat16, softmax_scale=softmax_scale
)
x = x.reshape(bs * max_seqlen_q, -1)
return x
...@@ -3,6 +3,7 @@ import torch.distributed as dist ...@@ -3,6 +3,7 @@ import torch.distributed as dist
from lightx2v.utils.quant_utils import dequant_fp8_vllm, quant_fp8_vllm from lightx2v.utils.quant_utils import dequant_fp8_vllm, quant_fp8_vllm
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
from .template import AttnWeightTemplate from .template import AttnWeightTemplate
from .utils.all2all import all2all_head2seq, all2all_seq2head from .utils.all2all import all2all_head2seq, all2all_seq2head
...@@ -75,7 +76,6 @@ class UlyssesAttnWeight(AttnWeightTemplate): ...@@ -75,7 +76,6 @@ class UlyssesAttnWeight(AttnWeightTemplate):
img_q = all2all_seq2head(img_q, group=seq_p_group) img_q = all2all_seq2head(img_q, group=seq_p_group)
img_k = all2all_seq2head(img_k, group=seq_p_group) img_k = all2all_seq2head(img_k, group=seq_p_group)
img_v = all2all_seq2head(img_v, group=seq_p_group) img_v = all2all_seq2head(img_v, group=seq_p_group)
self.device_synchronize() # 确保CUDA操作完成
# 处理文本的查询、键和值,选择当前进程的头 # 处理文本的查询、键和值,选择当前进程的头
txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :] txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
...@@ -88,7 +88,7 @@ class UlyssesAttnWeight(AttnWeightTemplate): ...@@ -88,7 +88,7 @@ class UlyssesAttnWeight(AttnWeightTemplate):
v = torch.cat((img_v, txt_v), dim=0) v = torch.cat((img_v, txt_v), dim=0)
# 初始化累积序列长度张量 # 初始化累积序列长度张量
cu_seqlens_qkv = torch.zeros([2], dtype=torch.int32, device=self.config.get("run_device", "cuda")) cu_seqlens_qkv = torch.zeros([2], dtype=torch.int32, device=AI_DEVICE)
s = txt_qkv_len + img_q.shape[0] # 计算文本和图像的总长度 s = txt_qkv_len + img_q.shape[0] # 计算文本和图像的总长度
s1 = s # 当前样本的结束位置 s1 = s # 当前样本的结束位置
cu_seqlens_qkv[1] = s1 # 设置累积序列长度 cu_seqlens_qkv[1] = s1 # 设置累积序列长度
...@@ -133,23 +133,8 @@ class UlyssesAttnWeight(AttnWeightTemplate): ...@@ -133,23 +133,8 @@ class UlyssesAttnWeight(AttnWeightTemplate):
img_attn = all2all_head2seq(img_attn, group=seq_p_group) img_attn = all2all_head2seq(img_attn, group=seq_p_group)
img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状 img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状
self.device_synchronize() # 确保CUDA操作完成
return img_attn return img_attn
def device_synchronize(
self,
):
if torch.cuda.is_available():
# no need to sync between comm and comp
# torch.cuda.synchronize()
self.config["run_device"] = "cuda"
elif hasattr(torch, "mlu") and torch.mlu.is_available():
torch.mlu.synchronize()
self.config["run_device"] = "mlu"
elif hasattr(torch, "npu") and torch.npu.is_available():
torch.npu.synchronize()
self.config["run_device"] = "npu"
@ATTN_WEIGHT_REGISTER("ulysses-4090") @ATTN_WEIGHT_REGISTER("ulysses-4090")
class Ulysses4090AttnWeight(AttnWeightTemplate): class Ulysses4090AttnWeight(AttnWeightTemplate):
......
...@@ -35,13 +35,7 @@ class Conv3dWeight(Conv3dWeightTemplate): ...@@ -35,13 +35,7 @@ class Conv3dWeight(Conv3dWeightTemplate):
def load(self, weight_dict): def load(self, weight_dict):
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
if device.type in ["cuda", "mlu", "npu"]: if device.type == "cpu":
self.weight = weight_dict[self.weight_name]
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
else:
self.bias = None
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
...@@ -57,7 +51,11 @@ class Conv3dWeight(Conv3dWeightTemplate): ...@@ -57,7 +51,11 @@ class Conv3dWeight(Conv3dWeightTemplate):
self.pin_bias = None self.pin_bias = None
del weight_dict[self.weight_name] del weight_dict[self.weight_name]
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") self.weight = weight_dict[self.weight_name]
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
else:
self.bias = None
def apply(self, input_tensor): def apply(self, input_tensor):
input_tensor = torch.nn.functional.conv3d( input_tensor = torch.nn.functional.conv3d(
......
...@@ -22,16 +22,14 @@ class EmbeddingWeightTemplate(metaclass=ABCMeta): ...@@ -22,16 +22,14 @@ class EmbeddingWeightTemplate(metaclass=ABCMeta):
self.weight_cuda_buffer = weight_dict[self.weight_name].cuda() self.weight_cuda_buffer = weight_dict[self.weight_name].cuda()
else: else:
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
if device.type in ["cuda", "mlu", "npu"]: if device.type == "cpu":
self.weight = weight_dict[self.weight_name]
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name]) self.pin_weight.copy_(weight_dict[self.weight_name])
del weight_dict[self.weight_name] del weight_dict[self.weight_name]
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") self.weight = weight_dict[self.weight_name]
def to_cuda(self, non_blocking=False): def to_cuda(self, non_blocking=False):
self.weight = self.pin_weight.cuda(non_blocking=non_blocking) self.weight = self.pin_weight.cuda(non_blocking=non_blocking)
......
...@@ -67,11 +67,6 @@ try: ...@@ -67,11 +67,6 @@ try:
except ImportError: except ImportError:
marlin_cuda_quant = None marlin_cuda_quant = None
try:
import torch_mlu_ops as tmo
except ImportError:
tmo = None
class MMWeightTemplate(metaclass=ABCMeta): class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): def __init__(self, weight_name, bias_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
...@@ -128,14 +123,7 @@ class MMWeight(MMWeightTemplate): ...@@ -128,14 +123,7 @@ class MMWeight(MMWeightTemplate):
self.bias_cuda_buffer = weight_dict[self.bias_name].cuda() self.bias_cuda_buffer = weight_dict[self.bias_name].cuda()
else: else:
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
if device.type in ["cuda", "mlu", "npu"]: if device.type == "cpu":
self.weight = weight_dict[self.weight_name].t()
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
else:
self.bias = None
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype weight_dtype = weight_dict[self.weight_name].dtype
...@@ -153,7 +141,11 @@ class MMWeight(MMWeightTemplate): ...@@ -153,7 +141,11 @@ class MMWeight(MMWeightTemplate):
del weight_dict[self.weight_name] del weight_dict[self.weight_name]
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") self.weight = weight_dict[self.weight_name].t()
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
else:
self.bias = None
def _calculate_size(self): def _calculate_size(self):
if self.bias is not None: if self.bias is not None:
...@@ -273,10 +265,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -273,10 +265,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight_scale_cuda_buffer = weight_dict[self.weight_scale_name].float().cuda() self.weight_scale_cuda_buffer = weight_dict[self.weight_scale_name].float().cuda()
else: else:
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
if device.type in ["cuda", "mlu", "npu"]: if device.type == "cpu":
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name].float()
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
...@@ -288,7 +277,8 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -288,7 +277,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name]) self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
del weight_dict[self.weight_name] del weight_dict[self.weight_name]
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name].float()
if self.bias_name is not None: if self.bias_name is not None:
if self.create_cuda_buffer: if self.create_cuda_buffer:
...@@ -296,15 +286,13 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -296,15 +286,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.bias_cuda_buffer = weight_dict[self.bias_name].cuda() self.bias_cuda_buffer = weight_dict[self.bias_name].cuda()
else: else:
device = weight_dict[self.bias_name].device device = weight_dict[self.bias_name].device
if device.type in ["cuda", "mlu", "npu"]: if device.type == "cpu":
self.bias = weight_dict[self.bias_name]
elif device.type == "cpu":
bias_shape = weight_dict[self.bias_name].shape bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype bias_dtype = weight_dict[self.bias_name].dtype
self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype) self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.pin_bias.copy_(weight_dict[self.bias_name]) self.pin_bias.copy_(weight_dict[self.bias_name])
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") self.bias = weight_dict[self.bias_name]
else: else:
self.bias = None self.bias = None
self.pin_bias = None self.pin_bias = None
...@@ -337,10 +325,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -337,10 +325,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device) self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else: else:
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
if device.type in ["cuda", "mlu", "npu"]: if device.type == "cpu":
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
...@@ -352,7 +337,8 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -352,7 +337,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name]) self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
del weight_dict[self.weight_name] del weight_dict[self.weight_name]
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
def load_mxfp6(self, weight_dict): def load_mxfp6(self, weight_dict):
if self.config.get("weight_auto_quant", False): if self.config.get("weight_auto_quant", False):
...@@ -362,10 +348,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -362,10 +348,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device) self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else: else:
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
if device.type in ["cuda", "mlu", "npu"]: if device.type == "cpu":
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
...@@ -377,7 +360,8 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -377,7 +360,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name]) self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
del weight_dict[self.weight_name] del weight_dict[self.weight_name]
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
def load_mxfp8(self, weight_dict): def load_mxfp8(self, weight_dict):
if self.config.get("weight_auto_quant", False): if self.config.get("weight_auto_quant", False):
...@@ -387,10 +371,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -387,10 +371,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device) self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else: else:
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
if device.type in ["cuda", "mlu", "npu"]: if device.type == "cpu":
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
...@@ -402,7 +383,8 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -402,7 +383,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name]) self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
del weight_dict[self.weight_name] del weight_dict[self.weight_name]
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
def load_nvfp4(self, weight_dict): def load_nvfp4(self, weight_dict):
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
...@@ -412,12 +394,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -412,12 +394,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
weight_global_scale = weight_dict[f"{self.weight_name}_global_scale"] weight_global_scale = weight_dict[f"{self.weight_name}_global_scale"]
alpha = 1.0 / (input_global_scale * weight_global_scale) alpha = 1.0 / (input_global_scale * weight_global_scale)
if device.type in ["cuda", "mlu", "npu"]: if device.type == "cpu":
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
self.input_global_scale = input_global_scale
self.alpha = alpha
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
...@@ -440,7 +417,10 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -440,7 +417,10 @@ class MMWeightQuantTemplate(MMWeightTemplate):
del weight_dict[self.weight_name] del weight_dict[self.weight_name]
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
self.input_global_scale = input_global_scale
self.alpha = alpha
if self.bias_name is not None: if self.bias_name is not None:
if self.create_cuda_buffer: if self.create_cuda_buffer:
...@@ -1178,33 +1158,3 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate): ...@@ -1178,33 +1158,3 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
if hasattr(self, "bias") and self.bias is not None: if hasattr(self, "bias") and self.bias is not None:
output_tensor.add_(self.bias) output_tensor.add_(self.bias)
return output_tensor return output_tensor
@MM_WEIGHT_REGISTER("int8-tmo")
class MMWeightWint8channelAint8channeldynamicMlu(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Mlu
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: mlu
"""
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_int8_perchannel_sym_tmo
def act_quant_int8_perchannel_sym_tmo(self, x):
input_tensor_quant, input_tensor_scale = tmo.scaled_quantize(x)
return input_tensor_quant, input_tensor_scale
def apply(self, input_tensor):
dtype = input_tensor.dtype
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = tmo.scaled_matmul(
input_tensor_quant, self.weight.contiguous(), input_tensor_scale, self.weight_scale.squeeze(-1), bias=self.bias if self.bias is not None else None, output_dtype=dtype, use_hp_active=True
)
return output_tensor
...@@ -32,13 +32,7 @@ class LNWeightTemplate(metaclass=ABCMeta): ...@@ -32,13 +32,7 @@ class LNWeightTemplate(metaclass=ABCMeta):
else: else:
if self.weight_name is not None: if self.weight_name is not None:
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
if device.type in ["cuda", "mlu", "npu"]: if device.type == "cpu":
self.weight = weight_dict[self.weight_name]
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
else:
self.bias = None
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
...@@ -54,7 +48,11 @@ class LNWeightTemplate(metaclass=ABCMeta): ...@@ -54,7 +48,11 @@ class LNWeightTemplate(metaclass=ABCMeta):
self.pin_bias = None self.pin_bias = None
del weight_dict[self.weight_name] del weight_dict[self.weight_name]
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") self.weight = weight_dict[self.weight_name]
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
else:
self.bias = None
else: else:
self.weight = None self.weight = None
self.bias = None self.bias = None
......
...@@ -30,16 +30,14 @@ class RMSWeightTemplate(metaclass=ABCMeta): ...@@ -30,16 +30,14 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self.weight_cuda_buffer = weight_dict[self.weight_name].cuda() self.weight_cuda_buffer = weight_dict[self.weight_name].cuda()
else: else:
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
if device.type in ["cuda", "mlu", "npu"]: if device.type == "cpu":
self.weight = weight_dict[self.weight_name]
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name]) self.pin_weight.copy_(weight_dict[self.weight_name])
del weight_dict[self.weight_name] del weight_dict[self.weight_name]
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") self.weight = weight_dict[self.weight_name]
def clear(self): def clear(self):
attrs = ["weight", "pinned_weight"] attrs = ["weight", "pinned_weight"]
......
...@@ -29,16 +29,14 @@ class DefaultTensor: ...@@ -29,16 +29,14 @@ class DefaultTensor:
self.tensor_cuda_buffer = weight_dict[self.tensor_name].cuda() self.tensor_cuda_buffer = weight_dict[self.tensor_name].cuda()
else: else:
device = weight_dict[self.tensor_name].device device = weight_dict[self.tensor_name].device
if device.type in ["cuda", "mlu", "npu"]: if device.type == "cpu":
self.tensor = weight_dict[self.tensor_name]
elif device.type == "cpu":
tensor_shape = weight_dict[self.tensor_name].shape tensor_shape = weight_dict[self.tensor_name].shape
tensor_dtype = weight_dict[self.tensor_name].dtype tensor_dtype = weight_dict[self.tensor_name].dtype
self.pin_tensor = torch.empty(tensor_shape, pin_memory=True, dtype=tensor_dtype) self.pin_tensor = torch.empty(tensor_shape, pin_memory=True, dtype=tensor_dtype)
self.pin_tensor.copy_(weight_dict[self.tensor_name]) self.pin_tensor.copy_(weight_dict[self.tensor_name])
del weight_dict[self.tensor_name] del weight_dict[self.tensor_name]
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") self.tensor = weight_dict[self.tensor_name]
def clear(self): def clear(self):
attrs = ["tensor", "pinned_tensor"] attrs = ["tensor", "pinned_tensor"]
......
...@@ -4,11 +4,6 @@ import torch ...@@ -4,11 +4,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from loguru import logger from loguru import logger
try:
from torch.distributed import ProcessGroupNCCL
except ImportError:
ProcessGroupNCCL = None
from lightx2v.common.ops import * from lightx2v.common.ops import *
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner # noqa: F401 from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner # noqa: F401
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401 from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401
...@@ -26,6 +21,8 @@ from lightx2v.utils.profiler import * ...@@ -26,6 +21,8 @@ from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.set_config import print_config, set_config, set_parallel_config from lightx2v.utils.set_config import print_config, set_config, set_parallel_config
from lightx2v.utils.utils import seed_all from lightx2v.utils.utils import seed_all
from lightx2v_platform.base.global_var import AI_DEVICE
from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER
def init_runner(config): def init_runner(config):
...@@ -105,15 +102,8 @@ def main(): ...@@ -105,15 +102,8 @@ def main():
config = set_config(args) config = set_config(args)
if config["parallel"]: if config["parallel"]:
run_device = config.get("run_device", "cuda") platform_device = PLATFORM_DEVICE_REGISTER.get(AI_DEVICE, None)
if "cuda" in run_device: platform_device.init_parallel_env()
pg_options = ProcessGroupNCCL.Options()
pg_options.is_high_priority_stream = True
dist.init_process_group(backend="nccl", pg_options=pg_options)
torch.cuda.set_device(dist.get_rank())
elif "mlu" in run_device:
dist.init_process_group(backend="cncl")
torch.mlu.set_device(dist.get_rank())
set_parallel_config(config) set_parallel_config(config)
print_config(config) print_config(config)
......
...@@ -8,6 +8,8 @@ import torch.nn as nn ...@@ -8,6 +8,8 @@ import torch.nn as nn
from safetensors import safe_open from safetensors import safe_open
from transformers import AutoTokenizer, T5ForConditionalGeneration from transformers import AutoTokenizer, T5ForConditionalGeneration
from lightx2v_platform.base.global_var import AI_DEVICE
from .format_prompt import MultilingualPromptFormat from .format_prompt import MultilingualPromptFormat
...@@ -159,14 +161,12 @@ class ByT5TextEncoder: ...@@ -159,14 +161,12 @@ class ByT5TextEncoder:
self, self,
config, config,
device=torch.device("cpu"), device=torch.device("cpu"),
run_device=torch.device("cuda"),
checkpoint_path=None, checkpoint_path=None,
byt5_max_length=256, byt5_max_length=256,
cpu_offload=False, cpu_offload=False,
): ):
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
self.config = config self.config = config
self.run_device = run_device
self.byt5_max_length = byt5_max_length self.byt5_max_length = byt5_max_length
self.enable_cfg = config.get("enable_cfg", False) self.enable_cfg = config.get("enable_cfg", False)
byT5_google_path = os.path.join(checkpoint_path, "text_encoder", "byt5-small") byT5_google_path = os.path.join(checkpoint_path, "text_encoder", "byt5-small")
...@@ -301,12 +301,12 @@ class ByT5TextEncoder: ...@@ -301,12 +301,12 @@ class ByT5TextEncoder:
negative_masks = [] negative_masks = []
for prompt in prompt_list: for prompt in prompt_list:
pos_emb, pos_mask = self._process_single_byt5_prompt(prompt, self.run_device) pos_emb, pos_mask = self._process_single_byt5_prompt(prompt, AI_DEVICE)
positive_embeddings.append(pos_emb) positive_embeddings.append(pos_emb)
positive_masks.append(pos_mask) positive_masks.append(pos_mask)
if self.enable_cfg: # TODO: 把cfg拆出去,更适合并行 if self.enable_cfg: # TODO: 把cfg拆出去,更适合并行
neg_emb, neg_mask = self._process_single_byt5_prompt("", self.run_device) neg_emb, neg_mask = self._process_single_byt5_prompt("", AI_DEVICE)
negative_embeddings.append(neg_emb) negative_embeddings.append(neg_emb)
negative_masks.append(neg_mask) negative_masks.append(neg_mask)
...@@ -328,8 +328,8 @@ class ByT5TextEncoder: ...@@ -328,8 +328,8 @@ class ByT5TextEncoder:
@torch.no_grad() @torch.no_grad()
def infer(self, prompts): def infer(self, prompts):
if self.cpu_offload: if self.cpu_offload:
self.byt5_model = self.byt5_model.to(self.run_device) self.byt5_model = self.byt5_model.to(AI_DEVICE)
self.byt5_mapper = self.byt5_mapper.to(self.run_device) self.byt5_mapper = self.byt5_mapper.to(AI_DEVICE)
byt5_embeddings, byt5_masks = self._prepare_byt5_embeddings(prompts) byt5_embeddings, byt5_masks = self._prepare_byt5_embeddings(prompts)
byt5_features = self.byt5_mapper(byt5_embeddings.to(torch.bfloat16)) byt5_features = self.byt5_mapper(byt5_embeddings.to(torch.bfloat16))
if self.cpu_offload: if self.cpu_offload:
......
...@@ -32,6 +32,9 @@ from lightx2v.models.input_encoders.hf.q_linear import ( # noqa E402 ...@@ -32,6 +32,9 @@ from lightx2v.models.input_encoders.hf.q_linear import ( # noqa E402
TorchaoQuantLinearInt8, # noqa E402 TorchaoQuantLinearInt8, # noqa E402
VllmQuantLinearInt8, # noqa E402 VllmQuantLinearInt8, # noqa E402
) )
from lightx2v_platform.base.global_var import AI_DEVICE # noqa E402
torch_device_module = getattr(torch, AI_DEVICE)
def use_default(value, default): def use_default(value, default):
...@@ -145,12 +148,7 @@ def load_text_encoder( ...@@ -145,12 +148,7 @@ def load_text_encoder(
new_w_dict[key.replace("model.", "")] = weight_dict[key] new_w_dict[key.replace("model.", "")] = weight_dict[key]
del weight_dict del weight_dict
if torch.cuda.is_available(): torch_device_module.empty_cache()
torch.cuda.empty_cache()
elif "mlu" in str(device):
torch.mlu.empty_cache()
elif "npu" in str(device):
torch.npu.empty_cache()
gc.collect() gc.collect()
text_encoder.load_state_dict(new_w_dict, assign=True) text_encoder.load_state_dict(new_w_dict, assign=True)
...@@ -552,7 +550,6 @@ class Qwen25VL_TextEncoder: ...@@ -552,7 +550,6 @@ class Qwen25VL_TextEncoder:
text_len=1000, text_len=1000,
dtype=torch.float16, dtype=torch.float16,
device=torch.device("cpu"), device=torch.device("cpu"),
run_device=torch.device("cuda"),
checkpoint_path=None, checkpoint_path=None,
cpu_offload=False, cpu_offload=False,
qwen25vl_quantized=False, qwen25vl_quantized=False,
...@@ -561,7 +558,6 @@ class Qwen25VL_TextEncoder: ...@@ -561,7 +558,6 @@ class Qwen25VL_TextEncoder:
): ):
self.text_len = text_len self.text_len = text_len
self.dtype = dtype self.dtype = dtype
self.run_device = run_device
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
self.qwen25vl_quantized = qwen25vl_quantized self.qwen25vl_quantized = qwen25vl_quantized
self.qwen25vl_quant_scheme = qwen25vl_quant_scheme self.qwen25vl_quant_scheme = qwen25vl_quant_scheme
...@@ -590,20 +586,20 @@ class Qwen25VL_TextEncoder: ...@@ -590,20 +586,20 @@ class Qwen25VL_TextEncoder:
def infer(self, texts): def infer(self, texts):
if self.cpu_offload: if self.cpu_offload:
self.text_encoder = self.text_encoder.to(self.run_device) self.text_encoder = self.text_encoder.to(AI_DEVICE)
text_inputs = self.text_encoder.text2tokens(texts, data_type="video", max_length=self.text_len) text_inputs = self.text_encoder.text2tokens(texts, data_type="video", max_length=self.text_len)
prompt_outputs = self.text_encoder.encode(text_inputs, data_type="video", device=self.run_device) prompt_outputs = self.text_encoder.encode(text_inputs, data_type="video", device=AI_DEVICE)
if self.cpu_offload: if self.cpu_offload:
self.text_encoder = self.text_encoder.to("cpu") self.text_encoder = self.text_encoder.to("cpu")
prompt_embeds = prompt_outputs.hidden_state prompt_embeds = prompt_outputs.hidden_state
attention_mask = prompt_outputs.attention_mask attention_mask = prompt_outputs.attention_mask
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.to(self.run_device) attention_mask = attention_mask.to(AI_DEVICE)
_, seq_len = attention_mask.shape _, seq_len = attention_mask.shape
attention_mask = attention_mask.repeat(1, self.num_videos_per_prompt) attention_mask = attention_mask.repeat(1, self.num_videos_per_prompt)
attention_mask = attention_mask.view(self.num_videos_per_prompt, seq_len) attention_mask = attention_mask.view(self.num_videos_per_prompt, seq_len)
prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.run_device) prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=AI_DEVICE)
seq_len = prompt_embeds.shape[1] seq_len = prompt_embeds.shape[1]
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
......
...@@ -10,6 +10,8 @@ from safetensors.torch import safe_open ...@@ -10,6 +10,8 @@ from safetensors.torch import safe_open
from transformers import SiglipImageProcessor, SiglipVisionModel from transformers import SiglipImageProcessor, SiglipVisionModel
from transformers.utils import ModelOutput from transformers.utils import ModelOutput
from lightx2v_platform.base.global_var import AI_DEVICE
PRECISION_TO_TYPE = { PRECISION_TO_TYPE = {
"fp32": torch.float32, "fp32": torch.float32,
"fp16": torch.float16, "fp16": torch.float16,
...@@ -95,7 +97,6 @@ class VisionEncoder(nn.Module): ...@@ -95,7 +97,6 @@ class VisionEncoder(nn.Module):
output_key: Optional[str] = None, output_key: Optional[str] = None,
logger=None, logger=None,
device=None, device=None,
run_device=None,
cpu_offload=False, cpu_offload=False,
): ):
super().__init__() super().__init__()
...@@ -121,7 +122,6 @@ class VisionEncoder(nn.Module): ...@@ -121,7 +122,6 @@ class VisionEncoder(nn.Module):
) )
self.dtype = self.model.dtype self.dtype = self.model.dtype
self.device = self.model.device self.device = self.model.device
self.run_device = run_device
self.processor, self.processor_path = load_image_processor( self.processor, self.processor_path = load_image_processor(
processor_type=self.processor_type, processor_type=self.processor_type,
...@@ -172,12 +172,12 @@ class VisionEncoder(nn.Module): ...@@ -172,12 +172,12 @@ class VisionEncoder(nn.Module):
VisionEncoderModelOutput with encoded features VisionEncoderModelOutput with encoded features
""" """
if self.cpu_offload: if self.cpu_offload:
self.model = self.model.to("cuda") self.model = self.model.to(AI_DEVICE)
self.processor = self.processor.to("cuda") self.processor = self.processor.to(AI_DEVICE)
if isinstance(images, np.ndarray): if isinstance(images, np.ndarray):
# Preprocess images if they're numpy arrays # Preprocess images if they're numpy arrays
preprocessed = self.processor.preprocess(images=images, return_tensors="pt").to(device=self.run_device, dtype=self.model.dtype) preprocessed = self.processor.preprocess(images=images, return_tensors="pt").to(device=AI_DEVICE, dtype=self.model.dtype)
else: else:
# Assume already preprocessed # Assume already preprocessed
preprocessed = images preprocessed = images
...@@ -232,13 +232,11 @@ class SiglipVisionEncoder: ...@@ -232,13 +232,11 @@ class SiglipVisionEncoder:
self, self,
config, config,
device=torch.device("cpu"), device=torch.device("cpu"),
run_device=torch.device("cuda"),
checkpoint_path=None, checkpoint_path=None,
cpu_offload=False, cpu_offload=False,
): ):
self.config = config self.config = config
self.device = device self.device = device
self.run_device = run_device
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
self.vision_states_dim = 1152 self.vision_states_dim = 1152
vision_encoder_path = os.path.join(checkpoint_path, "vision_encoder", "siglip") vision_encoder_path = os.path.join(checkpoint_path, "vision_encoder", "siglip")
...@@ -252,7 +250,6 @@ class SiglipVisionEncoder: ...@@ -252,7 +250,6 @@ class SiglipVisionEncoder:
output_key=None, output_key=None,
logger=None, logger=None,
device=self.device, device=self.device,
run_device=self.run_device,
cpu_offload=self.cpu_offload, cpu_offload=self.cpu_offload,
) )
...@@ -270,7 +267,7 @@ class SiglipVisionEncoder: ...@@ -270,7 +267,7 @@ class SiglipVisionEncoder:
@torch.no_grad() @torch.no_grad()
def infer(self, vision_states): def infer(self, vision_states):
if self.cpu_offload: if self.cpu_offload:
self.vision_in = self.vision_in.to(self.run_device) self.vision_in = self.vision_in.to(AI_DEVICE)
vision_states = self.vision_in(vision_states) vision_states = self.vision_in(vision_states)
if self.cpu_offload: if self.cpu_offload:
self.vision_in = self.vision_in.to("cpu") self.vision_in = self.vision_in.to("cpu")
......
...@@ -26,11 +26,6 @@ try: ...@@ -26,11 +26,6 @@ try:
except ImportError: except ImportError:
fp8_linear = None fp8_linear = None
try:
import torch_mlu_ops as tmo
except ImportError:
tmo = None
class VllmQuantLinearInt8(nn.Module): class VllmQuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16): def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
...@@ -315,19 +310,3 @@ class Q8FQuantLinearFp8(nn.Module): ...@@ -315,19 +310,3 @@ class Q8FQuantLinearFp8(nn.Module):
self.weight_scale = maybe_cast(self.weight_scale) self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias) self.bias = maybe_cast(self.bias)
return self return self
class MluQuantLinearInt8(VllmQuantLinearInt8):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__(in_features, out_features, bias, dtype)
def act_quant_func(self, x):
input_tensor_quant, input_tensor_scale = tmo.scaled_quantize(x)
return input_tensor_quant, input_tensor_scale
def forward(self, input_tensor):
input_tensor = input_tensor.squeeze(0)
dtype = input_tensor.dtype
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = tmo.scaled_matmul(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale.squeeze(-1), output_dtype=dtype)
return output_tensor.unsqueeze(0)
...@@ -5,6 +5,10 @@ import os ...@@ -5,6 +5,10 @@ import os
import torch import torch
from transformers import Qwen2Tokenizer, Qwen2_5_VLForConditionalGeneration from transformers import Qwen2Tokenizer, Qwen2_5_VLForConditionalGeneration
from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE)
try: try:
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from transformers import Qwen2VLProcessor from transformers import Qwen2VLProcessor
...@@ -58,11 +62,10 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -58,11 +62,10 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
self.VAE_IMAGE_SIZE = 1024 * 1024 self.VAE_IMAGE_SIZE = 1024 * 1024
self.cpu_offload = config.get("cpu_offload", False) self.cpu_offload = config.get("cpu_offload", False)
self.run_device = self.config.get("run_device", "cuda")
if self.cpu_offload: if self.cpu_offload:
self.device = torch.device("cpu") self.device = torch.device("cpu")
else: else:
self.device = torch.device(self.config.get("run_device", "cuda")) self.device = torch.device(AI_DEVICE)
self.dtype = torch.bfloat16 self.dtype = torch.bfloat16
self.load() self.load()
...@@ -180,9 +183,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -180,9 +183,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
if self.cpu_offload: if self.cpu_offload:
self.text_encoder.to(torch.device("cpu")) self.text_encoder.to(torch.device("cpu"))
if hasattr(torch, self.config.get("run_device", "cuda")): torch_device_module.empty_cache()
torch_module = getattr(torch, self.config.get("run_device", "cuda"))
torch_module.empty_cache()
gc.collect() gc.collect()
return prompt_embeds, prompt_embeds_mask, image_info return prompt_embeds, prompt_embeds_mask, image_info
...@@ -9,6 +9,8 @@ import torch.nn.functional as F ...@@ -9,6 +9,8 @@ import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange from einops import rearrange
from lightx2v_platform.base.global_var import AI_DEVICE
def linear_interpolation(features, output_len: int): def linear_interpolation(features, output_len: int):
features = features.transpose(1, 2) features = features.transpose(1, 2)
...@@ -252,7 +254,6 @@ class AudioAdapter(nn.Module): ...@@ -252,7 +254,6 @@ class AudioAdapter(nn.Module):
quantized: bool = False, quantized: bool = False,
quant_scheme: str = None, quant_scheme: str = None,
cpu_offload: bool = False, cpu_offload: bool = False,
run_device=torch.device("cuda"),
): ):
super().__init__() super().__init__()
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
...@@ -263,7 +264,6 @@ class AudioAdapter(nn.Module): ...@@ -263,7 +264,6 @@ class AudioAdapter(nn.Module):
mlp_dims=mlp_dims, mlp_dims=mlp_dims,
transformer_layers=projection_transformer_layers, transformer_layers=projection_transformer_layers,
) )
self.run_device = run_device
# self.num_tokens = num_tokens * 4 # self.num_tokens = num_tokens * 4
self.num_tokens_x4 = num_tokens * 4 self.num_tokens_x4 = num_tokens * 4
self.audio_pe = nn.Parameter(torch.randn(self.num_tokens_x4, mlp_dims[-1] // num_tokens) * 0.02) self.audio_pe = nn.Parameter(torch.randn(self.num_tokens_x4, mlp_dims[-1] // num_tokens) * 0.02)
...@@ -302,10 +302,10 @@ class AudioAdapter(nn.Module): ...@@ -302,10 +302,10 @@ class AudioAdapter(nn.Module):
@torch.no_grad() @torch.no_grad()
def forward_audio_proj(self, audio_feat, latent_frame): def forward_audio_proj(self, audio_feat, latent_frame):
if self.cpu_offload: if self.cpu_offload:
self.audio_proj.to(self.run_device) self.audio_proj.to(AI_DEVICE)
x = self.audio_proj(audio_feat, latent_frame) x = self.audio_proj(audio_feat, latent_frame)
x = self.rearange_audio_features(x) x = self.rearange_audio_features(x)
x = x + self.audio_pe.to(self.run_device) x = x + self.audio_pe.to(AI_DEVICE)
if self.cpu_offload: if self.cpu_offload:
self.audio_proj.to("cpu") self.audio_proj.to("cpu")
return x return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment