Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
b50498fa
Unverified
Commit
b50498fa
authored
Dec 02, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Dec 02, 2025
Browse files
Add lightx2v_platform (#541)
parent
31da6925
Changes
75
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
84 additions
and
264 deletions
+84
-264
configs/seko_talk/mlu/seko_talk_bf16.json
configs/seko_talk/mlu/seko_talk_bf16.json
+3
-4
configs/seko_talk/seko_talk_25_int8_dist_fp8_comm.json
configs/seko_talk/seko_talk_25_int8_dist_fp8_comm.json
+0
-1
lightx2v/__init__.py
lightx2v/__init__.py
+1
-0
lightx2v/common/ops/attn/__init__.py
lightx2v/common/ops/attn/__init__.py
+1
-1
lightx2v/common/ops/attn/flash_attn.py
lightx2v/common/ops/attn/flash_attn.py
+0
-40
lightx2v/common/ops/attn/sage_attn.py
lightx2v/common/ops/attn/sage_attn.py
+0
-27
lightx2v/common/ops/attn/ulysses_attn.py
lightx2v/common/ops/attn/ulysses_attn.py
+2
-17
lightx2v/common/ops/conv/conv3d.py
lightx2v/common/ops/conv/conv3d.py
+6
-8
lightx2v/common/ops/embedding/embedding_weight.py
lightx2v/common/ops/embedding/embedding_weight.py
+2
-4
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+25
-75
lightx2v/common/ops/norm/layer_norm_weight.py
lightx2v/common/ops/norm/layer_norm_weight.py
+6
-8
lightx2v/common/ops/norm/rms_norm_weight.py
lightx2v/common/ops/norm/rms_norm_weight.py
+2
-4
lightx2v/common/ops/tensor/tensor.py
lightx2v/common/ops/tensor/tensor.py
+2
-4
lightx2v/infer.py
lightx2v/infer.py
+4
-14
lightx2v/models/input_encoders/hf/hunyuan15/byt5/model.py
lightx2v/models/input_encoders/hf/hunyuan15/byt5/model.py
+6
-6
lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py
lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py
+8
-12
lightx2v/models/input_encoders/hf/hunyuan15/siglip/model.py
lightx2v/models/input_encoders/hf/hunyuan15/siglip/model.py
+6
-9
lightx2v/models/input_encoders/hf/q_linear.py
lightx2v/models/input_encoders/hf/q_linear.py
+0
-21
lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py
...t_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py
+6
-5
lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py
...tx2v/models/input_encoders/hf/seko_audio/audio_adapter.py
+4
-4
No files found.
configs/seko_talk/mlu/seko_talk_bf16.json
View file @
b50498fa
...
...
@@ -5,15 +5,14 @@
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"resize_mode"
:
"adaptive"
,
"self_attn_1_type"
:
"
flash
_attn
2
"
,
"cross_attn_1_type"
:
"
flash
_attn
2
"
,
"cross_attn_2_type"
:
"
flash
_attn
2
"
,
"self_attn_1_type"
:
"
mlu_sage
_attn"
,
"cross_attn_1_type"
:
"
mlu_sage
_attn"
,
"cross_attn_2_type"
:
"
mlu_sage
_attn"
,
"sample_guide_scale"
:
1.0
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"use_31_block"
:
false
,
"run_device"
:
"mlu"
,
"rope_type"
:
"torch"
,
"modulate_type"
:
"torch"
}
configs/seko_talk/seko_talk_25_int8_dist_fp8_comm.json
View file @
b50498fa
...
...
@@ -4,7 +4,6 @@
"video_duration"
:
5
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"resize_mode"
:
"adaptive"
,
"self_attn_1_type"
:
"sage_attn3"
,
"cross_attn_1_type"
:
"sage_attn3"
,
"cross_attn_2_type"
:
"sage_attn3"
,
...
...
lightx2v/__init__.py
View file @
b50498fa
...
...
@@ -2,6 +2,7 @@ __version__ = "0.1.0"
__author__
=
"LightX2V Contributors"
__license__
=
"Apache 2.0"
import
lightx2v_platform.set_ai_device
from
lightx2v
import
common
,
deploy
,
models
,
utils
from
lightx2v.pipeline
import
LightX2VPipeline
...
...
lightx2v/common/ops/attn/__init__.py
View file @
b50498fa
from
.flash_attn
import
FlashAttn2Weight
,
FlashAttn3Weight
,
MluFlashAttnWeight
from
.flash_attn
import
FlashAttn2Weight
,
FlashAttn3Weight
from
.nbhd_attn
import
NbhdAttnWeight
,
NbhdAttnWeightFlashInfer
from
.radial_attn
import
RadialAttnWeight
from
.ring_attn
import
RingAttnWeight
...
...
lightx2v/common/ops/attn/flash_attn.py
View file @
b50498fa
import
math
from
loguru
import
logger
try
:
...
...
@@ -15,12 +13,6 @@ except ImportError:
logger
.
info
(
"flash_attn_varlen_func_v3 not found, please install flash_attn3 first"
)
flash_attn_varlen_func_v3
=
None
try
:
import
torch_mlu_ops
as
tmo
except
ImportError
:
logger
.
info
(
"torch_mlu_ops not found."
)
tmo
=
None
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
.template
import
AttnWeightTemplate
...
...
@@ -94,35 +86,3 @@ class FlashAttn3Weight(AttnWeightTemplate):
max_seqlen_kv
,
).
reshape
(
bs
*
max_seqlen_q
,
-
1
)
return
x
@
ATTN_WEIGHT_REGISTER
(
"mlu_flash_attn"
)
class
MluFlashAttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
**
kws
):
if
len
(
q
.
shape
)
==
3
:
bs
=
1
q
,
k
,
v
=
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
)
elif
len
(
q
.
shape
)
==
4
:
bs
=
q
.
shape
[
0
]
softmax_scale
=
1
/
math
.
sqrt
(
q
.
shape
[
-
1
])
x
=
tmo
.
flash_attention
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seq_lens_q
=
cu_seqlens_q
,
cu_seq_lens_kv
=
cu_seqlens_kv
,
max_seq_len_q
=
max_seqlen_q
,
max_seq_len_kv
=
max_seqlen_kv
,
softmax_scale
=
softmax_scale
,
return_lse
=
False
,
out_dtype
=
q
.
dtype
,
is_causal
=
False
,
out
=
None
,
alibi_slope
=
None
,
attn_bias
=
None
,
)
x
=
x
.
reshape
(
bs
*
max_seqlen_q
,
-
1
)
return
x
lightx2v/common/ops/attn/sage_attn.py
View file @
b50498fa
import
math
import
torch
from
loguru
import
logger
...
...
@@ -26,12 +24,6 @@ except ImportError:
logger
.
info
(
"sageattn3 not found, please install sageattention first"
)
sageattn3_blackwell
=
None
try
:
import
torch_mlu_ops
as
tmo
except
ImportError
:
tmo
=
None
logger
.
info
(
"torch_mlu_ops not found."
)
@
ATTN_WEIGHT_REGISTER
(
"sage_attn2"
)
class
SageAttn2Weight
(
AttnWeightTemplate
):
...
...
@@ -89,22 +81,3 @@ class SageAttn3Weight(AttnWeightTemplate):
x
=
sageattn3_blackwell
(
q
.
transpose
(
1
,
2
),
k
.
transpose
(
1
,
2
),
v
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
).
reshape
(
bs
*
max_seqlen_q
,
-
1
)
return
x
@
ATTN_WEIGHT_REGISTER
(
"mlu_sage_attn"
)
class
MluSageAttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
**
kws
):
if
len
(
q
.
shape
)
==
3
:
bs
=
1
q
,
k
,
v
=
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
)
elif
len
(
q
.
shape
)
==
4
:
bs
=
q
.
shape
[
0
]
softmax_scale
=
1
/
math
.
sqrt
(
q
.
shape
[
-
1
])
x
=
tmo
.
sage_attn
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seq_lens_q
=
None
,
cu_seq_lens_kv
=
None
,
max_seq_len_kv
=
max_seqlen_kv
,
max_seq_len_q
=
max_seqlen_q
,
is_causal
=
False
,
compute_dtype
=
torch
.
bfloat16
,
softmax_scale
=
softmax_scale
)
x
=
x
.
reshape
(
bs
*
max_seqlen_q
,
-
1
)
return
x
lightx2v/common/ops/attn/ulysses_attn.py
View file @
b50498fa
...
...
@@ -3,6 +3,7 @@ import torch.distributed as dist
from
lightx2v.utils.quant_utils
import
dequant_fp8_vllm
,
quant_fp8_vllm
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
.template
import
AttnWeightTemplate
from
.utils.all2all
import
all2all_head2seq
,
all2all_seq2head
...
...
@@ -75,7 +76,6 @@ class UlyssesAttnWeight(AttnWeightTemplate):
img_q
=
all2all_seq2head
(
img_q
,
group
=
seq_p_group
)
img_k
=
all2all_seq2head
(
img_k
,
group
=
seq_p_group
)
img_v
=
all2all_seq2head
(
img_v
,
group
=
seq_p_group
)
self
.
device_synchronize
()
# 确保CUDA操作完成
# 处理文本的查询、键和值,选择当前进程的头
txt_q
=
txt_q
[:,
cur_rank
*
shard_heads
:
(
cur_rank
+
1
)
*
shard_heads
,
:]
...
...
@@ -88,7 +88,7 @@ class UlyssesAttnWeight(AttnWeightTemplate):
v
=
torch
.
cat
((
img_v
,
txt_v
),
dim
=
0
)
# 初始化累积序列长度张量
cu_seqlens_qkv
=
torch
.
zeros
([
2
],
dtype
=
torch
.
int32
,
device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
cu_seqlens_qkv
=
torch
.
zeros
([
2
],
dtype
=
torch
.
int32
,
device
=
AI_DEVICE
)
s
=
txt_qkv_len
+
img_q
.
shape
[
0
]
# 计算文本和图像的总长度
s1
=
s
# 当前样本的结束位置
cu_seqlens_qkv
[
1
]
=
s1
# 设置累积序列长度
...
...
@@ -133,23 +133,8 @@ class UlyssesAttnWeight(AttnWeightTemplate):
img_attn
=
all2all_head2seq
(
img_attn
,
group
=
seq_p_group
)
img_attn
=
img_attn
.
reshape
(
shard_seqlen
,
-
1
)
# 重塑为 [shard_seqlen, -1] 形状
self
.
device_synchronize
()
# 确保CUDA操作完成
return
img_attn
def
device_synchronize
(
self
,
):
if
torch
.
cuda
.
is_available
():
# no need to sync between comm and comp
# torch.cuda.synchronize()
self
.
config
[
"run_device"
]
=
"cuda"
elif
hasattr
(
torch
,
"mlu"
)
and
torch
.
mlu
.
is_available
():
torch
.
mlu
.
synchronize
()
self
.
config
[
"run_device"
]
=
"mlu"
elif
hasattr
(
torch
,
"npu"
)
and
torch
.
npu
.
is_available
():
torch
.
npu
.
synchronize
()
self
.
config
[
"run_device"
]
=
"npu"
@
ATTN_WEIGHT_REGISTER
(
"ulysses-4090"
)
class
Ulysses4090AttnWeight
(
AttnWeightTemplate
):
...
...
lightx2v/common/ops/conv/conv3d.py
View file @
b50498fa
...
...
@@ -35,13 +35,7 @@ class Conv3dWeight(Conv3dWeightTemplate):
def
load
(
self
,
weight_dict
):
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
elif
device
.
type
==
"cpu"
:
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
...
@@ -57,7 +51,11 @@ class Conv3dWeight(Conv3dWeightTemplate):
self
.
pin_bias
=
None
del
weight_dict
[
self
.
weight_name
]
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
def
apply
(
self
,
input_tensor
):
input_tensor
=
torch
.
nn
.
functional
.
conv3d
(
...
...
lightx2v/common/ops/embedding/embedding_weight.py
View file @
b50498fa
...
...
@@ -22,16 +22,14 @@ class EmbeddingWeightTemplate(metaclass=ABCMeta):
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
cuda
()
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
elif
device
.
type
==
"cpu"
:
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
del
weight_dict
[
self
.
weight_name
]
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
pin_weight
.
cuda
(
non_blocking
=
non_blocking
)
...
...
lightx2v/common/ops/mm/mm_weight.py
View file @
b50498fa
...
...
@@ -67,11 +67,6 @@ try:
except
ImportError
:
marlin_cuda_quant
=
None
try
:
import
torch_mlu_ops
as
tmo
except
ImportError
:
tmo
=
None
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
...
...
@@ -128,14 +123,7 @@ class MMWeight(MMWeightTemplate):
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
cuda
()
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
self
.
weight
=
weight_dict
[
self
.
weight_name
].
t
()
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
elif
device
.
type
==
"cpu"
:
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
...
...
@@ -153,7 +141,11 @@ class MMWeight(MMWeightTemplate):
del
weight_dict
[
self
.
weight_name
]
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
].
t
()
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
def
_calculate_size
(
self
):
if
self
.
bias
is
not
None
:
...
...
@@ -273,10 +265,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
weight_scale_cuda_buffer
=
weight_dict
[
self
.
weight_scale_name
].
float
().
cuda
()
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
].
float
()
elif
device
.
type
==
"cpu"
:
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
...
@@ -288,7 +277,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
del
weight_dict
[
self
.
weight_name
]
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
].
float
()
if
self
.
bias_name
is
not
None
:
if
self
.
create_cuda_buffer
:
...
...
@@ -296,15 +286,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
cuda
()
else
:
device
=
weight_dict
[
self
.
bias_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
elif
device
.
type
==
"cpu"
:
if
device
.
type
==
"cpu"
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
...
...
@@ -337,10 +325,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
elif
device
.
type
==
"cpu"
:
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
...
@@ -352,7 +337,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
del
weight_dict
[
self
.
weight_name
]
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
def
load_mxfp6
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
...
...
@@ -362,10 +348,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
elif
device
.
type
==
"cpu"
:
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
...
@@ -377,7 +360,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
del
weight_dict
[
self
.
weight_name
]
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
def
load_mxfp8
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
...
...
@@ -387,10 +371,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
elif
device
.
type
==
"cpu"
:
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
...
@@ -402,7 +383,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
del
weight_dict
[
self
.
weight_name
]
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
def
load_nvfp4
(
self
,
weight_dict
):
device
=
weight_dict
[
self
.
weight_name
].
device
...
...
@@ -412,12 +394,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
weight_global_scale
=
weight_dict
[
f
"
{
self
.
weight_name
}
_global_scale"
]
alpha
=
1.0
/
(
input_global_scale
*
weight_global_scale
)
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
self
.
input_global_scale
=
input_global_scale
self
.
alpha
=
alpha
elif
device
.
type
==
"cpu"
:
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
...
@@ -440,7 +417,10 @@ class MMWeightQuantTemplate(MMWeightTemplate):
del
weight_dict
[
self
.
weight_name
]
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
self
.
input_global_scale
=
input_global_scale
self
.
alpha
=
alpha
if
self
.
bias_name
is
not
None
:
if
self
.
create_cuda_buffer
:
...
...
@@ -1178,33 +1158,3 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
output_tensor
.
add_
(
self
.
bias
)
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"int8-tmo"
)
class
MMWeightWint8channelAint8channeldynamicMlu
(
MMWeightQuantTemplate
):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Mlu
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: mlu
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_tmo
def
act_quant_int8_perchannel_sym_tmo
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
tmo
.
scaled_quantize
(
x
)
return
input_tensor_quant
,
input_tensor_scale
def
apply
(
self
,
input_tensor
):
dtype
=
input_tensor
.
dtype
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
tmo
.
scaled_matmul
(
input_tensor_quant
,
self
.
weight
.
contiguous
(),
input_tensor_scale
,
self
.
weight_scale
.
squeeze
(
-
1
),
bias
=
self
.
bias
if
self
.
bias
is
not
None
else
None
,
output_dtype
=
dtype
,
use_hp_active
=
True
)
return
output_tensor
lightx2v/common/ops/norm/layer_norm_weight.py
View file @
b50498fa
...
...
@@ -32,13 +32,7 @@ class LNWeightTemplate(metaclass=ABCMeta):
else
:
if
self
.
weight_name
is
not
None
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
elif
device
.
type
==
"cpu"
:
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
...
@@ -54,7 +48,11 @@ class LNWeightTemplate(metaclass=ABCMeta):
self
.
pin_bias
=
None
del
weight_dict
[
self
.
weight_name
]
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
else
:
self
.
weight
=
None
self
.
bias
=
None
...
...
lightx2v/common/ops/norm/rms_norm_weight.py
View file @
b50498fa
...
...
@@ -30,16 +30,14 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
cuda
()
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
elif
device
.
type
==
"cpu"
:
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
del
weight_dict
[
self
.
weight_name
]
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
def
clear
(
self
):
attrs
=
[
"weight"
,
"pinned_weight"
]
...
...
lightx2v/common/ops/tensor/tensor.py
View file @
b50498fa
...
...
@@ -29,16 +29,14 @@ class DefaultTensor:
self
.
tensor_cuda_buffer
=
weight_dict
[
self
.
tensor_name
].
cuda
()
else
:
device
=
weight_dict
[
self
.
tensor_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
self
.
tensor
=
weight_dict
[
self
.
tensor_name
]
elif
device
.
type
==
"cpu"
:
if
device
.
type
==
"cpu"
:
tensor_shape
=
weight_dict
[
self
.
tensor_name
].
shape
tensor_dtype
=
weight_dict
[
self
.
tensor_name
].
dtype
self
.
pin_tensor
=
torch
.
empty
(
tensor_shape
,
pin_memory
=
True
,
dtype
=
tensor_dtype
)
self
.
pin_tensor
.
copy_
(
weight_dict
[
self
.
tensor_name
])
del
weight_dict
[
self
.
tensor_name
]
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
tensor
=
weight_dict
[
self
.
tensor_name
]
def
clear
(
self
):
attrs
=
[
"tensor"
,
"pinned_tensor"
]
...
...
lightx2v/infer.py
View file @
b50498fa
...
...
@@ -4,11 +4,6 @@ import torch
import
torch.distributed
as
dist
from
loguru
import
logger
try
:
from
torch.distributed
import
ProcessGroupNCCL
except
ImportError
:
ProcessGroupNCCL
=
None
from
lightx2v.common.ops
import
*
from
lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner
import
HunyuanVideo15DistillRunner
# noqa: F401
from
lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner
import
HunyuanVideo15Runner
# noqa: F401
...
...
@@ -26,6 +21,8 @@ from lightx2v.utils.profiler import *
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.set_config
import
print_config
,
set_config
,
set_parallel_config
from
lightx2v.utils.utils
import
seed_all
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
lightx2v_platform.registry_factory
import
PLATFORM_DEVICE_REGISTER
def
init_runner
(
config
):
...
...
@@ -105,15 +102,8 @@ def main():
config
=
set_config
(
args
)
if
config
[
"parallel"
]:
run_device
=
config
.
get
(
"run_device"
,
"cuda"
)
if
"cuda"
in
run_device
:
pg_options
=
ProcessGroupNCCL
.
Options
()
pg_options
.
is_high_priority_stream
=
True
dist
.
init_process_group
(
backend
=
"nccl"
,
pg_options
=
pg_options
)
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
elif
"mlu"
in
run_device
:
dist
.
init_process_group
(
backend
=
"cncl"
)
torch
.
mlu
.
set_device
(
dist
.
get_rank
())
platform_device
=
PLATFORM_DEVICE_REGISTER
.
get
(
AI_DEVICE
,
None
)
platform_device
.
init_parallel_env
()
set_parallel_config
(
config
)
print_config
(
config
)
...
...
lightx2v/models/input_encoders/hf/hunyuan15/byt5/model.py
View file @
b50498fa
...
...
@@ -8,6 +8,8 @@ import torch.nn as nn
from
safetensors
import
safe_open
from
transformers
import
AutoTokenizer
,
T5ForConditionalGeneration
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
.format_prompt
import
MultilingualPromptFormat
...
...
@@ -159,14 +161,12 @@ class ByT5TextEncoder:
self
,
config
,
device
=
torch
.
device
(
"cpu"
),
run_device
=
torch
.
device
(
"cuda"
),
checkpoint_path
=
None
,
byt5_max_length
=
256
,
cpu_offload
=
False
,
):
self
.
cpu_offload
=
cpu_offload
self
.
config
=
config
self
.
run_device
=
run_device
self
.
byt5_max_length
=
byt5_max_length
self
.
enable_cfg
=
config
.
get
(
"enable_cfg"
,
False
)
byT5_google_path
=
os
.
path
.
join
(
checkpoint_path
,
"text_encoder"
,
"byt5-small"
)
...
...
@@ -301,12 +301,12 @@ class ByT5TextEncoder:
negative_masks
=
[]
for
prompt
in
prompt_list
:
pos_emb
,
pos_mask
=
self
.
_process_single_byt5_prompt
(
prompt
,
self
.
run_device
)
pos_emb
,
pos_mask
=
self
.
_process_single_byt5_prompt
(
prompt
,
AI_DEVICE
)
positive_embeddings
.
append
(
pos_emb
)
positive_masks
.
append
(
pos_mask
)
if
self
.
enable_cfg
:
# TODO: 把cfg拆出去,更适合并行
neg_emb
,
neg_mask
=
self
.
_process_single_byt5_prompt
(
""
,
self
.
run_device
)
neg_emb
,
neg_mask
=
self
.
_process_single_byt5_prompt
(
""
,
AI_DEVICE
)
negative_embeddings
.
append
(
neg_emb
)
negative_masks
.
append
(
neg_mask
)
...
...
@@ -328,8 +328,8 @@ class ByT5TextEncoder:
@
torch
.
no_grad
()
def
infer
(
self
,
prompts
):
if
self
.
cpu_offload
:
self
.
byt5_model
=
self
.
byt5_model
.
to
(
self
.
run_device
)
self
.
byt5_mapper
=
self
.
byt5_mapper
.
to
(
self
.
run_device
)
self
.
byt5_model
=
self
.
byt5_model
.
to
(
AI_DEVICE
)
self
.
byt5_mapper
=
self
.
byt5_mapper
.
to
(
AI_DEVICE
)
byt5_embeddings
,
byt5_masks
=
self
.
_prepare_byt5_embeddings
(
prompts
)
byt5_features
=
self
.
byt5_mapper
(
byt5_embeddings
.
to
(
torch
.
bfloat16
))
if
self
.
cpu_offload
:
...
...
lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py
View file @
b50498fa
...
...
@@ -32,6 +32,9 @@ from lightx2v.models.input_encoders.hf.q_linear import ( # noqa E402
TorchaoQuantLinearInt8
,
# noqa E402
VllmQuantLinearInt8
,
# noqa E402
)
from
lightx2v_platform.base.global_var
import
AI_DEVICE
# noqa E402
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
def
use_default
(
value
,
default
):
...
...
@@ -145,12 +148,7 @@ def load_text_encoder(
new_w_dict
[
key
.
replace
(
"model."
,
""
)]
=
weight_dict
[
key
]
del
weight_dict
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
elif
"mlu"
in
str
(
device
):
torch
.
mlu
.
empty_cache
()
elif
"npu"
in
str
(
device
):
torch
.
npu
.
empty_cache
()
torch_device_module
.
empty_cache
()
gc
.
collect
()
text_encoder
.
load_state_dict
(
new_w_dict
,
assign
=
True
)
...
...
@@ -552,7 +550,6 @@ class Qwen25VL_TextEncoder:
text_len
=
1000
,
dtype
=
torch
.
float16
,
device
=
torch
.
device
(
"cpu"
),
run_device
=
torch
.
device
(
"cuda"
),
checkpoint_path
=
None
,
cpu_offload
=
False
,
qwen25vl_quantized
=
False
,
...
...
@@ -561,7 +558,6 @@ class Qwen25VL_TextEncoder:
):
self
.
text_len
=
text_len
self
.
dtype
=
dtype
self
.
run_device
=
run_device
self
.
cpu_offload
=
cpu_offload
self
.
qwen25vl_quantized
=
qwen25vl_quantized
self
.
qwen25vl_quant_scheme
=
qwen25vl_quant_scheme
...
...
@@ -590,20 +586,20 @@ class Qwen25VL_TextEncoder:
def
infer
(
self
,
texts
):
if
self
.
cpu_offload
:
self
.
text_encoder
=
self
.
text_encoder
.
to
(
self
.
run_device
)
self
.
text_encoder
=
self
.
text_encoder
.
to
(
AI_DEVICE
)
text_inputs
=
self
.
text_encoder
.
text2tokens
(
texts
,
data_type
=
"video"
,
max_length
=
self
.
text_len
)
prompt_outputs
=
self
.
text_encoder
.
encode
(
text_inputs
,
data_type
=
"video"
,
device
=
self
.
run_device
)
prompt_outputs
=
self
.
text_encoder
.
encode
(
text_inputs
,
data_type
=
"video"
,
device
=
AI_DEVICE
)
if
self
.
cpu_offload
:
self
.
text_encoder
=
self
.
text_encoder
.
to
(
"cpu"
)
prompt_embeds
=
prompt_outputs
.
hidden_state
attention_mask
=
prompt_outputs
.
attention_mask
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
.
to
(
self
.
run_device
)
attention_mask
=
attention_mask
.
to
(
AI_DEVICE
)
_
,
seq_len
=
attention_mask
.
shape
attention_mask
=
attention_mask
.
repeat
(
1
,
self
.
num_videos_per_prompt
)
attention_mask
=
attention_mask
.
view
(
self
.
num_videos_per_prompt
,
seq_len
)
prompt_embeds
=
prompt_embeds
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
run_device
)
prompt_embeds
=
prompt_embeds
.
to
(
dtype
=
self
.
dtype
,
device
=
AI_DEVICE
)
seq_len
=
prompt_embeds
.
shape
[
1
]
# duplicate text embeddings for each generation per prompt, using mps friendly method
...
...
lightx2v/models/input_encoders/hf/hunyuan15/siglip/model.py
View file @
b50498fa
...
...
@@ -10,6 +10,8 @@ from safetensors.torch import safe_open
from
transformers
import
SiglipImageProcessor
,
SiglipVisionModel
from
transformers.utils
import
ModelOutput
from
lightx2v_platform.base.global_var
import
AI_DEVICE
PRECISION_TO_TYPE
=
{
"fp32"
:
torch
.
float32
,
"fp16"
:
torch
.
float16
,
...
...
@@ -95,7 +97,6 @@ class VisionEncoder(nn.Module):
output_key
:
Optional
[
str
]
=
None
,
logger
=
None
,
device
=
None
,
run_device
=
None
,
cpu_offload
=
False
,
):
super
().
__init__
()
...
...
@@ -121,7 +122,6 @@ class VisionEncoder(nn.Module):
)
self
.
dtype
=
self
.
model
.
dtype
self
.
device
=
self
.
model
.
device
self
.
run_device
=
run_device
self
.
processor
,
self
.
processor_path
=
load_image_processor
(
processor_type
=
self
.
processor_type
,
...
...
@@ -172,12 +172,12 @@ class VisionEncoder(nn.Module):
VisionEncoderModelOutput with encoded features
"""
if
self
.
cpu_offload
:
self
.
model
=
self
.
model
.
to
(
"cuda"
)
self
.
processor
=
self
.
processor
.
to
(
"cuda"
)
self
.
model
=
self
.
model
.
to
(
AI_DEVICE
)
self
.
processor
=
self
.
processor
.
to
(
AI_DEVICE
)
if
isinstance
(
images
,
np
.
ndarray
):
# Preprocess images if they're numpy arrays
preprocessed
=
self
.
processor
.
preprocess
(
images
=
images
,
return_tensors
=
"pt"
).
to
(
device
=
self
.
run_device
,
dtype
=
self
.
model
.
dtype
)
preprocessed
=
self
.
processor
.
preprocess
(
images
=
images
,
return_tensors
=
"pt"
).
to
(
device
=
AI_DEVICE
,
dtype
=
self
.
model
.
dtype
)
else
:
# Assume already preprocessed
preprocessed
=
images
...
...
@@ -232,13 +232,11 @@ class SiglipVisionEncoder:
self
,
config
,
device
=
torch
.
device
(
"cpu"
),
run_device
=
torch
.
device
(
"cuda"
),
checkpoint_path
=
None
,
cpu_offload
=
False
,
):
self
.
config
=
config
self
.
device
=
device
self
.
run_device
=
run_device
self
.
cpu_offload
=
cpu_offload
self
.
vision_states_dim
=
1152
vision_encoder_path
=
os
.
path
.
join
(
checkpoint_path
,
"vision_encoder"
,
"siglip"
)
...
...
@@ -252,7 +250,6 @@ class SiglipVisionEncoder:
output_key
=
None
,
logger
=
None
,
device
=
self
.
device
,
run_device
=
self
.
run_device
,
cpu_offload
=
self
.
cpu_offload
,
)
...
...
@@ -270,7 +267,7 @@ class SiglipVisionEncoder:
@
torch
.
no_grad
()
def
infer
(
self
,
vision_states
):
if
self
.
cpu_offload
:
self
.
vision_in
=
self
.
vision_in
.
to
(
self
.
run_device
)
self
.
vision_in
=
self
.
vision_in
.
to
(
AI_DEVICE
)
vision_states
=
self
.
vision_in
(
vision_states
)
if
self
.
cpu_offload
:
self
.
vision_in
=
self
.
vision_in
.
to
(
"cpu"
)
...
...
lightx2v/models/input_encoders/hf/q_linear.py
View file @
b50498fa
...
...
@@ -26,11 +26,6 @@ try:
except
ImportError
:
fp8_linear
=
None
try
:
import
torch_mlu_ops
as
tmo
except
ImportError
:
tmo
=
None
class
VllmQuantLinearInt8
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
...
...
@@ -315,19 +310,3 @@ class Q8FQuantLinearFp8(nn.Module):
self
.
weight_scale
=
maybe_cast
(
self
.
weight_scale
)
self
.
bias
=
maybe_cast
(
self
.
bias
)
return
self
class
MluQuantLinearInt8
(
VllmQuantLinearInt8
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
super
().
__init__
(
in_features
,
out_features
,
bias
,
dtype
)
def
act_quant_func
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
tmo
.
scaled_quantize
(
x
)
return
input_tensor_quant
,
input_tensor_scale
def
forward
(
self
,
input_tensor
):
input_tensor
=
input_tensor
.
squeeze
(
0
)
dtype
=
input_tensor
.
dtype
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
tmo
.
scaled_matmul
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
.
squeeze
(
-
1
),
output_dtype
=
dtype
)
return
output_tensor
.
unsqueeze
(
0
)
lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py
View file @
b50498fa
...
...
@@ -5,6 +5,10 @@ import os
import
torch
from
transformers
import
Qwen2Tokenizer
,
Qwen2_5_VLForConditionalGeneration
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
try
:
from
diffusers.image_processor
import
VaeImageProcessor
from
transformers
import
Qwen2VLProcessor
...
...
@@ -58,11 +62,10 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
self
.
VAE_IMAGE_SIZE
=
1024
*
1024
self
.
cpu_offload
=
config
.
get
(
"cpu_offload"
,
False
)
self
.
run_device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
if
self
.
cpu_offload
:
self
.
device
=
torch
.
device
(
"cpu"
)
else
:
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
self
.
device
=
torch
.
device
(
AI_DEVICE
)
self
.
dtype
=
torch
.
bfloat16
self
.
load
()
...
...
@@ -180,9 +183,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
if
self
.
cpu_offload
:
self
.
text_encoder
.
to
(
torch
.
device
(
"cpu"
))
if
hasattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)):
torch_module
=
getattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
))
torch_module
.
empty_cache
()
torch_device_module
.
empty_cache
()
gc
.
collect
()
return
prompt_embeds
,
prompt_embeds_mask
,
image_info
lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py
View file @
b50498fa
...
...
@@ -9,6 +9,8 @@ import torch.nn.functional as F
from
diffusers.models.embeddings
import
TimestepEmbedding
,
Timesteps
from
einops
import
rearrange
from
lightx2v_platform.base.global_var
import
AI_DEVICE
def
linear_interpolation
(
features
,
output_len
:
int
):
features
=
features
.
transpose
(
1
,
2
)
...
...
@@ -252,7 +254,6 @@ class AudioAdapter(nn.Module):
quantized
:
bool
=
False
,
quant_scheme
:
str
=
None
,
cpu_offload
:
bool
=
False
,
run_device
=
torch
.
device
(
"cuda"
),
):
super
().
__init__
()
self
.
cpu_offload
=
cpu_offload
...
...
@@ -263,7 +264,6 @@ class AudioAdapter(nn.Module):
mlp_dims
=
mlp_dims
,
transformer_layers
=
projection_transformer_layers
,
)
self
.
run_device
=
run_device
# self.num_tokens = num_tokens * 4
self
.
num_tokens_x4
=
num_tokens
*
4
self
.
audio_pe
=
nn
.
Parameter
(
torch
.
randn
(
self
.
num_tokens_x4
,
mlp_dims
[
-
1
]
//
num_tokens
)
*
0.02
)
...
...
@@ -302,10 +302,10 @@ class AudioAdapter(nn.Module):
@
torch
.
no_grad
()
def
forward_audio_proj
(
self
,
audio_feat
,
latent_frame
):
if
self
.
cpu_offload
:
self
.
audio_proj
.
to
(
self
.
run_device
)
self
.
audio_proj
.
to
(
AI_DEVICE
)
x
=
self
.
audio_proj
(
audio_feat
,
latent_frame
)
x
=
self
.
rearange_audio_features
(
x
)
x
=
x
+
self
.
audio_pe
.
to
(
self
.
run_device
)
x
=
x
+
self
.
audio_pe
.
to
(
AI_DEVICE
)
if
self
.
cpu_offload
:
self
.
audio_proj
.
to
(
"cpu"
)
return
x
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment