Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
b50498fa
"vscode:/vscode.git/clone" did not exist on "09b7c26bbdb940e4e0d2216e14fd437f89fcdeb2"
Unverified
Commit
b50498fa
authored
Dec 02, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Dec 02, 2025
Browse files
Add lightx2v_platform (#541)
parent
31da6925
Changes
75
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
84 additions
and
264 deletions
+84
-264
configs/seko_talk/mlu/seko_talk_bf16.json
configs/seko_talk/mlu/seko_talk_bf16.json
+3
-4
configs/seko_talk/seko_talk_25_int8_dist_fp8_comm.json
configs/seko_talk/seko_talk_25_int8_dist_fp8_comm.json
+0
-1
lightx2v/__init__.py
lightx2v/__init__.py
+1
-0
lightx2v/common/ops/attn/__init__.py
lightx2v/common/ops/attn/__init__.py
+1
-1
lightx2v/common/ops/attn/flash_attn.py
lightx2v/common/ops/attn/flash_attn.py
+0
-40
lightx2v/common/ops/attn/sage_attn.py
lightx2v/common/ops/attn/sage_attn.py
+0
-27
lightx2v/common/ops/attn/ulysses_attn.py
lightx2v/common/ops/attn/ulysses_attn.py
+2
-17
lightx2v/common/ops/conv/conv3d.py
lightx2v/common/ops/conv/conv3d.py
+6
-8
lightx2v/common/ops/embedding/embedding_weight.py
lightx2v/common/ops/embedding/embedding_weight.py
+2
-4
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+25
-75
lightx2v/common/ops/norm/layer_norm_weight.py
lightx2v/common/ops/norm/layer_norm_weight.py
+6
-8
lightx2v/common/ops/norm/rms_norm_weight.py
lightx2v/common/ops/norm/rms_norm_weight.py
+2
-4
lightx2v/common/ops/tensor/tensor.py
lightx2v/common/ops/tensor/tensor.py
+2
-4
lightx2v/infer.py
lightx2v/infer.py
+4
-14
lightx2v/models/input_encoders/hf/hunyuan15/byt5/model.py
lightx2v/models/input_encoders/hf/hunyuan15/byt5/model.py
+6
-6
lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py
lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py
+8
-12
lightx2v/models/input_encoders/hf/hunyuan15/siglip/model.py
lightx2v/models/input_encoders/hf/hunyuan15/siglip/model.py
+6
-9
lightx2v/models/input_encoders/hf/q_linear.py
lightx2v/models/input_encoders/hf/q_linear.py
+0
-21
lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py
...t_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py
+6
-5
lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py
...tx2v/models/input_encoders/hf/seko_audio/audio_adapter.py
+4
-4
No files found.
configs/seko_talk/mlu/seko_talk_bf16.json
View file @
b50498fa
...
@@ -5,15 +5,14 @@
...
@@ -5,15 +5,14 @@
"audio_sr"
:
16000
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"target_video_length"
:
81
,
"resize_mode"
:
"adaptive"
,
"resize_mode"
:
"adaptive"
,
"self_attn_1_type"
:
"
flash
_attn
2
"
,
"self_attn_1_type"
:
"
mlu_sage
_attn"
,
"cross_attn_1_type"
:
"
flash
_attn
2
"
,
"cross_attn_1_type"
:
"
mlu_sage
_attn"
,
"cross_attn_2_type"
:
"
flash
_attn
2
"
,
"cross_attn_2_type"
:
"
mlu_sage
_attn"
,
"sample_guide_scale"
:
1.0
,
"sample_guide_scale"
:
1.0
,
"sample_shift"
:
5
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"cpu_offload"
:
false
,
"use_31_block"
:
false
,
"use_31_block"
:
false
,
"run_device"
:
"mlu"
,
"rope_type"
:
"torch"
,
"rope_type"
:
"torch"
,
"modulate_type"
:
"torch"
"modulate_type"
:
"torch"
}
}
configs/seko_talk/seko_talk_25_int8_dist_fp8_comm.json
View file @
b50498fa
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
"video_duration"
:
5
,
"video_duration"
:
5
,
"audio_sr"
:
16000
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"target_video_length"
:
81
,
"resize_mode"
:
"adaptive"
,
"self_attn_1_type"
:
"sage_attn3"
,
"self_attn_1_type"
:
"sage_attn3"
,
"cross_attn_1_type"
:
"sage_attn3"
,
"cross_attn_1_type"
:
"sage_attn3"
,
"cross_attn_2_type"
:
"sage_attn3"
,
"cross_attn_2_type"
:
"sage_attn3"
,
...
...
lightx2v/__init__.py
View file @
b50498fa
...
@@ -2,6 +2,7 @@ __version__ = "0.1.0"
...
@@ -2,6 +2,7 @@ __version__ = "0.1.0"
__author__
=
"LightX2V Contributors"
__author__
=
"LightX2V Contributors"
__license__
=
"Apache 2.0"
__license__
=
"Apache 2.0"
import
lightx2v_platform.set_ai_device
from
lightx2v
import
common
,
deploy
,
models
,
utils
from
lightx2v
import
common
,
deploy
,
models
,
utils
from
lightx2v.pipeline
import
LightX2VPipeline
from
lightx2v.pipeline
import
LightX2VPipeline
...
...
lightx2v/common/ops/attn/__init__.py
View file @
b50498fa
from
.flash_attn
import
FlashAttn2Weight
,
FlashAttn3Weight
,
MluFlashAttnWeight
from
.flash_attn
import
FlashAttn2Weight
,
FlashAttn3Weight
from
.nbhd_attn
import
NbhdAttnWeight
,
NbhdAttnWeightFlashInfer
from
.nbhd_attn
import
NbhdAttnWeight
,
NbhdAttnWeightFlashInfer
from
.radial_attn
import
RadialAttnWeight
from
.radial_attn
import
RadialAttnWeight
from
.ring_attn
import
RingAttnWeight
from
.ring_attn
import
RingAttnWeight
...
...
lightx2v/common/ops/attn/flash_attn.py
View file @
b50498fa
import
math
from
loguru
import
logger
from
loguru
import
logger
try
:
try
:
...
@@ -15,12 +13,6 @@ except ImportError:
...
@@ -15,12 +13,6 @@ except ImportError:
logger
.
info
(
"flash_attn_varlen_func_v3 not found, please install flash_attn3 first"
)
logger
.
info
(
"flash_attn_varlen_func_v3 not found, please install flash_attn3 first"
)
flash_attn_varlen_func_v3
=
None
flash_attn_varlen_func_v3
=
None
try
:
import
torch_mlu_ops
as
tmo
except
ImportError
:
logger
.
info
(
"torch_mlu_ops not found."
)
tmo
=
None
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
.template
import
AttnWeightTemplate
from
.template
import
AttnWeightTemplate
...
@@ -94,35 +86,3 @@ class FlashAttn3Weight(AttnWeightTemplate):
...
@@ -94,35 +86,3 @@ class FlashAttn3Weight(AttnWeightTemplate):
max_seqlen_kv
,
max_seqlen_kv
,
).
reshape
(
bs
*
max_seqlen_q
,
-
1
)
).
reshape
(
bs
*
max_seqlen_q
,
-
1
)
return
x
return
x
@
ATTN_WEIGHT_REGISTER
(
"mlu_flash_attn"
)
class
MluFlashAttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
**
kws
):
if
len
(
q
.
shape
)
==
3
:
bs
=
1
q
,
k
,
v
=
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
)
elif
len
(
q
.
shape
)
==
4
:
bs
=
q
.
shape
[
0
]
softmax_scale
=
1
/
math
.
sqrt
(
q
.
shape
[
-
1
])
x
=
tmo
.
flash_attention
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seq_lens_q
=
cu_seqlens_q
,
cu_seq_lens_kv
=
cu_seqlens_kv
,
max_seq_len_q
=
max_seqlen_q
,
max_seq_len_kv
=
max_seqlen_kv
,
softmax_scale
=
softmax_scale
,
return_lse
=
False
,
out_dtype
=
q
.
dtype
,
is_causal
=
False
,
out
=
None
,
alibi_slope
=
None
,
attn_bias
=
None
,
)
x
=
x
.
reshape
(
bs
*
max_seqlen_q
,
-
1
)
return
x
lightx2v/common/ops/attn/sage_attn.py
View file @
b50498fa
import
math
import
torch
import
torch
from
loguru
import
logger
from
loguru
import
logger
...
@@ -26,12 +24,6 @@ except ImportError:
...
@@ -26,12 +24,6 @@ except ImportError:
logger
.
info
(
"sageattn3 not found, please install sageattention first"
)
logger
.
info
(
"sageattn3 not found, please install sageattention first"
)
sageattn3_blackwell
=
None
sageattn3_blackwell
=
None
try
:
import
torch_mlu_ops
as
tmo
except
ImportError
:
tmo
=
None
logger
.
info
(
"torch_mlu_ops not found."
)
@
ATTN_WEIGHT_REGISTER
(
"sage_attn2"
)
@
ATTN_WEIGHT_REGISTER
(
"sage_attn2"
)
class
SageAttn2Weight
(
AttnWeightTemplate
):
class
SageAttn2Weight
(
AttnWeightTemplate
):
...
@@ -89,22 +81,3 @@ class SageAttn3Weight(AttnWeightTemplate):
...
@@ -89,22 +81,3 @@ class SageAttn3Weight(AttnWeightTemplate):
x
=
sageattn3_blackwell
(
q
.
transpose
(
1
,
2
),
k
.
transpose
(
1
,
2
),
v
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
).
reshape
(
bs
*
max_seqlen_q
,
-
1
)
x
=
sageattn3_blackwell
(
q
.
transpose
(
1
,
2
),
k
.
transpose
(
1
,
2
),
v
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
).
reshape
(
bs
*
max_seqlen_q
,
-
1
)
return
x
return
x
@
ATTN_WEIGHT_REGISTER
(
"mlu_sage_attn"
)
class
MluSageAttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
**
kws
):
if
len
(
q
.
shape
)
==
3
:
bs
=
1
q
,
k
,
v
=
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
)
elif
len
(
q
.
shape
)
==
4
:
bs
=
q
.
shape
[
0
]
softmax_scale
=
1
/
math
.
sqrt
(
q
.
shape
[
-
1
])
x
=
tmo
.
sage_attn
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seq_lens_q
=
None
,
cu_seq_lens_kv
=
None
,
max_seq_len_kv
=
max_seqlen_kv
,
max_seq_len_q
=
max_seqlen_q
,
is_causal
=
False
,
compute_dtype
=
torch
.
bfloat16
,
softmax_scale
=
softmax_scale
)
x
=
x
.
reshape
(
bs
*
max_seqlen_q
,
-
1
)
return
x
lightx2v/common/ops/attn/ulysses_attn.py
View file @
b50498fa
...
@@ -3,6 +3,7 @@ import torch.distributed as dist
...
@@ -3,6 +3,7 @@ import torch.distributed as dist
from
lightx2v.utils.quant_utils
import
dequant_fp8_vllm
,
quant_fp8_vllm
from
lightx2v.utils.quant_utils
import
dequant_fp8_vllm
,
quant_fp8_vllm
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
.template
import
AttnWeightTemplate
from
.template
import
AttnWeightTemplate
from
.utils.all2all
import
all2all_head2seq
,
all2all_seq2head
from
.utils.all2all
import
all2all_head2seq
,
all2all_seq2head
...
@@ -75,7 +76,6 @@ class UlyssesAttnWeight(AttnWeightTemplate):
...
@@ -75,7 +76,6 @@ class UlyssesAttnWeight(AttnWeightTemplate):
img_q
=
all2all_seq2head
(
img_q
,
group
=
seq_p_group
)
img_q
=
all2all_seq2head
(
img_q
,
group
=
seq_p_group
)
img_k
=
all2all_seq2head
(
img_k
,
group
=
seq_p_group
)
img_k
=
all2all_seq2head
(
img_k
,
group
=
seq_p_group
)
img_v
=
all2all_seq2head
(
img_v
,
group
=
seq_p_group
)
img_v
=
all2all_seq2head
(
img_v
,
group
=
seq_p_group
)
self
.
device_synchronize
()
# 确保CUDA操作完成
# 处理文本的查询、键和值,选择当前进程的头
# 处理文本的查询、键和值,选择当前进程的头
txt_q
=
txt_q
[:,
cur_rank
*
shard_heads
:
(
cur_rank
+
1
)
*
shard_heads
,
:]
txt_q
=
txt_q
[:,
cur_rank
*
shard_heads
:
(
cur_rank
+
1
)
*
shard_heads
,
:]
...
@@ -88,7 +88,7 @@ class UlyssesAttnWeight(AttnWeightTemplate):
...
@@ -88,7 +88,7 @@ class UlyssesAttnWeight(AttnWeightTemplate):
v
=
torch
.
cat
((
img_v
,
txt_v
),
dim
=
0
)
v
=
torch
.
cat
((
img_v
,
txt_v
),
dim
=
0
)
# 初始化累积序列长度张量
# 初始化累积序列长度张量
cu_seqlens_qkv
=
torch
.
zeros
([
2
],
dtype
=
torch
.
int32
,
device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
cu_seqlens_qkv
=
torch
.
zeros
([
2
],
dtype
=
torch
.
int32
,
device
=
AI_DEVICE
)
s
=
txt_qkv_len
+
img_q
.
shape
[
0
]
# 计算文本和图像的总长度
s
=
txt_qkv_len
+
img_q
.
shape
[
0
]
# 计算文本和图像的总长度
s1
=
s
# 当前样本的结束位置
s1
=
s
# 当前样本的结束位置
cu_seqlens_qkv
[
1
]
=
s1
# 设置累积序列长度
cu_seqlens_qkv
[
1
]
=
s1
# 设置累积序列长度
...
@@ -133,23 +133,8 @@ class UlyssesAttnWeight(AttnWeightTemplate):
...
@@ -133,23 +133,8 @@ class UlyssesAttnWeight(AttnWeightTemplate):
img_attn
=
all2all_head2seq
(
img_attn
,
group
=
seq_p_group
)
img_attn
=
all2all_head2seq
(
img_attn
,
group
=
seq_p_group
)
img_attn
=
img_attn
.
reshape
(
shard_seqlen
,
-
1
)
# 重塑为 [shard_seqlen, -1] 形状
img_attn
=
img_attn
.
reshape
(
shard_seqlen
,
-
1
)
# 重塑为 [shard_seqlen, -1] 形状
self
.
device_synchronize
()
# 确保CUDA操作完成
return
img_attn
return
img_attn
def
device_synchronize
(
self
,
):
if
torch
.
cuda
.
is_available
():
# no need to sync between comm and comp
# torch.cuda.synchronize()
self
.
config
[
"run_device"
]
=
"cuda"
elif
hasattr
(
torch
,
"mlu"
)
and
torch
.
mlu
.
is_available
():
torch
.
mlu
.
synchronize
()
self
.
config
[
"run_device"
]
=
"mlu"
elif
hasattr
(
torch
,
"npu"
)
and
torch
.
npu
.
is_available
():
torch
.
npu
.
synchronize
()
self
.
config
[
"run_device"
]
=
"npu"
@
ATTN_WEIGHT_REGISTER
(
"ulysses-4090"
)
@
ATTN_WEIGHT_REGISTER
(
"ulysses-4090"
)
class
Ulysses4090AttnWeight
(
AttnWeightTemplate
):
class
Ulysses4090AttnWeight
(
AttnWeightTemplate
):
...
...
lightx2v/common/ops/conv/conv3d.py
View file @
b50498fa
...
@@ -35,13 +35,7 @@ class Conv3dWeight(Conv3dWeightTemplate):
...
@@ -35,13 +35,7 @@ class Conv3dWeight(Conv3dWeightTemplate):
def
load
(
self
,
weight_dict
):
def
load
(
self
,
weight_dict
):
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
@@ -57,7 +51,11 @@ class Conv3dWeight(Conv3dWeightTemplate):
...
@@ -57,7 +51,11 @@ class Conv3dWeight(Conv3dWeightTemplate):
self
.
pin_bias
=
None
self
.
pin_bias
=
None
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
def
apply
(
self
,
input_tensor
):
def
apply
(
self
,
input_tensor
):
input_tensor
=
torch
.
nn
.
functional
.
conv3d
(
input_tensor
=
torch
.
nn
.
functional
.
conv3d
(
...
...
lightx2v/common/ops/embedding/embedding_weight.py
View file @
b50498fa
...
@@ -22,16 +22,14 @@ class EmbeddingWeightTemplate(metaclass=ABCMeta):
...
@@ -22,16 +22,14 @@ class EmbeddingWeightTemplate(metaclass=ABCMeta):
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
cuda
()
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
cuda
()
else
:
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
def
to_cuda
(
self
,
non_blocking
=
False
):
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
pin_weight
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight
=
self
.
pin_weight
.
cuda
(
non_blocking
=
non_blocking
)
...
...
lightx2v/common/ops/mm/mm_weight.py
View file @
b50498fa
...
@@ -67,11 +67,6 @@ try:
...
@@ -67,11 +67,6 @@ try:
except
ImportError
:
except
ImportError
:
marlin_cuda_quant
=
None
marlin_cuda_quant
=
None
try
:
import
torch_mlu_ops
as
tmo
except
ImportError
:
tmo
=
None
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
...
@@ -128,14 +123,7 @@ class MMWeight(MMWeightTemplate):
...
@@ -128,14 +123,7 @@ class MMWeight(MMWeightTemplate):
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
cuda
()
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
cuda
()
else
:
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
].
t
()
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
...
@@ -153,7 +141,11 @@ class MMWeight(MMWeightTemplate):
...
@@ -153,7 +141,11 @@ class MMWeight(MMWeightTemplate):
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
].
t
()
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
def
_calculate_size
(
self
):
def
_calculate_size
(
self
):
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
...
@@ -273,10 +265,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -273,10 +265,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
weight_scale_cuda_buffer
=
weight_dict
[
self
.
weight_scale_name
].
float
().
cuda
()
self
.
weight_scale_cuda_buffer
=
weight_dict
[
self
.
weight_scale_name
].
float
().
cuda
()
else
:
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
].
float
()
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
@@ -288,7 +277,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -288,7 +277,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
].
float
()
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
if
self
.
create_cuda_buffer
:
if
self
.
create_cuda_buffer
:
...
@@ -296,15 +286,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -296,15 +286,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
cuda
()
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
cuda
()
else
:
else
:
device
=
weight_dict
[
self
.
bias_name
].
device
device
=
weight_dict
[
self
.
bias_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
elif
device
.
type
==
"cpu"
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
self
.
pin_bias
=
None
self
.
pin_bias
=
None
...
@@ -337,10 +325,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -337,10 +325,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
@@ -352,7 +337,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -352,7 +337,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
def
load_mxfp6
(
self
,
weight_dict
):
def
load_mxfp6
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
...
@@ -362,10 +348,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -362,10 +348,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
@@ -377,7 +360,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -377,7 +360,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
def
load_mxfp8
(
self
,
weight_dict
):
def
load_mxfp8
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
...
@@ -387,10 +371,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -387,10 +371,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
@@ -402,7 +383,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -402,7 +383,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
def
load_nvfp4
(
self
,
weight_dict
):
def
load_nvfp4
(
self
,
weight_dict
):
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
...
@@ -412,12 +394,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -412,12 +394,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
weight_global_scale
=
weight_dict
[
f
"
{
self
.
weight_name
}
_global_scale"
]
weight_global_scale
=
weight_dict
[
f
"
{
self
.
weight_name
}
_global_scale"
]
alpha
=
1.0
/
(
input_global_scale
*
weight_global_scale
)
alpha
=
1.0
/
(
input_global_scale
*
weight_global_scale
)
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
self
.
input_global_scale
=
input_global_scale
self
.
alpha
=
alpha
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
@@ -440,7 +417,10 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -440,7 +417,10 @@ class MMWeightQuantTemplate(MMWeightTemplate):
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
self
.
input_global_scale
=
input_global_scale
self
.
alpha
=
alpha
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
if
self
.
create_cuda_buffer
:
if
self
.
create_cuda_buffer
:
...
@@ -1178,33 +1158,3 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
...
@@ -1178,33 +1158,3 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
output_tensor
.
add_
(
self
.
bias
)
output_tensor
.
add_
(
self
.
bias
)
return
output_tensor
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"int8-tmo"
)
class
MMWeightWint8channelAint8channeldynamicMlu
(
MMWeightQuantTemplate
):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Mlu
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: mlu
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_tmo
def
act_quant_int8_perchannel_sym_tmo
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
tmo
.
scaled_quantize
(
x
)
return
input_tensor_quant
,
input_tensor_scale
def
apply
(
self
,
input_tensor
):
dtype
=
input_tensor
.
dtype
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
tmo
.
scaled_matmul
(
input_tensor_quant
,
self
.
weight
.
contiguous
(),
input_tensor_scale
,
self
.
weight_scale
.
squeeze
(
-
1
),
bias
=
self
.
bias
if
self
.
bias
is
not
None
else
None
,
output_dtype
=
dtype
,
use_hp_active
=
True
)
return
output_tensor
lightx2v/common/ops/norm/layer_norm_weight.py
View file @
b50498fa
...
@@ -32,13 +32,7 @@ class LNWeightTemplate(metaclass=ABCMeta):
...
@@ -32,13 +32,7 @@ class LNWeightTemplate(metaclass=ABCMeta):
else
:
else
:
if
self
.
weight_name
is
not
None
:
if
self
.
weight_name
is
not
None
:
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
...
@@ -54,7 +48,11 @@ class LNWeightTemplate(metaclass=ABCMeta):
...
@@ -54,7 +48,11 @@ class LNWeightTemplate(metaclass=ABCMeta):
self
.
pin_bias
=
None
self
.
pin_bias
=
None
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
else
:
else
:
self
.
weight
=
None
self
.
weight
=
None
self
.
bias
=
None
self
.
bias
=
None
...
...
lightx2v/common/ops/norm/rms_norm_weight.py
View file @
b50498fa
...
@@ -30,16 +30,14 @@ class RMSWeightTemplate(metaclass=ABCMeta):
...
@@ -30,16 +30,14 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
cuda
()
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
cuda
()
else
:
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
def
clear
(
self
):
def
clear
(
self
):
attrs
=
[
"weight"
,
"pinned_weight"
]
attrs
=
[
"weight"
,
"pinned_weight"
]
...
...
lightx2v/common/ops/tensor/tensor.py
View file @
b50498fa
...
@@ -29,16 +29,14 @@ class DefaultTensor:
...
@@ -29,16 +29,14 @@ class DefaultTensor:
self
.
tensor_cuda_buffer
=
weight_dict
[
self
.
tensor_name
].
cuda
()
self
.
tensor_cuda_buffer
=
weight_dict
[
self
.
tensor_name
].
cuda
()
else
:
else
:
device
=
weight_dict
[
self
.
tensor_name
].
device
device
=
weight_dict
[
self
.
tensor_name
].
device
if
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]:
if
device
.
type
==
"cpu"
:
self
.
tensor
=
weight_dict
[
self
.
tensor_name
]
elif
device
.
type
==
"cpu"
:
tensor_shape
=
weight_dict
[
self
.
tensor_name
].
shape
tensor_shape
=
weight_dict
[
self
.
tensor_name
].
shape
tensor_dtype
=
weight_dict
[
self
.
tensor_name
].
dtype
tensor_dtype
=
weight_dict
[
self
.
tensor_name
].
dtype
self
.
pin_tensor
=
torch
.
empty
(
tensor_shape
,
pin_memory
=
True
,
dtype
=
tensor_dtype
)
self
.
pin_tensor
=
torch
.
empty
(
tensor_shape
,
pin_memory
=
True
,
dtype
=
tensor_dtype
)
self
.
pin_tensor
.
copy_
(
weight_dict
[
self
.
tensor_name
])
self
.
pin_tensor
.
copy_
(
weight_dict
[
self
.
tensor_name
])
del
weight_dict
[
self
.
tensor_name
]
del
weight_dict
[
self
.
tensor_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
tensor
=
weight_dict
[
self
.
tensor_name
]
def
clear
(
self
):
def
clear
(
self
):
attrs
=
[
"tensor"
,
"pinned_tensor"
]
attrs
=
[
"tensor"
,
"pinned_tensor"
]
...
...
lightx2v/infer.py
View file @
b50498fa
...
@@ -4,11 +4,6 @@ import torch
...
@@ -4,11 +4,6 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
loguru
import
logger
from
loguru
import
logger
try
:
from
torch.distributed
import
ProcessGroupNCCL
except
ImportError
:
ProcessGroupNCCL
=
None
from
lightx2v.common.ops
import
*
from
lightx2v.common.ops
import
*
from
lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner
import
HunyuanVideo15DistillRunner
# noqa: F401
from
lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner
import
HunyuanVideo15DistillRunner
# noqa: F401
from
lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner
import
HunyuanVideo15Runner
# noqa: F401
from
lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner
import
HunyuanVideo15Runner
# noqa: F401
...
@@ -26,6 +21,8 @@ from lightx2v.utils.profiler import *
...
@@ -26,6 +21,8 @@ from lightx2v.utils.profiler import *
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.set_config
import
print_config
,
set_config
,
set_parallel_config
from
lightx2v.utils.set_config
import
print_config
,
set_config
,
set_parallel_config
from
lightx2v.utils.utils
import
seed_all
from
lightx2v.utils.utils
import
seed_all
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
lightx2v_platform.registry_factory
import
PLATFORM_DEVICE_REGISTER
def
init_runner
(
config
):
def
init_runner
(
config
):
...
@@ -105,15 +102,8 @@ def main():
...
@@ -105,15 +102,8 @@ def main():
config
=
set_config
(
args
)
config
=
set_config
(
args
)
if
config
[
"parallel"
]:
if
config
[
"parallel"
]:
run_device
=
config
.
get
(
"run_device"
,
"cuda"
)
platform_device
=
PLATFORM_DEVICE_REGISTER
.
get
(
AI_DEVICE
,
None
)
if
"cuda"
in
run_device
:
platform_device
.
init_parallel_env
()
pg_options
=
ProcessGroupNCCL
.
Options
()
pg_options
.
is_high_priority_stream
=
True
dist
.
init_process_group
(
backend
=
"nccl"
,
pg_options
=
pg_options
)
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
elif
"mlu"
in
run_device
:
dist
.
init_process_group
(
backend
=
"cncl"
)
torch
.
mlu
.
set_device
(
dist
.
get_rank
())
set_parallel_config
(
config
)
set_parallel_config
(
config
)
print_config
(
config
)
print_config
(
config
)
...
...
lightx2v/models/input_encoders/hf/hunyuan15/byt5/model.py
View file @
b50498fa
...
@@ -8,6 +8,8 @@ import torch.nn as nn
...
@@ -8,6 +8,8 @@ import torch.nn as nn
from
safetensors
import
safe_open
from
safetensors
import
safe_open
from
transformers
import
AutoTokenizer
,
T5ForConditionalGeneration
from
transformers
import
AutoTokenizer
,
T5ForConditionalGeneration
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
.format_prompt
import
MultilingualPromptFormat
from
.format_prompt
import
MultilingualPromptFormat
...
@@ -159,14 +161,12 @@ class ByT5TextEncoder:
...
@@ -159,14 +161,12 @@ class ByT5TextEncoder:
self
,
self
,
config
,
config
,
device
=
torch
.
device
(
"cpu"
),
device
=
torch
.
device
(
"cpu"
),
run_device
=
torch
.
device
(
"cuda"
),
checkpoint_path
=
None
,
checkpoint_path
=
None
,
byt5_max_length
=
256
,
byt5_max_length
=
256
,
cpu_offload
=
False
,
cpu_offload
=
False
,
):
):
self
.
cpu_offload
=
cpu_offload
self
.
cpu_offload
=
cpu_offload
self
.
config
=
config
self
.
config
=
config
self
.
run_device
=
run_device
self
.
byt5_max_length
=
byt5_max_length
self
.
byt5_max_length
=
byt5_max_length
self
.
enable_cfg
=
config
.
get
(
"enable_cfg"
,
False
)
self
.
enable_cfg
=
config
.
get
(
"enable_cfg"
,
False
)
byT5_google_path
=
os
.
path
.
join
(
checkpoint_path
,
"text_encoder"
,
"byt5-small"
)
byT5_google_path
=
os
.
path
.
join
(
checkpoint_path
,
"text_encoder"
,
"byt5-small"
)
...
@@ -301,12 +301,12 @@ class ByT5TextEncoder:
...
@@ -301,12 +301,12 @@ class ByT5TextEncoder:
negative_masks
=
[]
negative_masks
=
[]
for
prompt
in
prompt_list
:
for
prompt
in
prompt_list
:
pos_emb
,
pos_mask
=
self
.
_process_single_byt5_prompt
(
prompt
,
self
.
run_device
)
pos_emb
,
pos_mask
=
self
.
_process_single_byt5_prompt
(
prompt
,
AI_DEVICE
)
positive_embeddings
.
append
(
pos_emb
)
positive_embeddings
.
append
(
pos_emb
)
positive_masks
.
append
(
pos_mask
)
positive_masks
.
append
(
pos_mask
)
if
self
.
enable_cfg
:
# TODO: 把cfg拆出去,更适合并行
if
self
.
enable_cfg
:
# TODO: 把cfg拆出去,更适合并行
neg_emb
,
neg_mask
=
self
.
_process_single_byt5_prompt
(
""
,
self
.
run_device
)
neg_emb
,
neg_mask
=
self
.
_process_single_byt5_prompt
(
""
,
AI_DEVICE
)
negative_embeddings
.
append
(
neg_emb
)
negative_embeddings
.
append
(
neg_emb
)
negative_masks
.
append
(
neg_mask
)
negative_masks
.
append
(
neg_mask
)
...
@@ -328,8 +328,8 @@ class ByT5TextEncoder:
...
@@ -328,8 +328,8 @@ class ByT5TextEncoder:
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
prompts
):
def
infer
(
self
,
prompts
):
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
byt5_model
=
self
.
byt5_model
.
to
(
self
.
run_device
)
self
.
byt5_model
=
self
.
byt5_model
.
to
(
AI_DEVICE
)
self
.
byt5_mapper
=
self
.
byt5_mapper
.
to
(
self
.
run_device
)
self
.
byt5_mapper
=
self
.
byt5_mapper
.
to
(
AI_DEVICE
)
byt5_embeddings
,
byt5_masks
=
self
.
_prepare_byt5_embeddings
(
prompts
)
byt5_embeddings
,
byt5_masks
=
self
.
_prepare_byt5_embeddings
(
prompts
)
byt5_features
=
self
.
byt5_mapper
(
byt5_embeddings
.
to
(
torch
.
bfloat16
))
byt5_features
=
self
.
byt5_mapper
(
byt5_embeddings
.
to
(
torch
.
bfloat16
))
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
...
...
lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py
View file @
b50498fa
...
@@ -32,6 +32,9 @@ from lightx2v.models.input_encoders.hf.q_linear import ( # noqa E402
...
@@ -32,6 +32,9 @@ from lightx2v.models.input_encoders.hf.q_linear import ( # noqa E402
TorchaoQuantLinearInt8
,
# noqa E402
TorchaoQuantLinearInt8
,
# noqa E402
VllmQuantLinearInt8
,
# noqa E402
VllmQuantLinearInt8
,
# noqa E402
)
)
from
lightx2v_platform.base.global_var
import
AI_DEVICE
# noqa E402
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
def
use_default
(
value
,
default
):
def
use_default
(
value
,
default
):
...
@@ -145,12 +148,7 @@ def load_text_encoder(
...
@@ -145,12 +148,7 @@ def load_text_encoder(
new_w_dict
[
key
.
replace
(
"model."
,
""
)]
=
weight_dict
[
key
]
new_w_dict
[
key
.
replace
(
"model."
,
""
)]
=
weight_dict
[
key
]
del
weight_dict
del
weight_dict
if
torch
.
cuda
.
is_available
():
torch_device_module
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
elif
"mlu"
in
str
(
device
):
torch
.
mlu
.
empty_cache
()
elif
"npu"
in
str
(
device
):
torch
.
npu
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
text_encoder
.
load_state_dict
(
new_w_dict
,
assign
=
True
)
text_encoder
.
load_state_dict
(
new_w_dict
,
assign
=
True
)
...
@@ -552,7 +550,6 @@ class Qwen25VL_TextEncoder:
...
@@ -552,7 +550,6 @@ class Qwen25VL_TextEncoder:
text_len
=
1000
,
text_len
=
1000
,
dtype
=
torch
.
float16
,
dtype
=
torch
.
float16
,
device
=
torch
.
device
(
"cpu"
),
device
=
torch
.
device
(
"cpu"
),
run_device
=
torch
.
device
(
"cuda"
),
checkpoint_path
=
None
,
checkpoint_path
=
None
,
cpu_offload
=
False
,
cpu_offload
=
False
,
qwen25vl_quantized
=
False
,
qwen25vl_quantized
=
False
,
...
@@ -561,7 +558,6 @@ class Qwen25VL_TextEncoder:
...
@@ -561,7 +558,6 @@ class Qwen25VL_TextEncoder:
):
):
self
.
text_len
=
text_len
self
.
text_len
=
text_len
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
run_device
=
run_device
self
.
cpu_offload
=
cpu_offload
self
.
cpu_offload
=
cpu_offload
self
.
qwen25vl_quantized
=
qwen25vl_quantized
self
.
qwen25vl_quantized
=
qwen25vl_quantized
self
.
qwen25vl_quant_scheme
=
qwen25vl_quant_scheme
self
.
qwen25vl_quant_scheme
=
qwen25vl_quant_scheme
...
@@ -590,20 +586,20 @@ class Qwen25VL_TextEncoder:
...
@@ -590,20 +586,20 @@ class Qwen25VL_TextEncoder:
def
infer
(
self
,
texts
):
def
infer
(
self
,
texts
):
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
text_encoder
=
self
.
text_encoder
.
to
(
self
.
run_device
)
self
.
text_encoder
=
self
.
text_encoder
.
to
(
AI_DEVICE
)
text_inputs
=
self
.
text_encoder
.
text2tokens
(
texts
,
data_type
=
"video"
,
max_length
=
self
.
text_len
)
text_inputs
=
self
.
text_encoder
.
text2tokens
(
texts
,
data_type
=
"video"
,
max_length
=
self
.
text_len
)
prompt_outputs
=
self
.
text_encoder
.
encode
(
text_inputs
,
data_type
=
"video"
,
device
=
self
.
run_device
)
prompt_outputs
=
self
.
text_encoder
.
encode
(
text_inputs
,
data_type
=
"video"
,
device
=
AI_DEVICE
)
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
text_encoder
=
self
.
text_encoder
.
to
(
"cpu"
)
self
.
text_encoder
=
self
.
text_encoder
.
to
(
"cpu"
)
prompt_embeds
=
prompt_outputs
.
hidden_state
prompt_embeds
=
prompt_outputs
.
hidden_state
attention_mask
=
prompt_outputs
.
attention_mask
attention_mask
=
prompt_outputs
.
attention_mask
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
.
to
(
self
.
run_device
)
attention_mask
=
attention_mask
.
to
(
AI_DEVICE
)
_
,
seq_len
=
attention_mask
.
shape
_
,
seq_len
=
attention_mask
.
shape
attention_mask
=
attention_mask
.
repeat
(
1
,
self
.
num_videos_per_prompt
)
attention_mask
=
attention_mask
.
repeat
(
1
,
self
.
num_videos_per_prompt
)
attention_mask
=
attention_mask
.
view
(
self
.
num_videos_per_prompt
,
seq_len
)
attention_mask
=
attention_mask
.
view
(
self
.
num_videos_per_prompt
,
seq_len
)
prompt_embeds
=
prompt_embeds
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
run_device
)
prompt_embeds
=
prompt_embeds
.
to
(
dtype
=
self
.
dtype
,
device
=
AI_DEVICE
)
seq_len
=
prompt_embeds
.
shape
[
1
]
seq_len
=
prompt_embeds
.
shape
[
1
]
# duplicate text embeddings for each generation per prompt, using mps friendly method
# duplicate text embeddings for each generation per prompt, using mps friendly method
...
...
lightx2v/models/input_encoders/hf/hunyuan15/siglip/model.py
View file @
b50498fa
...
@@ -10,6 +10,8 @@ from safetensors.torch import safe_open
...
@@ -10,6 +10,8 @@ from safetensors.torch import safe_open
from
transformers
import
SiglipImageProcessor
,
SiglipVisionModel
from
transformers
import
SiglipImageProcessor
,
SiglipVisionModel
from
transformers.utils
import
ModelOutput
from
transformers.utils
import
ModelOutput
from
lightx2v_platform.base.global_var
import
AI_DEVICE
PRECISION_TO_TYPE
=
{
PRECISION_TO_TYPE
=
{
"fp32"
:
torch
.
float32
,
"fp32"
:
torch
.
float32
,
"fp16"
:
torch
.
float16
,
"fp16"
:
torch
.
float16
,
...
@@ -95,7 +97,6 @@ class VisionEncoder(nn.Module):
...
@@ -95,7 +97,6 @@ class VisionEncoder(nn.Module):
output_key
:
Optional
[
str
]
=
None
,
output_key
:
Optional
[
str
]
=
None
,
logger
=
None
,
logger
=
None
,
device
=
None
,
device
=
None
,
run_device
=
None
,
cpu_offload
=
False
,
cpu_offload
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -121,7 +122,6 @@ class VisionEncoder(nn.Module):
...
@@ -121,7 +122,6 @@ class VisionEncoder(nn.Module):
)
)
self
.
dtype
=
self
.
model
.
dtype
self
.
dtype
=
self
.
model
.
dtype
self
.
device
=
self
.
model
.
device
self
.
device
=
self
.
model
.
device
self
.
run_device
=
run_device
self
.
processor
,
self
.
processor_path
=
load_image_processor
(
self
.
processor
,
self
.
processor_path
=
load_image_processor
(
processor_type
=
self
.
processor_type
,
processor_type
=
self
.
processor_type
,
...
@@ -172,12 +172,12 @@ class VisionEncoder(nn.Module):
...
@@ -172,12 +172,12 @@ class VisionEncoder(nn.Module):
VisionEncoderModelOutput with encoded features
VisionEncoderModelOutput with encoded features
"""
"""
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
model
=
self
.
model
.
to
(
"cuda"
)
self
.
model
=
self
.
model
.
to
(
AI_DEVICE
)
self
.
processor
=
self
.
processor
.
to
(
"cuda"
)
self
.
processor
=
self
.
processor
.
to
(
AI_DEVICE
)
if
isinstance
(
images
,
np
.
ndarray
):
if
isinstance
(
images
,
np
.
ndarray
):
# Preprocess images if they're numpy arrays
# Preprocess images if they're numpy arrays
preprocessed
=
self
.
processor
.
preprocess
(
images
=
images
,
return_tensors
=
"pt"
).
to
(
device
=
self
.
run_device
,
dtype
=
self
.
model
.
dtype
)
preprocessed
=
self
.
processor
.
preprocess
(
images
=
images
,
return_tensors
=
"pt"
).
to
(
device
=
AI_DEVICE
,
dtype
=
self
.
model
.
dtype
)
else
:
else
:
# Assume already preprocessed
# Assume already preprocessed
preprocessed
=
images
preprocessed
=
images
...
@@ -232,13 +232,11 @@ class SiglipVisionEncoder:
...
@@ -232,13 +232,11 @@ class SiglipVisionEncoder:
self
,
self
,
config
,
config
,
device
=
torch
.
device
(
"cpu"
),
device
=
torch
.
device
(
"cpu"
),
run_device
=
torch
.
device
(
"cuda"
),
checkpoint_path
=
None
,
checkpoint_path
=
None
,
cpu_offload
=
False
,
cpu_offload
=
False
,
):
):
self
.
config
=
config
self
.
config
=
config
self
.
device
=
device
self
.
device
=
device
self
.
run_device
=
run_device
self
.
cpu_offload
=
cpu_offload
self
.
cpu_offload
=
cpu_offload
self
.
vision_states_dim
=
1152
self
.
vision_states_dim
=
1152
vision_encoder_path
=
os
.
path
.
join
(
checkpoint_path
,
"vision_encoder"
,
"siglip"
)
vision_encoder_path
=
os
.
path
.
join
(
checkpoint_path
,
"vision_encoder"
,
"siglip"
)
...
@@ -252,7 +250,6 @@ class SiglipVisionEncoder:
...
@@ -252,7 +250,6 @@ class SiglipVisionEncoder:
output_key
=
None
,
output_key
=
None
,
logger
=
None
,
logger
=
None
,
device
=
self
.
device
,
device
=
self
.
device
,
run_device
=
self
.
run_device
,
cpu_offload
=
self
.
cpu_offload
,
cpu_offload
=
self
.
cpu_offload
,
)
)
...
@@ -270,7 +267,7 @@ class SiglipVisionEncoder:
...
@@ -270,7 +267,7 @@ class SiglipVisionEncoder:
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
vision_states
):
def
infer
(
self
,
vision_states
):
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
vision_in
=
self
.
vision_in
.
to
(
self
.
run_device
)
self
.
vision_in
=
self
.
vision_in
.
to
(
AI_DEVICE
)
vision_states
=
self
.
vision_in
(
vision_states
)
vision_states
=
self
.
vision_in
(
vision_states
)
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
vision_in
=
self
.
vision_in
.
to
(
"cpu"
)
self
.
vision_in
=
self
.
vision_in
.
to
(
"cpu"
)
...
...
lightx2v/models/input_encoders/hf/q_linear.py
View file @
b50498fa
...
@@ -26,11 +26,6 @@ try:
...
@@ -26,11 +26,6 @@ try:
except
ImportError
:
except
ImportError
:
fp8_linear
=
None
fp8_linear
=
None
try
:
import
torch_mlu_ops
as
tmo
except
ImportError
:
tmo
=
None
class
VllmQuantLinearInt8
(
nn
.
Module
):
class
VllmQuantLinearInt8
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
...
@@ -315,19 +310,3 @@ class Q8FQuantLinearFp8(nn.Module):
...
@@ -315,19 +310,3 @@ class Q8FQuantLinearFp8(nn.Module):
self
.
weight_scale
=
maybe_cast
(
self
.
weight_scale
)
self
.
weight_scale
=
maybe_cast
(
self
.
weight_scale
)
self
.
bias
=
maybe_cast
(
self
.
bias
)
self
.
bias
=
maybe_cast
(
self
.
bias
)
return
self
return
self
class
MluQuantLinearInt8
(
VllmQuantLinearInt8
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
super
().
__init__
(
in_features
,
out_features
,
bias
,
dtype
)
def
act_quant_func
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
tmo
.
scaled_quantize
(
x
)
return
input_tensor_quant
,
input_tensor_scale
def
forward
(
self
,
input_tensor
):
input_tensor
=
input_tensor
.
squeeze
(
0
)
dtype
=
input_tensor
.
dtype
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
tmo
.
scaled_matmul
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
.
squeeze
(
-
1
),
output_dtype
=
dtype
)
return
output_tensor
.
unsqueeze
(
0
)
lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py
View file @
b50498fa
...
@@ -5,6 +5,10 @@ import os
...
@@ -5,6 +5,10 @@ import os
import
torch
import
torch
from
transformers
import
Qwen2Tokenizer
,
Qwen2_5_VLForConditionalGeneration
from
transformers
import
Qwen2Tokenizer
,
Qwen2_5_VLForConditionalGeneration
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
try
:
try
:
from
diffusers.image_processor
import
VaeImageProcessor
from
diffusers.image_processor
import
VaeImageProcessor
from
transformers
import
Qwen2VLProcessor
from
transformers
import
Qwen2VLProcessor
...
@@ -58,11 +62,10 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
...
@@ -58,11 +62,10 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
self
.
VAE_IMAGE_SIZE
=
1024
*
1024
self
.
VAE_IMAGE_SIZE
=
1024
*
1024
self
.
cpu_offload
=
config
.
get
(
"cpu_offload"
,
False
)
self
.
cpu_offload
=
config
.
get
(
"cpu_offload"
,
False
)
self
.
run_device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
device
=
torch
.
device
(
"cpu"
)
else
:
else
:
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
self
.
device
=
torch
.
device
(
AI_DEVICE
)
self
.
dtype
=
torch
.
bfloat16
self
.
dtype
=
torch
.
bfloat16
self
.
load
()
self
.
load
()
...
@@ -180,9 +183,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
...
@@ -180,9 +183,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
text_encoder
.
to
(
torch
.
device
(
"cpu"
))
self
.
text_encoder
.
to
(
torch
.
device
(
"cpu"
))
if
hasattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)):
torch_device_module
.
empty_cache
()
torch_module
=
getattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
))
torch_module
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
return
prompt_embeds
,
prompt_embeds_mask
,
image_info
return
prompt_embeds
,
prompt_embeds_mask
,
image_info
lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py
View file @
b50498fa
...
@@ -9,6 +9,8 @@ import torch.nn.functional as F
...
@@ -9,6 +9,8 @@ import torch.nn.functional as F
from
diffusers.models.embeddings
import
TimestepEmbedding
,
Timesteps
from
diffusers.models.embeddings
import
TimestepEmbedding
,
Timesteps
from
einops
import
rearrange
from
einops
import
rearrange
from
lightx2v_platform.base.global_var
import
AI_DEVICE
def
linear_interpolation
(
features
,
output_len
:
int
):
def
linear_interpolation
(
features
,
output_len
:
int
):
features
=
features
.
transpose
(
1
,
2
)
features
=
features
.
transpose
(
1
,
2
)
...
@@ -252,7 +254,6 @@ class AudioAdapter(nn.Module):
...
@@ -252,7 +254,6 @@ class AudioAdapter(nn.Module):
quantized
:
bool
=
False
,
quantized
:
bool
=
False
,
quant_scheme
:
str
=
None
,
quant_scheme
:
str
=
None
,
cpu_offload
:
bool
=
False
,
cpu_offload
:
bool
=
False
,
run_device
=
torch
.
device
(
"cuda"
),
):
):
super
().
__init__
()
super
().
__init__
()
self
.
cpu_offload
=
cpu_offload
self
.
cpu_offload
=
cpu_offload
...
@@ -263,7 +264,6 @@ class AudioAdapter(nn.Module):
...
@@ -263,7 +264,6 @@ class AudioAdapter(nn.Module):
mlp_dims
=
mlp_dims
,
mlp_dims
=
mlp_dims
,
transformer_layers
=
projection_transformer_layers
,
transformer_layers
=
projection_transformer_layers
,
)
)
self
.
run_device
=
run_device
# self.num_tokens = num_tokens * 4
# self.num_tokens = num_tokens * 4
self
.
num_tokens_x4
=
num_tokens
*
4
self
.
num_tokens_x4
=
num_tokens
*
4
self
.
audio_pe
=
nn
.
Parameter
(
torch
.
randn
(
self
.
num_tokens_x4
,
mlp_dims
[
-
1
]
//
num_tokens
)
*
0.02
)
self
.
audio_pe
=
nn
.
Parameter
(
torch
.
randn
(
self
.
num_tokens_x4
,
mlp_dims
[
-
1
]
//
num_tokens
)
*
0.02
)
...
@@ -302,10 +302,10 @@ class AudioAdapter(nn.Module):
...
@@ -302,10 +302,10 @@ class AudioAdapter(nn.Module):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward_audio_proj
(
self
,
audio_feat
,
latent_frame
):
def
forward_audio_proj
(
self
,
audio_feat
,
latent_frame
):
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
audio_proj
.
to
(
self
.
run_device
)
self
.
audio_proj
.
to
(
AI_DEVICE
)
x
=
self
.
audio_proj
(
audio_feat
,
latent_frame
)
x
=
self
.
audio_proj
(
audio_feat
,
latent_frame
)
x
=
self
.
rearange_audio_features
(
x
)
x
=
self
.
rearange_audio_features
(
x
)
x
=
x
+
self
.
audio_pe
.
to
(
self
.
run_device
)
x
=
x
+
self
.
audio_pe
.
to
(
AI_DEVICE
)
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
audio_proj
.
to
(
"cpu"
)
self
.
audio_proj
.
to
(
"cpu"
)
return
x
return
x
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment