Unverified Commit b50498fa authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Add lightx2v_platform (#541)

parent 31da6925
......@@ -5,15 +5,14 @@
"audio_sr": 16000,
"target_video_length": 81,
"resize_mode": "adaptive",
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
"self_attn_1_type": "mlu_sage_attn",
"cross_attn_1_type": "mlu_sage_attn",
"cross_attn_2_type": "mlu_sage_attn",
"sample_guide_scale": 1.0,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"use_31_block": false,
"run_device": "mlu",
"rope_type": "torch",
"modulate_type": "torch"
}
......@@ -4,7 +4,6 @@
"video_duration": 5,
"audio_sr": 16000,
"target_video_length": 81,
"resize_mode": "adaptive",
"self_attn_1_type": "sage_attn3",
"cross_attn_1_type": "sage_attn3",
"cross_attn_2_type": "sage_attn3",
......
......@@ -2,6 +2,7 @@ __version__ = "0.1.0"
__author__ = "LightX2V Contributors"
__license__ = "Apache 2.0"
import lightx2v_platform.set_ai_device
from lightx2v import common, deploy, models, utils
from lightx2v.pipeline import LightX2VPipeline
......
from .flash_attn import FlashAttn2Weight, FlashAttn3Weight, MluFlashAttnWeight
from .flash_attn import FlashAttn2Weight, FlashAttn3Weight
from .nbhd_attn import NbhdAttnWeight, NbhdAttnWeightFlashInfer
from .radial_attn import RadialAttnWeight
from .ring_attn import RingAttnWeight
......
import math
from loguru import logger
try:
......@@ -15,12 +13,6 @@ except ImportError:
logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first")
flash_attn_varlen_func_v3 = None
try:
import torch_mlu_ops as tmo
except ImportError:
logger.info("torch_mlu_ops not found.")
tmo = None
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
......@@ -94,35 +86,3 @@ class FlashAttn3Weight(AttnWeightTemplate):
max_seqlen_kv,
).reshape(bs * max_seqlen_q, -1)
return x
@ATTN_WEIGHT_REGISTER("mlu_flash_attn")
class MluFlashAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, **kws):
if len(q.shape) == 3:
bs = 1
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
elif len(q.shape) == 4:
bs = q.shape[0]
softmax_scale = 1 / math.sqrt(q.shape[-1])
x = tmo.flash_attention(
q=q,
k=k,
v=v,
cu_seq_lens_q=cu_seqlens_q,
cu_seq_lens_kv=cu_seqlens_kv,
max_seq_len_q=max_seqlen_q,
max_seq_len_kv=max_seqlen_kv,
softmax_scale=softmax_scale,
return_lse=False,
out_dtype=q.dtype,
is_causal=False,
out=None,
alibi_slope=None,
attn_bias=None,
)
x = x.reshape(bs * max_seqlen_q, -1)
return x
import math
import torch
from loguru import logger
......@@ -26,12 +24,6 @@ except ImportError:
logger.info("sageattn3 not found, please install sageattention first")
sageattn3_blackwell = None
try:
import torch_mlu_ops as tmo
except ImportError:
tmo = None
logger.info("torch_mlu_ops not found.")
@ATTN_WEIGHT_REGISTER("sage_attn2")
class SageAttn2Weight(AttnWeightTemplate):
......@@ -89,22 +81,3 @@ class SageAttn3Weight(AttnWeightTemplate):
x = sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2).reshape(bs * max_seqlen_q, -1)
return x
@ATTN_WEIGHT_REGISTER("mlu_sage_attn")
class MluSageAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, **kws):
if len(q.shape) == 3:
bs = 1
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
elif len(q.shape) == 4:
bs = q.shape[0]
softmax_scale = 1 / math.sqrt(q.shape[-1])
x = tmo.sage_attn(
q=q, k=k, v=v, cu_seq_lens_q=None, cu_seq_lens_kv=None, max_seq_len_kv=max_seqlen_kv, max_seq_len_q=max_seqlen_q, is_causal=False, compute_dtype=torch.bfloat16, softmax_scale=softmax_scale
)
x = x.reshape(bs * max_seqlen_q, -1)
return x
......@@ -3,6 +3,7 @@ import torch.distributed as dist
from lightx2v.utils.quant_utils import dequant_fp8_vllm, quant_fp8_vllm
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
from .template import AttnWeightTemplate
from .utils.all2all import all2all_head2seq, all2all_seq2head
......@@ -75,7 +76,6 @@ class UlyssesAttnWeight(AttnWeightTemplate):
img_q = all2all_seq2head(img_q, group=seq_p_group)
img_k = all2all_seq2head(img_k, group=seq_p_group)
img_v = all2all_seq2head(img_v, group=seq_p_group)
self.device_synchronize() # 确保CUDA操作完成
# 处理文本的查询、键和值,选择当前进程的头
txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
......@@ -88,7 +88,7 @@ class UlyssesAttnWeight(AttnWeightTemplate):
v = torch.cat((img_v, txt_v), dim=0)
# 初始化累积序列长度张量
cu_seqlens_qkv = torch.zeros([2], dtype=torch.int32, device=self.config.get("run_device", "cuda"))
cu_seqlens_qkv = torch.zeros([2], dtype=torch.int32, device=AI_DEVICE)
s = txt_qkv_len + img_q.shape[0] # 计算文本和图像的总长度
s1 = s # 当前样本的结束位置
cu_seqlens_qkv[1] = s1 # 设置累积序列长度
......@@ -133,23 +133,8 @@ class UlyssesAttnWeight(AttnWeightTemplate):
img_attn = all2all_head2seq(img_attn, group=seq_p_group)
img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状
self.device_synchronize() # 确保CUDA操作完成
return img_attn
def device_synchronize(
self,
):
if torch.cuda.is_available():
# no need to sync between comm and comp
# torch.cuda.synchronize()
self.config["run_device"] = "cuda"
elif hasattr(torch, "mlu") and torch.mlu.is_available():
torch.mlu.synchronize()
self.config["run_device"] = "mlu"
elif hasattr(torch, "npu") and torch.npu.is_available():
torch.npu.synchronize()
self.config["run_device"] = "npu"
@ATTN_WEIGHT_REGISTER("ulysses-4090")
class Ulysses4090AttnWeight(AttnWeightTemplate):
......
......@@ -35,13 +35,7 @@ class Conv3dWeight(Conv3dWeightTemplate):
def load(self, weight_dict):
device = weight_dict[self.weight_name].device
if device.type in ["cuda", "mlu", "npu"]:
self.weight = weight_dict[self.weight_name]
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
else:
self.bias = None
elif device.type == "cpu":
if device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
......@@ -57,7 +51,11 @@ class Conv3dWeight(Conv3dWeightTemplate):
self.pin_bias = None
del weight_dict[self.weight_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
self.weight = weight_dict[self.weight_name]
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
else:
self.bias = None
def apply(self, input_tensor):
input_tensor = torch.nn.functional.conv3d(
......
......@@ -22,16 +22,14 @@ class EmbeddingWeightTemplate(metaclass=ABCMeta):
self.weight_cuda_buffer = weight_dict[self.weight_name].cuda()
else:
device = weight_dict[self.weight_name].device
if device.type in ["cuda", "mlu", "npu"]:
self.weight = weight_dict[self.weight_name]
elif device.type == "cpu":
if device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
del weight_dict[self.weight_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
self.weight = weight_dict[self.weight_name]
def to_cuda(self, non_blocking=False):
self.weight = self.pin_weight.cuda(non_blocking=non_blocking)
......
......@@ -67,11 +67,6 @@ try:
except ImportError:
marlin_cuda_quant = None
try:
import torch_mlu_ops as tmo
except ImportError:
tmo = None
class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
......@@ -128,14 +123,7 @@ class MMWeight(MMWeightTemplate):
self.bias_cuda_buffer = weight_dict[self.bias_name].cuda()
else:
device = weight_dict[self.weight_name].device
if device.type in ["cuda", "mlu", "npu"]:
self.weight = weight_dict[self.weight_name].t()
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
else:
self.bias = None
elif device.type == "cpu":
if device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
......@@ -153,7 +141,11 @@ class MMWeight(MMWeightTemplate):
del weight_dict[self.weight_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
self.weight = weight_dict[self.weight_name].t()
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
else:
self.bias = None
def _calculate_size(self):
if self.bias is not None:
......@@ -273,10 +265,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight_scale_cuda_buffer = weight_dict[self.weight_scale_name].float().cuda()
else:
device = weight_dict[self.weight_name].device
if device.type in ["cuda", "mlu", "npu"]:
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name].float()
elif device.type == "cpu":
if device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
......@@ -288,7 +277,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
del weight_dict[self.weight_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name].float()
if self.bias_name is not None:
if self.create_cuda_buffer:
......@@ -296,15 +286,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.bias_cuda_buffer = weight_dict[self.bias_name].cuda()
else:
device = weight_dict[self.bias_name].device
if device.type in ["cuda", "mlu", "npu"]:
self.bias = weight_dict[self.bias_name]
elif device.type == "cpu":
if device.type == "cpu":
bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.pin_bias.copy_(weight_dict[self.bias_name])
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
self.bias = weight_dict[self.bias_name]
else:
self.bias = None
self.pin_bias = None
......@@ -337,10 +325,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else:
device = weight_dict[self.weight_name].device
if device.type in ["cuda", "mlu", "npu"]:
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
elif device.type == "cpu":
if device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
......@@ -352,7 +337,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
del weight_dict[self.weight_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
def load_mxfp6(self, weight_dict):
if self.config.get("weight_auto_quant", False):
......@@ -362,10 +348,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else:
device = weight_dict[self.weight_name].device
if device.type in ["cuda", "mlu", "npu"]:
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
elif device.type == "cpu":
if device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
......@@ -377,7 +360,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
del weight_dict[self.weight_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
def load_mxfp8(self, weight_dict):
if self.config.get("weight_auto_quant", False):
......@@ -387,10 +371,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else:
device = weight_dict[self.weight_name].device
if device.type in ["cuda", "mlu", "npu"]:
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
elif device.type == "cpu":
if device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
......@@ -402,7 +383,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
del weight_dict[self.weight_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
def load_nvfp4(self, weight_dict):
device = weight_dict[self.weight_name].device
......@@ -412,12 +394,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
weight_global_scale = weight_dict[f"{self.weight_name}_global_scale"]
alpha = 1.0 / (input_global_scale * weight_global_scale)
if device.type in ["cuda", "mlu", "npu"]:
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
self.input_global_scale = input_global_scale
self.alpha = alpha
elif device.type == "cpu":
if device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
......@@ -440,7 +417,10 @@ class MMWeightQuantTemplate(MMWeightTemplate):
del weight_dict[self.weight_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
self.input_global_scale = input_global_scale
self.alpha = alpha
if self.bias_name is not None:
if self.create_cuda_buffer:
......@@ -1178,33 +1158,3 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
if hasattr(self, "bias") and self.bias is not None:
output_tensor.add_(self.bias)
return output_tensor
@MM_WEIGHT_REGISTER("int8-tmo")
class MMWeightWint8channelAint8channeldynamicMlu(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Mlu
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: mlu
"""
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_int8_perchannel_sym_tmo
def act_quant_int8_perchannel_sym_tmo(self, x):
input_tensor_quant, input_tensor_scale = tmo.scaled_quantize(x)
return input_tensor_quant, input_tensor_scale
def apply(self, input_tensor):
dtype = input_tensor.dtype
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = tmo.scaled_matmul(
input_tensor_quant, self.weight.contiguous(), input_tensor_scale, self.weight_scale.squeeze(-1), bias=self.bias if self.bias is not None else None, output_dtype=dtype, use_hp_active=True
)
return output_tensor
......@@ -32,13 +32,7 @@ class LNWeightTemplate(metaclass=ABCMeta):
else:
if self.weight_name is not None:
device = weight_dict[self.weight_name].device
if device.type in ["cuda", "mlu", "npu"]:
self.weight = weight_dict[self.weight_name]
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
else:
self.bias = None
elif device.type == "cpu":
if device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
......@@ -54,7 +48,11 @@ class LNWeightTemplate(metaclass=ABCMeta):
self.pin_bias = None
del weight_dict[self.weight_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
self.weight = weight_dict[self.weight_name]
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
else:
self.bias = None
else:
self.weight = None
self.bias = None
......
......@@ -30,16 +30,14 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self.weight_cuda_buffer = weight_dict[self.weight_name].cuda()
else:
device = weight_dict[self.weight_name].device
if device.type in ["cuda", "mlu", "npu"]:
self.weight = weight_dict[self.weight_name]
elif device.type == "cpu":
if device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
del weight_dict[self.weight_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
self.weight = weight_dict[self.weight_name]
def clear(self):
attrs = ["weight", "pinned_weight"]
......
......@@ -29,16 +29,14 @@ class DefaultTensor:
self.tensor_cuda_buffer = weight_dict[self.tensor_name].cuda()
else:
device = weight_dict[self.tensor_name].device
if device.type in ["cuda", "mlu", "npu"]:
self.tensor = weight_dict[self.tensor_name]
elif device.type == "cpu":
if device.type == "cpu":
tensor_shape = weight_dict[self.tensor_name].shape
tensor_dtype = weight_dict[self.tensor_name].dtype
self.pin_tensor = torch.empty(tensor_shape, pin_memory=True, dtype=tensor_dtype)
self.pin_tensor.copy_(weight_dict[self.tensor_name])
del weight_dict[self.tensor_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
self.tensor = weight_dict[self.tensor_name]
def clear(self):
attrs = ["tensor", "pinned_tensor"]
......
......@@ -4,11 +4,6 @@ import torch
import torch.distributed as dist
from loguru import logger
try:
from torch.distributed import ProcessGroupNCCL
except ImportError:
ProcessGroupNCCL = None
from lightx2v.common.ops import *
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner # noqa: F401
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401
......@@ -26,6 +21,8 @@ from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.set_config import print_config, set_config, set_parallel_config
from lightx2v.utils.utils import seed_all
from lightx2v_platform.base.global_var import AI_DEVICE
from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER
def init_runner(config):
......@@ -105,15 +102,8 @@ def main():
config = set_config(args)
if config["parallel"]:
run_device = config.get("run_device", "cuda")
if "cuda" in run_device:
pg_options = ProcessGroupNCCL.Options()
pg_options.is_high_priority_stream = True
dist.init_process_group(backend="nccl", pg_options=pg_options)
torch.cuda.set_device(dist.get_rank())
elif "mlu" in run_device:
dist.init_process_group(backend="cncl")
torch.mlu.set_device(dist.get_rank())
platform_device = PLATFORM_DEVICE_REGISTER.get(AI_DEVICE, None)
platform_device.init_parallel_env()
set_parallel_config(config)
print_config(config)
......
......@@ -8,6 +8,8 @@ import torch.nn as nn
from safetensors import safe_open
from transformers import AutoTokenizer, T5ForConditionalGeneration
from lightx2v_platform.base.global_var import AI_DEVICE
from .format_prompt import MultilingualPromptFormat
......@@ -159,14 +161,12 @@ class ByT5TextEncoder:
self,
config,
device=torch.device("cpu"),
run_device=torch.device("cuda"),
checkpoint_path=None,
byt5_max_length=256,
cpu_offload=False,
):
self.cpu_offload = cpu_offload
self.config = config
self.run_device = run_device
self.byt5_max_length = byt5_max_length
self.enable_cfg = config.get("enable_cfg", False)
byT5_google_path = os.path.join(checkpoint_path, "text_encoder", "byt5-small")
......@@ -301,12 +301,12 @@ class ByT5TextEncoder:
negative_masks = []
for prompt in prompt_list:
pos_emb, pos_mask = self._process_single_byt5_prompt(prompt, self.run_device)
pos_emb, pos_mask = self._process_single_byt5_prompt(prompt, AI_DEVICE)
positive_embeddings.append(pos_emb)
positive_masks.append(pos_mask)
if self.enable_cfg: # TODO: 把cfg拆出去,更适合并行
neg_emb, neg_mask = self._process_single_byt5_prompt("", self.run_device)
neg_emb, neg_mask = self._process_single_byt5_prompt("", AI_DEVICE)
negative_embeddings.append(neg_emb)
negative_masks.append(neg_mask)
......@@ -328,8 +328,8 @@ class ByT5TextEncoder:
@torch.no_grad()
def infer(self, prompts):
if self.cpu_offload:
self.byt5_model = self.byt5_model.to(self.run_device)
self.byt5_mapper = self.byt5_mapper.to(self.run_device)
self.byt5_model = self.byt5_model.to(AI_DEVICE)
self.byt5_mapper = self.byt5_mapper.to(AI_DEVICE)
byt5_embeddings, byt5_masks = self._prepare_byt5_embeddings(prompts)
byt5_features = self.byt5_mapper(byt5_embeddings.to(torch.bfloat16))
if self.cpu_offload:
......
......@@ -32,6 +32,9 @@ from lightx2v.models.input_encoders.hf.q_linear import ( # noqa E402
TorchaoQuantLinearInt8, # noqa E402
VllmQuantLinearInt8, # noqa E402
)
from lightx2v_platform.base.global_var import AI_DEVICE # noqa E402
torch_device_module = getattr(torch, AI_DEVICE)
def use_default(value, default):
......@@ -145,12 +148,7 @@ def load_text_encoder(
new_w_dict[key.replace("model.", "")] = weight_dict[key]
del weight_dict
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif "mlu" in str(device):
torch.mlu.empty_cache()
elif "npu" in str(device):
torch.npu.empty_cache()
torch_device_module.empty_cache()
gc.collect()
text_encoder.load_state_dict(new_w_dict, assign=True)
......@@ -552,7 +550,6 @@ class Qwen25VL_TextEncoder:
text_len=1000,
dtype=torch.float16,
device=torch.device("cpu"),
run_device=torch.device("cuda"),
checkpoint_path=None,
cpu_offload=False,
qwen25vl_quantized=False,
......@@ -561,7 +558,6 @@ class Qwen25VL_TextEncoder:
):
self.text_len = text_len
self.dtype = dtype
self.run_device = run_device
self.cpu_offload = cpu_offload
self.qwen25vl_quantized = qwen25vl_quantized
self.qwen25vl_quant_scheme = qwen25vl_quant_scheme
......@@ -590,20 +586,20 @@ class Qwen25VL_TextEncoder:
def infer(self, texts):
if self.cpu_offload:
self.text_encoder = self.text_encoder.to(self.run_device)
self.text_encoder = self.text_encoder.to(AI_DEVICE)
text_inputs = self.text_encoder.text2tokens(texts, data_type="video", max_length=self.text_len)
prompt_outputs = self.text_encoder.encode(text_inputs, data_type="video", device=self.run_device)
prompt_outputs = self.text_encoder.encode(text_inputs, data_type="video", device=AI_DEVICE)
if self.cpu_offload:
self.text_encoder = self.text_encoder.to("cpu")
prompt_embeds = prompt_outputs.hidden_state
attention_mask = prompt_outputs.attention_mask
if attention_mask is not None:
attention_mask = attention_mask.to(self.run_device)
attention_mask = attention_mask.to(AI_DEVICE)
_, seq_len = attention_mask.shape
attention_mask = attention_mask.repeat(1, self.num_videos_per_prompt)
attention_mask = attention_mask.view(self.num_videos_per_prompt, seq_len)
prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.run_device)
prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=AI_DEVICE)
seq_len = prompt_embeds.shape[1]
# duplicate text embeddings for each generation per prompt, using mps friendly method
......
......@@ -10,6 +10,8 @@ from safetensors.torch import safe_open
from transformers import SiglipImageProcessor, SiglipVisionModel
from transformers.utils import ModelOutput
from lightx2v_platform.base.global_var import AI_DEVICE
PRECISION_TO_TYPE = {
"fp32": torch.float32,
"fp16": torch.float16,
......@@ -95,7 +97,6 @@ class VisionEncoder(nn.Module):
output_key: Optional[str] = None,
logger=None,
device=None,
run_device=None,
cpu_offload=False,
):
super().__init__()
......@@ -121,7 +122,6 @@ class VisionEncoder(nn.Module):
)
self.dtype = self.model.dtype
self.device = self.model.device
self.run_device = run_device
self.processor, self.processor_path = load_image_processor(
processor_type=self.processor_type,
......@@ -172,12 +172,12 @@ class VisionEncoder(nn.Module):
VisionEncoderModelOutput with encoded features
"""
if self.cpu_offload:
self.model = self.model.to("cuda")
self.processor = self.processor.to("cuda")
self.model = self.model.to(AI_DEVICE)
self.processor = self.processor.to(AI_DEVICE)
if isinstance(images, np.ndarray):
# Preprocess images if they're numpy arrays
preprocessed = self.processor.preprocess(images=images, return_tensors="pt").to(device=self.run_device, dtype=self.model.dtype)
preprocessed = self.processor.preprocess(images=images, return_tensors="pt").to(device=AI_DEVICE, dtype=self.model.dtype)
else:
# Assume already preprocessed
preprocessed = images
......@@ -232,13 +232,11 @@ class SiglipVisionEncoder:
self,
config,
device=torch.device("cpu"),
run_device=torch.device("cuda"),
checkpoint_path=None,
cpu_offload=False,
):
self.config = config
self.device = device
self.run_device = run_device
self.cpu_offload = cpu_offload
self.vision_states_dim = 1152
vision_encoder_path = os.path.join(checkpoint_path, "vision_encoder", "siglip")
......@@ -252,7 +250,6 @@ class SiglipVisionEncoder:
output_key=None,
logger=None,
device=self.device,
run_device=self.run_device,
cpu_offload=self.cpu_offload,
)
......@@ -270,7 +267,7 @@ class SiglipVisionEncoder:
@torch.no_grad()
def infer(self, vision_states):
if self.cpu_offload:
self.vision_in = self.vision_in.to(self.run_device)
self.vision_in = self.vision_in.to(AI_DEVICE)
vision_states = self.vision_in(vision_states)
if self.cpu_offload:
self.vision_in = self.vision_in.to("cpu")
......
......@@ -26,11 +26,6 @@ try:
except ImportError:
fp8_linear = None
try:
import torch_mlu_ops as tmo
except ImportError:
tmo = None
class VllmQuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
......@@ -315,19 +310,3 @@ class Q8FQuantLinearFp8(nn.Module):
self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias)
return self
class MluQuantLinearInt8(VllmQuantLinearInt8):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__(in_features, out_features, bias, dtype)
def act_quant_func(self, x):
input_tensor_quant, input_tensor_scale = tmo.scaled_quantize(x)
return input_tensor_quant, input_tensor_scale
def forward(self, input_tensor):
input_tensor = input_tensor.squeeze(0)
dtype = input_tensor.dtype
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = tmo.scaled_matmul(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale.squeeze(-1), output_dtype=dtype)
return output_tensor.unsqueeze(0)
......@@ -5,6 +5,10 @@ import os
import torch
from transformers import Qwen2Tokenizer, Qwen2_5_VLForConditionalGeneration
from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE)
try:
from diffusers.image_processor import VaeImageProcessor
from transformers import Qwen2VLProcessor
......@@ -58,11 +62,10 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
self.VAE_IMAGE_SIZE = 1024 * 1024
self.cpu_offload = config.get("cpu_offload", False)
self.run_device = self.config.get("run_device", "cuda")
if self.cpu_offload:
self.device = torch.device("cpu")
else:
self.device = torch.device(self.config.get("run_device", "cuda"))
self.device = torch.device(AI_DEVICE)
self.dtype = torch.bfloat16
self.load()
......@@ -180,9 +183,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
if self.cpu_offload:
self.text_encoder.to(torch.device("cpu"))
if hasattr(torch, self.config.get("run_device", "cuda")):
torch_module = getattr(torch, self.config.get("run_device", "cuda"))
torch_module.empty_cache()
torch_device_module.empty_cache()
gc.collect()
return prompt_embeds, prompt_embeds_mask, image_info
......@@ -9,6 +9,8 @@ import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange
from lightx2v_platform.base.global_var import AI_DEVICE
def linear_interpolation(features, output_len: int):
features = features.transpose(1, 2)
......@@ -252,7 +254,6 @@ class AudioAdapter(nn.Module):
quantized: bool = False,
quant_scheme: str = None,
cpu_offload: bool = False,
run_device=torch.device("cuda"),
):
super().__init__()
self.cpu_offload = cpu_offload
......@@ -263,7 +264,6 @@ class AudioAdapter(nn.Module):
mlp_dims=mlp_dims,
transformer_layers=projection_transformer_layers,
)
self.run_device = run_device
# self.num_tokens = num_tokens * 4
self.num_tokens_x4 = num_tokens * 4
self.audio_pe = nn.Parameter(torch.randn(self.num_tokens_x4, mlp_dims[-1] // num_tokens) * 0.02)
......@@ -302,10 +302,10 @@ class AudioAdapter(nn.Module):
@torch.no_grad()
def forward_audio_proj(self, audio_feat, latent_frame):
if self.cpu_offload:
self.audio_proj.to(self.run_device)
self.audio_proj.to(AI_DEVICE)
x = self.audio_proj(audio_feat, latent_frame)
x = self.rearange_audio_features(x)
x = x + self.audio_pe.to(self.run_device)
x = x + self.audio_pe.to(AI_DEVICE)
if self.cpu_offload:
self.audio_proj.to("cpu")
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment