Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
b50498fa
Unverified
Commit
b50498fa
authored
Dec 02, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Dec 02, 2025
Browse files
Add lightx2v_platform (#541)
parent
31da6925
Changes
75
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
84 additions
and
264 deletions
+84
-264
configs/seko_talk/mlu/seko_talk_bf16.json
configs/seko_talk/mlu/seko_talk_bf16.json
+3
-4
configs/seko_talk/seko_talk_25_int8_dist_fp8_comm.json
configs/seko_talk/seko_talk_25_int8_dist_fp8_comm.json
+0
-1
lightx2v/__init__.py
lightx2v/__init__.py
+1
-0
lightx2v/common/ops/attn/__init__.py
lightx2v/common/ops/attn/__init__.py
+1
-1
lightx2v/common/ops/attn/flash_attn.py
lightx2v/common/ops/attn/flash_attn.py
+0
-40
lightx2v/common/ops/attn/sage_attn.py
lightx2v/common/ops/attn/sage_attn.py
+0
-27
lightx2v/common/ops/attn/ulysses_attn.py
lightx2v/common/ops/attn/ulysses_attn.py
+2
-17
lightx2v/common/ops/conv/conv3d.py
lightx2v/common/ops/conv/conv3d.py
+6
-8
lightx2v/common/ops/embedding/embedding_weight.py
lightx2v/common/ops/embedding/embedding_weight.py
+2
-4
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+25
-75
lightx2v/common/ops/norm/layer_norm_weight.py
lightx2v/common/ops/norm/layer_norm_weight.py
+6
-8
lightx2v/common/ops/norm/rms_norm_weight.py
lightx2v/common/ops/norm/rms_norm_weight.py
+2
-4
lightx2v/common/ops/tensor/tensor.py
lightx2v/common/ops/tensor/tensor.py
+2
-4
lightx2v/infer.py
lightx2v/infer.py
+4
-14
lightx2v/models/input_encoders/hf/hunyuan15/byt5/model.py
lightx2v/models/input_encoders/hf/hunyuan15/byt5/model.py
+6
-6
lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py
lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py
+8
-12
lightx2v/models/input_encoders/hf/hunyuan15/siglip/model.py
lightx2v/models/input_encoders/hf/hunyuan15/siglip/model.py
+6
-9
lightx2v/models/input_encoders/hf/q_linear.py
lightx2v/models/input_encoders/hf/q_linear.py
+0
-21
lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py
...t_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py
+6
-5
lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py
...tx2v/models/input_encoders/hf/seko_audio/audio_adapter.py
+4
-4
No files found.
configs/seko_talk/mlu/seko_talk_bf16.json
View file @
b50498fa
...
@@ -5,15 +5,14 @@
...
@@ -5,15 +5,14 @@
"audio_sr"
:
16000
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"target_video_length"
:
81
,
"resize_mode"
:
"adaptive"
,
"resize_mode"
:
"adaptive"
,
"self_attn_1_type"
:
"
flash
_attn
2
"
,
"self_attn_1_type"
:
"
mlu_sage
_attn"
,
"cross_attn_1_type"
:
"
flash
_attn
2
"
,
"cross_attn_1_type"
:
"
mlu_sage
_attn"
,
"cross_attn_2_type"
:
"
flash
_attn
2
"
,
"cross_attn_2_type"
:
"
mlu_sage
_attn"
,
"sample_guide_scale"
:
1.0
,
"sample_guide_scale"
:
1.0
,
"sample_shift"
:
5
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"cpu_offload"
:
false
,
"use_31_block"
:
false
,
"use_31_block"
:
false
,
"run_device"
:
"mlu"
,
"rope_type"
:
"torch"
,
"rope_type"
:
"torch"
,
"modulate_type"
:
"torch"
"modulate_type"
:
"torch"
}
}
configs/seko_talk/seko_talk_25_int8_dist_fp8_comm.json
View file @
b50498fa
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
"video_duration"
:
5
,
"video_duration"
:
5
,
"audio_sr"
:
16000
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"target_video_length"
:
81
,
"resize_mode"
:
"adaptive"
,
"self_attn_1_type"
:
"sage_attn3"
,
"self_attn_1_type"
:
"sage_attn3"
,
"cross_attn_1_type"
:
"sage_attn3"
,
"cross_attn_1_type"
:
"sage_attn3"
,
"cross_attn_2_type"
:
"sage_attn3"
,
"cross_attn_2_type"
:
"sage_attn3"
,
...
...
lightx2v/__init__.py
View file @
b50498fa
...
@@ -2,6 +2,7 @@ __version__ = "0.1.0"
...
@@ -2,6 +2,7 @@ __version__ = "0.1.0"
__author__
=
"LightX2V Contributors"
__author__
=
"LightX2V Contributors"
__license__
=
"Apache 2.0"
__license__
=
"Apache 2.0"
import
lightx2v_platform.set_ai_device
from
lightx2v
import
common
,
deploy
,
models
,
utils
from
lightx2v
import
common
,
deploy
,
models
,
utils
from
lightx2v.pipeline
import
LightX2VPipeline
from
lightx2v.pipeline
import
LightX2VPipeline
...
...
lightx2v/common/ops/attn/__init__.py
View file @
b50498fa
from
.flash_attn
import
FlashAttn2Weight
,
FlashAttn3Weight
,
MluFlashAttnWeight
from
.flash_attn
import
FlashAttn2Weight
,
FlashAttn3Weight
from
.nbhd_attn
import
NbhdAttnWeight
,
NbhdAttnWeightFlashInfer
from
.nbhd_attn
import
NbhdAttnWeight
,
NbhdAttnWeightFlashInfer
from
.radial_attn
import
RadialAttnWeight
from
.radial_attn
import
RadialAttnWeight
from
.ring_attn
import
RingAttnWeight
from
.ring_attn
import
RingAttnWeight
...
...
lightx2v/common/ops/attn/flash_attn.py
View file @
b50498fa
import
math
from
loguru
import
logger
from
loguru
import
logger
try
:
try
:
...
@@ -15,12 +13,6 @@ except ImportError:
...
@@ -15,12 +13,6 @@ except ImportError:
logger
.
info
(
"flash_attn_varlen_func_v3 not found, please install flash_attn3 first"
)
logger
.
info
(
"flash_attn_varlen_func_v3 not found, please install flash_attn3 first"
)
flash_attn_varlen_func_v3
=
None
flash_attn_varlen_func_v3
=
None
try
:
import
torch_mlu_ops
as
tmo
except
ImportError
:
logger
.
info
(
"torch_mlu_ops not found."
)
tmo
=
None
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
.template
import
AttnWeightTemplate
from
.template
import
AttnWeightTemplate
...
@@ -94,35 +86,3 @@ class FlashAttn3Weight(AttnWeightTemplate):
...
@@ -94,35 +86,3 @@ class FlashAttn3Weight(AttnWeightTemplate):
max_seqlen_kv
,
max_seqlen_kv
,
).
reshape
(
bs
*
max_seqlen_q
,
-
1
)
).
reshape
(
bs
*
max_seqlen_q
,
-
1
)
return
x
return
x
@
ATTN_WEIGHT_REGISTER
(
"mlu_flash_attn"
)
class
MluFlashAttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
**
kws
):
if
len
(
q
.
shape
)
==
3
:
bs
=
1
q
,
k
,
v
=
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
)
elif
len
(
q
.
shape
)
==
4
:
bs
=
q
.
shape
[
0
]
softmax_scale
=
1
/
math
.
sqrt
(
q
.
shape
[
-
1
])
x
=
tmo
.
flash_attention
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seq_lens_q
=
cu_seqlens_q
,
cu_seq_lens_kv
=
cu_seqlens_kv
,
max_seq_len_q
=
max_seqlen_q
,
max_seq_len_kv
=
max_seqlen_kv
,
softmax_scale
=
softmax_scale
,
return_lse
=
False
,
out_dtype
=
q
.
dtype
,
is_causal
=
False
,
out
=
None
,
alibi_slope
=
None
,
attn_bias
=
None
,
)
x
=
x
.
reshape
(
bs
*
max_seqlen_q
,
-
1
)
return
x
lightx2v/common/ops/attn/sage_attn.py
View file @
b50498fa
import
math
import
torch
import
torch
from
loguru
import
logger
from
loguru
import
logger
...
@@ -26,12 +24,6 @@ except ImportError:
...
@@ -26,12 +24,6 @@ except ImportError:
logger
.
info
(
"sageattn3 not found, please install sageattention first"
)
logger
.
info
(
"sageattn3 not found, please install sageattention first"
)
sageattn3_blackwell
=
None
sageattn3_blackwell
=
None
try
:
import
torch_mlu_ops
as
tmo
except
ImportError
:
tmo
=
None
logger
.
info
(
"torch_mlu_ops not found."
)
@
ATTN_WEIGHT_REGISTER
(
"sage_attn2"
)
@
ATTN_WEIGHT_REGISTER
(
"sage_attn2"
)
class
SageAttn2Weight
(
AttnWeightTemplate
):
class
SageAttn2Weight
(
AttnWeightTemplate
):
...
@@ -89,22 +81,3 @@ class SageAttn3Weight(AttnWeightTemplate):
...
@@ -89,22 +81,3 @@ class SageAttn3Weight(AttnWeightTemplate):
x
=
sageattn3_blackwell
(
q
.
transpose
(
1
,
2
),
k
.
transpose
(
1
,
2
),
v
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
).
reshape
(
bs
*
max_seqlen_q
,
-
1
)
x
=
sageattn3_blackwell
(
q
.
transpose
(
1
,
2
),
k
.
transpose
(
1
,
2
),
v
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
).
reshape
(
bs
*
max_seqlen_q
,
-
1
)
return
x
return
x
@
ATTN_WEIGHT_REGISTER
(
"mlu_sage_attn"
)
class
MluSageAttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
**
kws
):
if
len
(
q
.
shape
)
==
3
:
bs
=
1
q
,
k
,
v
=
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
)
elif
len
(
q
.
shape
)
==
4
:
bs
=
q
.
shape
[
0
]
softmax_scale
=
1
/
math
.
sqrt
(
q
.
shape
[
-
1
])
x
=
tmo
.
sage_attn
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seq_lens_q
=
None
,
cu_seq_lens_kv
=
None
,
max_seq_len_kv
=
max_seqlen_kv
,
max_seq_len_q
=
max_seqlen_q
,
is_causal
=
False
,
compute_dtype
=
torch
.
bfloat16
,
softmax_scale
=
softmax_scale
)
x
=
x
.
reshape
(
bs
*
max_seqlen_q
,
-
1
)
return
x
lightx2v/common/ops/attn/ulysses_attn.py
View file @
b50498fa
...
@@ -3,6 +3,7 @@ import torch.distributed as dist
...
@@ -3,6 +3,7 @@ import torch.distributed as dist
from
lightx2v.utils.quant_utils
import
dequant_fp8_vllm
,
quant_fp8_vllm
from
lightx2v.utils.quant_utils
import
dequant_fp8_vllm
,
quant_fp8_vllm
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
.template
import
AttnWeightTemplate
from
.template
import
AttnWeightTemplate
from
.utils.all2all
import
all2all_head2seq
,
all2all_seq2head
from
.utils.all2all
import
all2all_head2seq
,
all2all_seq2head
...
@@ -75,7 +76,6 @@ class UlyssesAttnWeight(AttnWeightTemplate):
...
@@ -75,7 +76,6 @@ class UlyssesAttnWeight(AttnWeightTemplate):
img_q
=
all2all_seq2head
(
img_q
,
group
=
seq_p_group
)
img_q
=
all2all_seq2head
(
img_q
,
group
=
seq_p_group
)
img_k
=
all2all_seq2head
(
img_k
,
group
=
seq_p_group
)
img_k
=
all2all_seq2head
(
img_k
,
group
=
seq_p_group
)
img_v
=
all2all_seq2head
(
img_v
,
group
=
seq_p_group
)
img_v
=
all2all_seq2head
(
img_v
,
group
=
seq_p_group
)
self
.
device_synchronize
()
# 确保CUDA操作完成
# 处理文本的查询、键和值,选择当前进程的头
# 处理文本的查询、键和值,选择当前进程的头
txt_q
=
txt_q
[:,
cur_rank
*
shard_heads
:
(
cur_rank
+
1
)
*
shard_heads
,
:]
txt_q
=
txt_q
[:,
cur_rank
*
shard_heads
:
(
cur_rank
+
1
)
*
shard_heads
,
:]
...
@@ -88,7 +88,7 @@ class UlyssesAttnWeight(AttnWeightTemplate):
...
@@ -88,7 +88,7 @@ class UlyssesAttnWeight(AttnWeightTemplate):
v
=
torch
.
cat
((
img_v
,
txt_v
),
dim
=
0
)
v
=
torch
.
cat
((
img_v
,
txt_v
),
dim
=
0
)
# 初始化累积序列长度张量
# 初始化累积序列长度张量
cu_seqlens_qkv
=
torch
.
zeros
([
2
],
dtype
=
torch
.
int32
,
device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
cu_seqlens_qkv
=
torch
.
zeros
([
2
],
dtype
=
torch
.
int32
,
device
=
AI_DEVICE
)
s
=
txt_qkv_len
+
img_q
.
shape
[
0
]
# 计算文本和图像的总长度
s
=
txt_qkv_len
+
img_q
.
shape
[
0
]
# 计算文本和图像的总长度
s1
=
s
# 当前样本的结束位置
s1
=
s
# 当前样本的结束位置
cu_seqlens_qkv
[
1
]
=
s1
# 设置累积序列长度
cu_seqlens_qkv
[
1
]
=
s1
# 设置累积序列长度
...
@@ -133,23 +133,8 @@ class UlyssesAttnWeight(AttnWeightTemplate):
...
@@ -133,23 +133,8 @@ class UlyssesAttnWeight(AttnWeightTemplate):
img_attn
=
all2all_head2seq
(
img_attn
,
group
=
seq_p_group
)
img_attn
=
all2all_head2seq
(
img_attn
,
group
=
seq_p_group
)
img_attn
=
img_attn
.
reshape
(
shard_seqlen
,
-
1
)
# 重塑为 [shard_seqlen, -1] 形状
img_attn
=
img_attn
.
reshape
(
shard_seqlen
,
-
1
)
# 重塑为 [shard_seqlen, -1] 形状
self
.
device_synchronize
()
# 确保CUDA操作完成
return
img_attn
return
img_attn
def
device_synchronize
(
self
,
):
if
torch
.
cuda
.
is_available
():
# no need to sync between comm and comp
# torch.cuda.synchronize()
self
.
config
[
"run_device"
]
=
"cuda"
elif
hasattr
(
torch
,
"mlu"
)
and
torch
.
mlu
.
is_available
():
torch
.
mlu
.
synchronize
()
self
.
config
[
"run_device"
]
=
"mlu"
elif
hasattr
(
torch
,
"npu"
)
and
torch
.
npu
.
is_available
():
torch
.
npu
.
synchronize
()
self
.
config
[
"run_device"
]
=
"npu"
@
ATTN_WEIGHT_REGISTER
(
"ulysses-4090"
)
@
ATTN_WEIGHT_REGISTER
(
"ulysses-4090"
)
class
Ulysses4090AttnWeight
(
AttnWeightTemplate
):
class
Ulysses4090AttnWeight
(
AttnWeightTemplate
):
...
...
lightx2v/common/ops/conv/conv3d.py
View file @
b50498fa
...
@@ -35,13 +35,7 @@ class Conv3dWeight(Conv3dWeightTemplate):
...
@@ -35,13 +35,7 @@ class Conv3dWeight(Conv3dWeightTemplate):
def
load
(
self
,
weight_dict
):
def
load
(
self
,
weight_dict
):
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
@@ -57,7 +51,11 @@ class Conv3dWeight(Conv3dWeightTemplate):
...
@@ -57,7 +51,11 @@ class Conv3dWeight(Conv3dWeightTemplate):
self
.
pin_bias
=
None
self
.
pin_bias
=
None
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
def
apply
(
self
,
input_tensor
):
def
apply
(
self
,
input_tensor
):
input_tensor
=
torch
.
nn
.
functional
.
conv3d
(
input_tensor
=
torch
.
nn
.
functional
.
conv3d
(
...
...
lightx2v/common/ops/embedding/embedding_weight.py
View file @
b50498fa
...
@@ -22,16 +22,14 @@ class EmbeddingWeightTemplate(metaclass=ABCMeta):
...
@@ -22,16 +22,14 @@ class EmbeddingWeightTemplate(metaclass=ABCMeta):
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
cuda
()
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
cuda
()
else
:
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
def
to_cuda
(
self
,
non_blocking
=
False
):
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
pin_weight
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight
=
self
.
pin_weight
.
cuda
(
non_blocking
=
non_blocking
)
...
...
lightx2v/common/ops/mm/mm_weight.py
View file @
b50498fa
...
@@ -67,11 +67,6 @@ try:
...
@@ -67,11 +67,6 @@ try:
except
ImportError
:
except
ImportError
:
marlin_cuda_quant
=
None
marlin_cuda_quant
=
None
try
:
import
torch_mlu_ops
as
tmo
except
ImportError
:
tmo
=
None
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
...
@@ -128,14 +123,7 @@ class MMWeight(MMWeightTemplate):
...
@@ -128,14 +123,7 @@ class MMWeight(MMWeightTemplate):
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
cuda
()
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
cuda
()
else
:
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
].
t
()
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
...
@@ -153,7 +141,11 @@ class MMWeight(MMWeightTemplate):
...
@@ -153,7 +141,11 @@ class MMWeight(MMWeightTemplate):
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
].
t
()
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
def
_calculate_size
(
self
):
def
_calculate_size
(
self
):
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
...
@@ -273,10 +265,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -273,10 +265,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
weight_scale_cuda_buffer
=
weight_dict
[
self
.
weight_scale_name
].
float
().
cuda
()
self
.
weight_scale_cuda_buffer
=
weight_dict
[
self
.
weight_scale_name
].
float
().
cuda
()
else
:
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
].
float
()
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
@@ -288,7 +277,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -288,7 +277,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
].
float
()
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
if
self
.
create_cuda_buffer
:
if
self
.
create_cuda_buffer
:
...
@@ -296,15 +286,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -296,15 +286,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
cuda
()
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
cuda
()
else
:
else
:
device
=
weight_dict
[
self
.
bias_name
].
device
device
=
weight_dict
[
self
.
bias_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
elif
device
.
type
==
"cpu"
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
self
.
pin_bias
=
None
self
.
pin_bias
=
None
...
@@ -337,10 +325,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -337,10 +325,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
@@ -352,7 +337,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -352,7 +337,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
def
load_mxfp6
(
self
,
weight_dict
):
def
load_mxfp6
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
...
@@ -362,10 +348,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -362,10 +348,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
@@ -377,7 +360,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -377,7 +360,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
def
load_mxfp8
(
self
,
weight_dict
):
def
load_mxfp8
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
...
@@ -387,10 +371,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -387,10 +371,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
@@ -402,7 +383,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -402,7 +383,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
def
load_nvfp4
(
self
,
weight_dict
):
def
load_nvfp4
(
self
,
weight_dict
):
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
...
@@ -412,12 +394,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -412,12 +394,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
weight_global_scale
=
weight_dict
[
f
"
{
self
.
weight_name
}
_global_scale"
]
weight_global_scale
=
weight_dict
[
f
"
{
self
.
weight_name
}
_global_scale"
]
alpha
=
1.0
/
(
input_global_scale
*
weight_global_scale
)
alpha
=
1.0
/
(
input_global_scale
*
weight_global_scale
)
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
self
.
input_global_scale
=
input_global_scale
self
.
alpha
=
alpha
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
@@ -440,7 +417,10 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -440,7 +417,10 @@ class MMWeightQuantTemplate(MMWeightTemplate):
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
self
.
input_global_scale
=
input_global_scale
self
.
alpha
=
alpha
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
if
self
.
create_cuda_buffer
:
if
self
.
create_cuda_buffer
:
...
@@ -1178,33 +1158,3 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
...
@@ -1178,33 +1158,3 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
output_tensor
.
add_
(
self
.
bias
)
output_tensor
.
add_
(
self
.
bias
)
return
output_tensor
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"int8-tmo"
)
class
MMWeightWint8channelAint8channeldynamicMlu
(
MMWeightQuantTemplate
):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Mlu
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: mlu
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_tmo
def
act_quant_int8_perchannel_sym_tmo
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
tmo
.
scaled_quantize
(
x
)
return
input_tensor_quant
,
input_tensor_scale
def
apply
(
self
,
input_tensor
):
dtype
=
input_tensor
.
dtype
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
tmo
.
scaled_matmul
(
input_tensor_quant
,
self
.
weight
.
contiguous
(),
input_tensor_scale
,
self
.
weight_scale
.
squeeze
(
-
1
),
bias
=
self
.
bias
if
self
.
bias
is
not
None
else
None
,
output_dtype
=
dtype
,
use_hp_active
=
True
)
return
output_tensor
lightx2v/common/ops/norm/layer_norm_weight.py
View file @
b50498fa
...
@@ -32,13 +32,7 @@ class LNWeightTemplate(metaclass=ABCMeta):
...
@@ -32,13 +32,7 @@ class LNWeightTemplate(metaclass=ABCMeta):
else
:
else
:
if
self
.
weight_name
is
not
None
:
if
self
.
weight_name
is
not
None
:
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
@@ -54,7 +48,11 @@ class LNWeightTemplate(metaclass=ABCMeta):
...
@@ -54,7 +48,11 @@ class LNWeightTemplate(metaclass=ABCMeta):
self
.
pin_bias
=
None
self
.
pin_bias
=
None
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
else
:
else
:
self
.
weight
=
None
self
.
weight
=
None
self
.
bias
=
None
self
.
bias
=
None
...
...
lightx2v/common/ops/norm/rms_norm_weight.py
View file @
b50498fa
...
@@ -30,16 +30,14 @@ class RMSWeightTemplate(metaclass=ABCMeta):
...
@@ -30,16 +30,14 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
cuda
()
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
cuda
()
else
:
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
def
clear
(
self
):
def
clear
(
self
):
attrs
=
[
"weight"
,
"pinned_weight"
]
attrs
=
[
"weight"
,
"pinned_weight"
]
...
...
lightx2v/common/ops/tensor/tensor.py
View file @
b50498fa
...
@@ -29,16 +29,14 @@ class DefaultTensor:
...
@@ -29,16 +29,14 @@ class DefaultTensor:
self
.
tensor_cuda_buffer
=
weight_dict
[
self
.
tensor_name
].
cuda
()
self
.
tensor_cuda_buffer
=
weight_dict
[
self
.
tensor_name
].
cuda
()
else
:
else
:
device
=
weight_dict
[
self
.
tensor_name
].
device
device
=
weight_dict
[
self
.
tensor_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
tensor
=
weight_dict
[
self
.
tensor_name
]
elif
device
.
type
==
"cpu"
:
tensor_shape
=
weight_dict
[
self
.
tensor_name
].
shape
tensor_shape
=
weight_dict
[
self
.
tensor_name
].
shape
tensor_dtype
=
weight_dict
[
self
.
tensor_name
].
dtype
tensor_dtype
=
weight_dict
[
self
.
tensor_name
].
dtype
self
.
pin_tensor
=
torch
.
empty
(
tensor_shape
,
pin_memory
=
True
,
dtype
=
tensor_dtype
)
self
.
pin_tensor
=
torch
.
empty
(
tensor_shape
,
pin_memory
=
True
,
dtype
=
tensor_dtype
)
self
.
pin_tensor
.
copy_
(
weight_dict
[
self
.
tensor_name
])
self
.
pin_tensor
.
copy_
(
weight_dict
[
self
.
tensor_name
])
del
weight_dict
[
self
.
tensor_name
]
del
weight_dict
[
self
.
tensor_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
tensor
=
weight_dict
[
self
.
tensor_name
]
def
clear
(
self
):
def
clear
(
self
):
attrs
=
[
"tensor"
,
"pinned_tensor"
]
attrs
=
[
"tensor"
,
"pinned_tensor"
]
...
...
lightx2v/infer.py
View file @
b50498fa
...
@@ -4,11 +4,6 @@ import torch
...
@@ -4,11 +4,6 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
loguru
import
logger
from
loguru
import
logger
try
:
from
torch.distributed
import
ProcessGroupNCCL
except
ImportError
:
ProcessGroupNCCL
=
None
from
lightx2v.common.ops
import
*
from
lightx2v.common.ops
import
*
from
lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner
import
HunyuanVideo15DistillRunner
# noqa: F401
from
lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner
import
HunyuanVideo15DistillRunner
# noqa: F401
from
lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner
import
HunyuanVideo15Runner
# noqa: F401
from
lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner
import
HunyuanVideo15Runner
# noqa: F401
...
@@ -26,6 +21,8 @@ from lightx2v.utils.profiler import *
...
@@ -26,6 +21,8 @@ from lightx2v.utils.profiler import *
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.set_config
import
print_config
,
set_config
,
set_parallel_config
from
lightx2v.utils.set_config
import
print_config
,
set_config
,
set_parallel_config
from
lightx2v.utils.utils
import
seed_all
from
lightx2v.utils.utils
import
seed_all
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
lightx2v_platform.registry_factory
import
PLATFORM_DEVICE_REGISTER
def
init_runner
(
config
):
def
init_runner
(
config
):
...
@@ -105,15 +102,8 @@ def main():
...
@@ -105,15 +102,8 @@ def main():
config
=
set_config
(
args
)
config
=
set_config
(
args
)
if
config
[
"parallel"
]:
if
config
[
"parallel"
]:
run_device
=
config
.
get
(
"run_device"
,
"cuda"
)
platform_device
=
PLATFORM_DEVICE_REGISTER
.
get
(
AI_DEVICE
,
None
)
if
"cuda"
in
run_device
:
platform_device
.
init_parallel_env
()
pg_options
=
ProcessGroupNCCL
.
Options
()
pg_options
.
is_high_priority_stream
=
True
dist
.
init_process_group
(
backend
=
"nccl"
,
pg_options
=
pg_options
)
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
elif
"mlu"
in
run_device
:
dist
.
init_process_group
(
backend
=
"cncl"
)
torch
.
mlu
.
set_device
(
dist
.
get_rank
())
set_parallel_config
(
config
)
set_parallel_config
(
config
)
print_config
(
config
)
print_config
(
config
)
...
...
lightx2v/models/input_encoders/hf/hunyuan15/byt5/model.py
View file @
b50498fa
...
@@ -8,6 +8,8 @@ import torch.nn as nn
...
@@ -8,6 +8,8 @@ import torch.nn as nn
from
safetensors
import
safe_open
from
safetensors
import
safe_open
from
transformers
import
AutoTokenizer
,
T5ForConditionalGeneration
from
transformers
import
AutoTokenizer
,
T5ForConditionalGeneration
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
.format_prompt
import
MultilingualPromptFormat
from
.format_prompt
import
MultilingualPromptFormat
...
@@ -159,14 +161,12 @@ class ByT5TextEncoder:
...
@@ -159,14 +161,12 @@ class ByT5TextEncoder:
self
,
self
,
config
,
config
,
device
=
torch
.
device
(
"cpu"
),
device
=
torch
.
device
(
"cpu"
),
run_device
=
torch
.
device
(
"cuda"
),
checkpoint_path
=
None
,
checkpoint_path
=
None
,
byt5_max_length
=
256
,
byt5_max_length
=
256
,
cpu_offload
=
False
,
cpu_offload
=
False
,
):
):
self
.
cpu_offload
=
cpu_offload
self
.
cpu_offload
=
cpu_offload
self
.
config
=
config
self
.
config
=
config
self
.
run_device
=
run_device
self
.
byt5_max_length
=
byt5_max_length
self
.
byt5_max_length
=
byt5_max_length
self
.
enable_cfg
=
config
.
get
(
"enable_cfg"
,
False
)
self
.
enable_cfg
=
config
.
get
(
"enable_cfg"
,
False
)
byT5_google_path
=
os
.
path
.
join
(
checkpoint_path
,
"text_encoder"
,
"byt5-small"
)
byT5_google_path
=
os
.
path
.
join
(
checkpoint_path
,
"text_encoder"
,
"byt5-small"
)
...
@@ -301,12 +301,12 @@ class ByT5TextEncoder:
...
@@ -301,12 +301,12 @@ class ByT5TextEncoder:
negative_masks
=
[]
negative_masks
=
[]
for
prompt
in
prompt_list
:
for
prompt
in
prompt_list
:
pos_emb
,
pos_mask
=
self
.
_process_single_byt5_prompt
(
prompt
,
self
.
run_device
)
pos_emb
,
pos_mask
=
self
.
_process_single_byt5_prompt
(
prompt
,
AI_DEVICE
)
positive_embeddings
.
append
(
pos_emb
)
positive_embeddings
.
append
(
pos_emb
)
positive_masks
.
append
(
pos_mask
)
positive_masks
.
append
(
pos_mask
)
if
self
.
enable_cfg
:
# TODO: 把cfg拆出去,更适合并行
if
self
.
enable_cfg
:
# TODO: 把cfg拆出去,更适合并行
neg_emb
,
neg_mask
=
self
.
_process_single_byt5_prompt
(
""
,
self
.
run_device
)
neg_emb
,
neg_mask
=
self
.
_process_single_byt5_prompt
(
""
,
AI_DEVICE
)
negative_embeddings
.
append
(
neg_emb
)
negative_embeddings
.
append
(
neg_emb
)
negative_masks
.
append
(
neg_mask
)
negative_masks
.
append
(
neg_mask
)
...
@@ -328,8 +328,8 @@ class ByT5TextEncoder:
...
@@ -328,8 +328,8 @@ class ByT5TextEncoder:
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
prompts
):
def
infer
(
self
,
prompts
):
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
byt5_model
=
self
.
byt5_model
.
to
(
self
.
run_device
)
self
.
byt5_model
=
self
.
byt5_model
.
to
(
AI_DEVICE
)
self
.
byt5_mapper
=
self
.
byt5_mapper
.
to
(
self
.
run_device
)
self
.
byt5_mapper
=
self
.
byt5_mapper
.
to
(
AI_DEVICE
)
byt5_embeddings
,
byt5_masks
=
self
.
_prepare_byt5_embeddings
(
prompts
)
byt5_embeddings
,
byt5_masks
=
self
.
_prepare_byt5_embeddings
(
prompts
)
byt5_features
=
self
.
byt5_mapper
(
byt5_embeddings
.
to
(
torch
.
bfloat16
))
byt5_features
=
self
.
byt5_mapper
(
byt5_embeddings
.
to
(
torch
.
bfloat16
))
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
...
...
lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py
View file @
b50498fa
...
@@ -32,6 +32,9 @@ from lightx2v.models.input_encoders.hf.q_linear import ( # noqa E402
...
@@ -32,6 +32,9 @@ from lightx2v.models.input_encoders.hf.q_linear import ( # noqa E402
TorchaoQuantLinearInt8
,
# noqa E402
TorchaoQuantLinearInt8
,
# noqa E402
VllmQuantLinearInt8
,
# noqa E402
VllmQuantLinearInt8
,
# noqa E402
)
)
from
lightx2v_platform.base.global_var
import
AI_DEVICE
# noqa E402
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
def
use_default
(
value
,
default
):
def
use_default
(
value
,
default
):
...
@@ -145,12 +148,7 @@ def load_text_encoder(
...
@@ -145,12 +148,7 @@ def load_text_encoder(
new_w_dict
[
key
.
replace
(
"model."
,
""
)]
=
weight_dict
[
key
]
new_w_dict
[
key
.
replace
(
"model."
,
""
)]
=
weight_dict
[
key
]
del
weight_dict
del
weight_dict
if
torch
.
cuda
.
is_available
():
torch_device_module
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
elif
"mlu"
in
str
(
device
):
torch
.
mlu
.
empty_cache
()
elif
"npu"
in
str
(
device
):
torch
.
npu
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
text_encoder
.
load_state_dict
(
new_w_dict
,
assign
=
True
)
text_encoder
.
load_state_dict
(
new_w_dict
,
assign
=
True
)
...
@@ -552,7 +550,6 @@ class Qwen25VL_TextEncoder:
...
@@ -552,7 +550,6 @@ class Qwen25VL_TextEncoder:
text_len
=
1000
,
text_len
=
1000
,
dtype
=
torch
.
float16
,
dtype
=
torch
.
float16
,
device
=
torch
.
device
(
"cpu"
),
device
=
torch
.
device
(
"cpu"
),
run_device
=
torch
.
device
(
"cuda"
),
checkpoint_path
=
None
,
checkpoint_path
=
None
,
cpu_offload
=
False
,
cpu_offload
=
False
,
qwen25vl_quantized
=
False
,
qwen25vl_quantized
=
False
,
...
@@ -561,7 +558,6 @@ class Qwen25VL_TextEncoder:
...
@@ -561,7 +558,6 @@ class Qwen25VL_TextEncoder:
):
):
self
.
text_len
=
text_len
self
.
text_len
=
text_len
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
run_device
=
run_device
self
.
cpu_offload
=
cpu_offload
self
.
cpu_offload
=
cpu_offload
self
.
qwen25vl_quantized
=
qwen25vl_quantized
self
.
qwen25vl_quantized
=
qwen25vl_quantized
self
.
qwen25vl_quant_scheme
=
qwen25vl_quant_scheme
self
.
qwen25vl_quant_scheme
=
qwen25vl_quant_scheme
...
@@ -590,20 +586,20 @@ class Qwen25VL_TextEncoder:
...
@@ -590,20 +586,20 @@ class Qwen25VL_TextEncoder:
def
infer
(
self
,
texts
):
def
infer
(
self
,
texts
):
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
text_encoder
=
self
.
text_encoder
.
to
(
self
.
run_device
)
self
.
text_encoder
=
self
.
text_encoder
.
to
(
AI_DEVICE
)
text_inputs
=
self
.
text_encoder
.
text2tokens
(
texts
,
data_type
=
"video"
,
max_length
=
self
.
text_len
)
text_inputs
=
self
.
text_encoder
.
text2tokens
(
texts
,
data_type
=
"video"
,
max_length
=
self
.
text_len
)
prompt_outputs
=
self
.
text_encoder
.
encode
(
text_inputs
,
data_type
=
"video"
,
device
=
self
.
run_device
)
prompt_outputs
=
self
.
text_encoder
.
encode
(
text_inputs
,
data_type
=
"video"
,
device
=
AI_DEVICE
)
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
text_encoder
=
self
.
text_encoder
.
to
(
"cpu"
)
self
.
text_encoder
=
self
.
text_encoder
.
to
(
"cpu"
)
prompt_embeds
=
prompt_outputs
.
hidden_state
prompt_embeds
=
prompt_outputs
.
hidden_state
attention_mask
=
prompt_outputs
.
attention_mask
attention_mask
=
prompt_outputs
.
attention_mask
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
.
to
(
self
.
run_device
)
attention_mask
=
attention_mask
.
to
(
AI_DEVICE
)
_
,
seq_len
=
attention_mask
.
shape
_
,
seq_len
=
attention_mask
.
shape
attention_mask
=
attention_mask
.
repeat
(
1
,
self
.
num_videos_per_prompt
)
attention_mask
=
attention_mask
.
repeat
(
1
,
self
.
num_videos_per_prompt
)
attention_mask
=
attention_mask
.
view
(
self
.
num_videos_per_prompt
,
seq_len
)
attention_mask
=
attention_mask
.
view
(
self
.
num_videos_per_prompt
,
seq_len
)
prompt_embeds
=
prompt_embeds
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
run_device
)
prompt_embeds
=
prompt_embeds
.
to
(
dtype
=
self
.
dtype
,
device
=
AI_DEVICE
)
seq_len
=
prompt_embeds
.
shape
[
1
]
seq_len
=
prompt_embeds
.
shape
[
1
]
# duplicate text embeddings for each generation per prompt, using mps friendly method
# duplicate text embeddings for each generation per prompt, using mps friendly method
...
...
lightx2v/models/input_encoders/hf/hunyuan15/siglip/model.py
View file @
b50498fa
...
@@ -10,6 +10,8 @@ from safetensors.torch import safe_open
...
@@ -10,6 +10,8 @@ from safetensors.torch import safe_open
from
transformers
import
SiglipImageProcessor
,
SiglipVisionModel
from
transformers
import
SiglipImageProcessor
,
SiglipVisionModel
from
transformers.utils
import
ModelOutput
from
transformers.utils
import
ModelOutput
from
lightx2v_platform.base.global_var
import
AI_DEVICE
PRECISION_TO_TYPE
=
{
PRECISION_TO_TYPE
=
{
"fp32"
:
torch
.
float32
,
"fp32"
:
torch
.
float32
,
"fp16"
:
torch
.
float16
,
"fp16"
:
torch
.
float16
,
...
@@ -95,7 +97,6 @@ class VisionEncoder(nn.Module):
...
@@ -95,7 +97,6 @@ class VisionEncoder(nn.Module):
output_key
:
Optional
[
str
]
=
None
,
output_key
:
Optional
[
str
]
=
None
,
logger
=
None
,
logger
=
None
,
device
=
None
,
device
=
None
,
run_device
=
None
,
cpu_offload
=
False
,
cpu_offload
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -121,7 +122,6 @@ class VisionEncoder(nn.Module):
...
@@ -121,7 +122,6 @@ class VisionEncoder(nn.Module):
)
)
self
.
dtype
=
self
.
model
.
dtype
self
.
dtype
=
self
.
model
.
dtype
self
.
device
=
self
.
model
.
device
self
.
device
=
self
.
model
.
device
self
.
run_device
=
run_device
self
.
processor
,
self
.
processor_path
=
load_image_processor
(
self
.
processor
,
self
.
processor_path
=
load_image_processor
(
processor_type
=
self
.
processor_type
,
processor_type
=
self
.
processor_type
,
...
@@ -172,12 +172,12 @@ class VisionEncoder(nn.Module):
...
@@ -172,12 +172,12 @@ class VisionEncoder(nn.Module):
VisionEncoderModelOutput with encoded features
VisionEncoderModelOutput with encoded features
"""
"""
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
model
=
self
.
model
.
to
(
"cuda"
)
self
.
model
=
self
.
model
.
to
(
AI_DEVICE
)
self
.
processor
=
self
.
processor
.
to
(
"cuda"
)
self
.
processor
=
self
.
processor
.
to
(
AI_DEVICE
)
if
isinstance
(
images
,
np
.
ndarray
):
if
isinstance
(
images
,
np
.
ndarray
):
# Preprocess images if they're numpy arrays
# Preprocess images if they're numpy arrays
preprocessed
=
self
.
processor
.
preprocess
(
images
=
images
,
return_tensors
=
"pt"
).
to
(
device
=
self
.
run_device
,
dtype
=
self
.
model
.
dtype
)
preprocessed
=
self
.
processor
.
preprocess
(
images
=
images
,
return_tensors
=
"pt"
).
to
(
device
=
AI_DEVICE
,
dtype
=
self
.
model
.
dtype
)
else
:
else
:
# Assume already preprocessed
# Assume already preprocessed
preprocessed
=
images
preprocessed
=
images
...
@@ -232,13 +232,11 @@ class SiglipVisionEncoder:
...
@@ -232,13 +232,11 @@ class SiglipVisionEncoder:
self
,
self
,
config
,
config
,
device
=
torch
.
device
(
"cpu"
),
device
=
torch
.
device
(
"cpu"
),
run_device
=
torch
.
device
(
"cuda"
),
checkpoint_path
=
None
,
checkpoint_path
=
None
,
cpu_offload
=
False
,
cpu_offload
=
False
,
):
):
self
.
config
=
config
self
.
config
=
config
self
.
device
=
device
self
.
device
=
device
self
.
run_device
=
run_device
self
.
cpu_offload
=
cpu_offload
self
.
cpu_offload
=
cpu_offload
self
.
vision_states_dim
=
1152
self
.
vision_states_dim
=
1152
vision_encoder_path
=
os
.
path
.
join
(
checkpoint_path
,
"vision_encoder"
,
"siglip"
)
vision_encoder_path
=
os
.
path
.
join
(
checkpoint_path
,
"vision_encoder"
,
"siglip"
)
...
@@ -252,7 +250,6 @@ class SiglipVisionEncoder:
...
@@ -252,7 +250,6 @@ class SiglipVisionEncoder:
output_key
=
None
,
output_key
=
None
,
logger
=
None
,
logger
=
None
,
device
=
self
.
device
,
device
=
self
.
device
,
run_device
=
self
.
run_device
,
cpu_offload
=
self
.
cpu_offload
,
cpu_offload
=
self
.
cpu_offload
,
)
)
...
@@ -270,7 +267,7 @@ class SiglipVisionEncoder:
...
@@ -270,7 +267,7 @@ class SiglipVisionEncoder:
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
vision_states
):
def
infer
(
self
,
vision_states
):
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
vision_in
=
self
.
vision_in
.
to
(
self
.
run_device
)
self
.
vision_in
=
self
.
vision_in
.
to
(
AI_DEVICE
)
vision_states
=
self
.
vision_in
(
vision_states
)
vision_states
=
self
.
vision_in
(
vision_states
)
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
vision_in
=
self
.
vision_in
.
to
(
"cpu"
)
self
.
vision_in
=
self
.
vision_in
.
to
(
"cpu"
)
...
...
lightx2v/models/input_encoders/hf/q_linear.py
View file @
b50498fa
...
@@ -26,11 +26,6 @@ try:
...
@@ -26,11 +26,6 @@ try:
except
ImportError
:
except
ImportError
:
fp8_linear
=
None
fp8_linear
=
None
try
:
import
torch_mlu_ops
as
tmo
except
ImportError
:
tmo
=
None
class
VllmQuantLinearInt8
(
nn
.
Module
):
class
VllmQuantLinearInt8
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
...
@@ -315,19 +310,3 @@ class Q8FQuantLinearFp8(nn.Module):
...
@@ -315,19 +310,3 @@ class Q8FQuantLinearFp8(nn.Module):
self
.
weight_scale
=
maybe_cast
(
self
.
weight_scale
)
self
.
weight_scale
=
maybe_cast
(
self
.
weight_scale
)
self
.
bias
=
maybe_cast
(
self
.
bias
)
self
.
bias
=
maybe_cast
(
self
.
bias
)
return
self
return
self
class
MluQuantLinearInt8
(
VllmQuantLinearInt8
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
super
().
__init__
(
in_features
,
out_features
,
bias
,
dtype
)
def
act_quant_func
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
tmo
.
scaled_quantize
(
x
)
return
input_tensor_quant
,
input_tensor_scale
def
forward
(
self
,
input_tensor
):
input_tensor
=
input_tensor
.
squeeze
(
0
)
dtype
=
input_tensor
.
dtype
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
tmo
.
scaled_matmul
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
.
squeeze
(
-
1
),
output_dtype
=
dtype
)
return
output_tensor
.
unsqueeze
(
0
)
lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py
View file @
b50498fa
...
@@ -5,6 +5,10 @@ import os
...
@@ -5,6 +5,10 @@ import os
import
torch
import
torch
from
transformers
import
Qwen2Tokenizer
,
Qwen2_5_VLForConditionalGeneration
from
transformers
import
Qwen2Tokenizer
,
Qwen2_5_VLForConditionalGeneration
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
try
:
try
:
from
diffusers.image_processor
import
VaeImageProcessor
from
diffusers.image_processor
import
VaeImageProcessor
from
transformers
import
Qwen2VLProcessor
from
transformers
import
Qwen2VLProcessor
...
@@ -58,11 +62,10 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
...
@@ -58,11 +62,10 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
self
.
VAE_IMAGE_SIZE
=
1024
*
1024
self
.
VAE_IMAGE_SIZE
=
1024
*
1024
self
.
cpu_offload
=
config
.
get
(
"cpu_offload"
,
False
)
self
.
cpu_offload
=
config
.
get
(
"cpu_offload"
,
False
)
self
.
run_device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
device
=
torch
.
device
(
"cpu"
)
else
:
else
:
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
self
.
device
=
torch
.
device
(
AI_DEVICE
)
self
.
dtype
=
torch
.
bfloat16
self
.
dtype
=
torch
.
bfloat16
self
.
load
()
self
.
load
()
...
@@ -180,9 +183,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
...
@@ -180,9 +183,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
text_encoder
.
to
(
torch
.
device
(
"cpu"
))
self
.
text_encoder
.
to
(
torch
.
device
(
"cpu"
))
if
hasattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)):
torch_device_module
.
empty_cache
()
torch_module
=
getattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
))
torch_module
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
return
prompt_embeds
,
prompt_embeds_mask
,
image_info
return
prompt_embeds
,
prompt_embeds_mask
,
image_info
lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py
View file @
b50498fa
...
@@ -9,6 +9,8 @@ import torch.nn.functional as F
...
@@ -9,6 +9,8 @@ import torch.nn.functional as F
from
diffusers.models.embeddings
import
TimestepEmbedding
,
Timesteps
from
diffusers.models.embeddings
import
TimestepEmbedding
,
Timesteps
from
einops
import
rearrange
from
einops
import
rearrange
from
lightx2v_platform.base.global_var
import
AI_DEVICE
def
linear_interpolation
(
features
,
output_len
:
int
):
def
linear_interpolation
(
features
,
output_len
:
int
):
features
=
features
.
transpose
(
1
,
2
)
features
=
features
.
transpose
(
1
,
2
)
...
@@ -252,7 +254,6 @@ class AudioAdapter(nn.Module):
...
@@ -252,7 +254,6 @@ class AudioAdapter(nn.Module):
quantized
:
bool
=
False
,
quantized
:
bool
=
False
,
quant_scheme
:
str
=
None
,
quant_scheme
:
str
=
None
,
cpu_offload
:
bool
=
False
,
cpu_offload
:
bool
=
False
,
run_device
=
torch
.
device
(
"cuda"
),
):
):
super
().
__init__
()
super
().
__init__
()
self
.
cpu_offload
=
cpu_offload
self
.
cpu_offload
=
cpu_offload
...
@@ -263,7 +264,6 @@ class AudioAdapter(nn.Module):
...
@@ -263,7 +264,6 @@ class AudioAdapter(nn.Module):
mlp_dims
=
mlp_dims
,
mlp_dims
=
mlp_dims
,
transformer_layers
=
projection_transformer_layers
,
transformer_layers
=
projection_transformer_layers
,
)
)
self
.
run_device
=
run_device
# self.num_tokens = num_tokens * 4
# self.num_tokens = num_tokens * 4
self
.
num_tokens_x4
=
num_tokens
*
4
self
.
num_tokens_x4
=
num_tokens
*
4
self
.
audio_pe
=
nn
.
Parameter
(
torch
.
randn
(
self
.
num_tokens_x4
,
mlp_dims
[
-
1
]
//
num_tokens
)
*
0.02
)
self
.
audio_pe
=
nn
.
Parameter
(
torch
.
randn
(
self
.
num_tokens_x4
,
mlp_dims
[
-
1
]
//
num_tokens
)
*
0.02
)
...
@@ -302,10 +302,10 @@ class AudioAdapter(nn.Module):
...
@@ -302,10 +302,10 @@ class AudioAdapter(nn.Module):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward_audio_proj
(
self
,
audio_feat
,
latent_frame
):
def
forward_audio_proj
(
self
,
audio_feat
,
latent_frame
):
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
audio_proj
.
to
(
self
.
run_device
)
self
.
audio_proj
.
to
(
AI_DEVICE
)
x
=
self
.
audio_proj
(
audio_feat
,
latent_frame
)
x
=
self
.
audio_proj
(
audio_feat
,
latent_frame
)
x
=
self
.
rearange_audio_features
(
x
)
x
=
self
.
rearange_audio_features
(
x
)
x
=
x
+
self
.
audio_pe
.
to
(
self
.
run_device
)
x
=
x
+
self
.
audio_pe
.
to
(
AI_DEVICE
)
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
audio_proj
.
to
(
"cpu"
)
self
.
audio_proj
.
to
(
"cpu"
)
return
x
return
x
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment