Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
3e215bad
Commit
3e215bad
authored
Aug 06, 2025
by
gushiqiao
Browse files
Support bf16/fp16 inference and mixed-precision inference with fp32 for some layers
parent
e684202c
Changes
38
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
110 additions
and
87 deletions
+110
-87
lightx2v/common/ops/attn/ring_attn.py
lightx2v/common/ops/attn/ring_attn.py
+3
-2
lightx2v/common/ops/attn/sage_attn.py
lightx2v/common/ops/attn/sage_attn.py
+1
-1
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+10
-9
lightx2v/common/ops/norm/layer_norm_weight.py
lightx2v/common/ops/norm/layer_norm_weight.py
+9
-6
lightx2v/common/ops/norm/rms_norm_weight.py
lightx2v/common/ops/norm/rms_norm_weight.py
+14
-13
lightx2v/common/ops/tensor/tensor.py
lightx2v/common/ops/tensor/tensor.py
+4
-2
lightx2v/models/input_encoders/hf/t5/model.py
lightx2v/models/input_encoders/hf/t5/model.py
+3
-2
lightx2v/models/networks/cogvideox/model.py
lightx2v/models/networks/cogvideox/model.py
+2
-1
lightx2v/models/networks/hunyuan/infer/pre_infer.py
lightx2v/models/networks/hunyuan/infer/pre_infer.py
+6
-4
lightx2v/models/networks/hunyuan/infer/utils_fp32.py
lightx2v/models/networks/hunyuan/infer/utils_fp32.py
+4
-2
lightx2v/models/networks/hunyuan/model.py
lightx2v/models/networks/hunyuan/model.py
+1
-1
lightx2v/models/networks/wan/audio_adapter.py
lightx2v/models/networks/wan/audio_adapter.py
+3
-1
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+2
-2
lightx2v/models/networks/wan/causvid_model.py
lightx2v/models/networks/wan/causvid_model.py
+6
-4
lightx2v/models/networks/wan/distill_model.py
lightx2v/models/networks/wan/distill_model.py
+5
-5
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
...2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
+1
-1
lightx2v/models/networks/wan/infer/post_infer.py
lightx2v/models/networks/wan/infer/post_infer.py
+6
-4
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+6
-4
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+17
-15
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+7
-8
No files found.
lightx2v/common/ops/attn/ring_attn.py
View file @
3e215bad
...
...
@@ -3,6 +3,7 @@ import torch.distributed as dist
import
torch.nn.functional
as
F
from
loguru
import
logger
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
.template
import
AttnWeightTemplate
...
...
@@ -114,7 +115,7 @@ class RingAttnWeight(AttnWeightTemplate):
k
=
next_k
v
=
next_v
attn1
=
out
.
to
(
torch
.
bfloat16
).
squeeze
(
0
).
reshape
(
img_qkv_len
+
txt_qkv_len
,
-
1
)
attn1
=
out
.
to
(
GET_DTYPE
()
).
squeeze
(
0
).
reshape
(
img_qkv_len
+
txt_qkv_len
,
-
1
)
if
txt_mask_len
>
0
:
attn2
,
*
_
=
flash_attn
.
flash_attn_interface
.
_flash_attn_forward
(
...
...
@@ -131,7 +132,7 @@ class RingAttnWeight(AttnWeightTemplate):
return_softmax
=
False
,
)
attn2
=
attn2
.
to
(
torch
.
bfloat16
).
squeeze
(
0
).
reshape
((
txt_mask_len
-
txt_qkv_len
),
-
1
)
attn2
=
attn2
.
to
(
GET_DTYPE
()
).
squeeze
(
0
).
reshape
((
txt_mask_len
-
txt_qkv_len
),
-
1
)
attn1
=
torch
.
cat
([
attn1
,
attn2
],
dim
=
0
)
return
attn1
...
...
lightx2v/common/ops/attn/sage_attn.py
View file @
3e215bad
...
...
@@ -52,7 +52,7 @@ class SageAttn2Weight(AttnWeightTemplate):
)
x
=
torch
.
cat
((
x1
,
x2
),
dim
=
1
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
elif
model_cls
in
[
"wan2.1"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_df"
,
"wan2.1_audio"
]:
elif
model_cls
in
[
"wan2.1"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_df"
,
"wan2.1_audio"
,
"wan2.2"
]:
x
=
sageattn
(
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
...
...
lightx2v/common/ops/mm/mm_weight.py
View file @
3e215bad
...
...
@@ -129,6 +129,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
act_quant_func
=
None
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
infer_dtype
=
GET_DTYPE
()
# =========================
# weight load functions
...
...
@@ -139,12 +140,12 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
pin_memory
()
self
.
weight_scale
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_scale_name
).
float
().
pin_memory
()
if
self
.
bias_name
is
not
None
:
self
.
bias
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
torch
.
bfloat16
).
pin_memory
()
self
.
bias
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
self
.
infer_dtype
).
pin_memory
()
else
:
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
self
.
weight_scale
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_scale_name
).
float
()
if
self
.
bias_name
is
not
None
:
self
.
bias
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
torch
.
bfloat16
)
self
.
bias
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
self
.
infer_dtype
)
if
self
.
weight_need_transpose
:
self
.
weight
=
self
.
weight
.
t
()
...
...
@@ -394,7 +395,7 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
self
.
bias
.
float
(),
input_tensor_scale
,
self
.
weight_scale
,
out_dtype
=
torch
.
bfloat16
,
out_dtype
=
self
.
infer_dtype
,
)
return
output_tensor
.
squeeze
(
0
)
...
...
@@ -425,7 +426,7 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
input_tensor_scale
,
self
.
weight_scale
,
fuse_gelu
=
False
,
out_dtype
=
torch
.
bfloat16
,
out_dtype
=
self
.
infer_dtype
,
)
return
output_tensor
.
squeeze
(
0
)
...
...
@@ -449,7 +450,7 @@ class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemm(MMWeightQuantTempla
Act Scale: torch.Size([1024, 16]), torch.float32
Weight : torch.Size([4096, 2048]), torch.float8_e4m3fn
Weight Scale: torch.Size([32, 16]), torch.float32
Out : torch.Size([1024, 4096]),
torch.bfloat16
Out : torch.Size([1024, 4096]),
self.infer_dtype
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
...
...
@@ -568,7 +569,7 @@ class MMWeightWfp8channelAfp8channeldynamicSglActVllm(MMWeightQuantTemplate):
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
torch
.
bfloat16
,
self
.
infer_dtype
,
bias
=
self
.
bias
,
)
return
output_tensor
...
...
@@ -598,7 +599,7 @@ class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
torch
.
bfloat16
,
self
.
infer_dtype
,
bias
=
self
.
bias
,
)
return
output_tensor
...
...
@@ -633,7 +634,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
torch
.
bfloat16
,
self
.
infer_dtype
,
self
.
bias
,
)
return
output_tensor
...
...
@@ -659,7 +660,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
def
apply
(
self
,
input_tensor
):
input_tensor
=
input_tensor
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
quant_int8_per_token_matmul
(
input_tensor_quant
,
input_tensor_scale
,
self
.
weight
,
self
.
weight_scale
.
t
().
float
(),
output_dtype
=
torch
.
bfloat16
)
output_tensor
=
quant_int8_per_token_matmul
(
input_tensor_quant
,
input_tensor_scale
,
self
.
weight
,
self
.
weight_scale
.
t
().
float
(),
output_dtype
=
self
.
infer_dtype
)
if
self
.
bias
is
not
None
:
output_tensor
=
output_tensor
+
self
.
bias
...
...
lightx2v/common/ops/norm/layer_norm_weight.py
View file @
3e215bad
...
...
@@ -14,6 +14,8 @@ class LNWeightTemplate(metaclass=ABCMeta):
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
config
=
{}
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
def
load
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
...
...
@@ -85,29 +87,30 @@ class LNWeight(LNWeightTemplate):
def
load_from_disk
(
self
):
if
self
.
weight_name
is
not
None
:
if
not
torch
.
_dynamo
.
is_compiling
():
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
torch
.
bfloat16
).
pin_memory
()
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
()
).
pin_memory
()
else
:
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
torch
.
bfloat16
)
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
()
)
else
:
self
.
weight
=
None
if
self
.
bias_name
is
not
None
:
if
not
torch
.
_dynamo
.
is_compiling
():
self
.
bias
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
torch
.
bfloat16
).
pin_memory
()
self
.
bias
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
GET_DTYPE
()
).
pin_memory
()
else
:
self
.
bias
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
torch
.
bfloat16
)
self
.
bias
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
GET_DTYPE
()
)
else
:
self
.
bias
=
None
def
apply
(
self
,
input_tensor
):
if
GET_DTYPE
()
!=
"BF16"
:
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
input_tensor
=
torch
.
nn
.
functional
.
layer_norm
(
input_tensor
.
float
(),
(
input_tensor
.
shape
[
-
1
],),
self
.
weight
,
self
.
bias
,
self
.
eps
,
).
to
(
torch
.
bfloat16
)
).
to
(
self
.
infer_dtype
)
else
:
input_tensor
=
torch
.
nn
.
functional
.
layer_norm
(
input_tensor
,
(
input_tensor
.
shape
[
-
1
],),
self
.
weight
,
self
.
bias
,
self
.
eps
)
return
input_tensor
lightx2v/common/ops/norm/rms_norm_weight.py
View file @
3e215bad
...
...
@@ -17,6 +17,8 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self
.
eps
=
eps
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
self
.
config
=
{}
def
load
(
self
,
weight_dict
):
...
...
@@ -64,17 +66,17 @@ class RMSWeight(RMSWeightTemplate):
def
load_from_disk
(
self
):
if
not
torch
.
_dynamo
.
is_compiling
():
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
torch
.
bfloat16
).
pin_memory
()
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
()
).
pin_memory
()
else
:
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
torch
.
bfloat16
)
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
()
)
def
apply
(
self
,
input_tensor
):
if
GET_DTYPE
()
=
=
"BF16"
:
if
GET_
SENSITIVE_
DTYPE
()
!
=
GET_DTYPE
()
:
input_tensor
=
input_tensor
*
torch
.
rsqrt
(
input_tensor
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
input_tensor
=
input_tensor
*
self
.
weight
else
:
input_tensor
=
input_tensor
*
torch
.
rsqrt
(
input_tensor
.
float
().
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
input_tensor
=
(
input_tensor
*
self
.
weight
).
to
(
torch
.
bfloat16
)
input_tensor
=
(
input_tensor
*
self
.
weight
).
to
(
GET_DTYPE
()
)
return
input_tensor
...
...
@@ -97,24 +99,23 @@ class RMSWeightSgl(RMSWeight):
def
load_from_disk
(
self
):
if
not
torch
.
_dynamo
.
is_compiling
():
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
torch
.
bfloat16
).
pin_memory
()
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
()
).
pin_memory
()
else
:
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
torch
.
bfloat16
)
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
()
)
def
apply
(
self
,
input_tensor
):
use_bf16
=
GET_DTYPE
()
==
"BF16"
if
sgl_kernel
is
not
None
and
use_bf16
:
if
sgl_kernel
is
not
None
and
self
.
sensitive_layer_dtype
==
self
.
infer_dtype
:
input_tensor
=
input_tensor
.
contiguous
()
orig_shape
=
input_tensor
.
shape
input_tensor
=
input_tensor
.
view
(
-
1
,
orig_shape
[
-
1
])
input_tensor
=
sgl_kernel
.
rmsnorm
(
input_tensor
,
self
.
weight
,
self
.
eps
).
view
(
orig_shape
)
else
:
# sgl_kernel is not available or dtype!=torch.bfloat16, fallback to default implementation
if
use_bf16
:
# sgl_kernel is not available or dtype!=torch.bfloat16/float16, fallback to default implementation
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
input_tensor
=
input_tensor
*
torch
.
rsqrt
(
input_tensor
.
float
().
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
).
to
(
self
.
infer_dtype
)
input_tensor
=
(
input_tensor
*
self
.
weight
).
to
(
self
.
infer_dtype
)
else
:
input_tensor
=
input_tensor
*
torch
.
rsqrt
(
input_tensor
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
input_tensor
=
input_tensor
*
self
.
weight
else
:
input_tensor
=
input_tensor
*
torch
.
rsqrt
(
input_tensor
.
float
().
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
).
type_as
(
input_tensor
)
input_tensor
=
(
input_tensor
*
self
.
weight
).
type_as
(
input_tensor
)
return
input_tensor
lightx2v/common/ops/tensor/tensor.py
View file @
3e215bad
...
...
@@ -10,12 +10,14 @@ class DefaultTensor:
self
.
tensor_name
=
tensor_name
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
def
load_from_disk
(
self
):
if
not
torch
.
_dynamo
.
is_compiling
():
self
.
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
tensor_name
).
to
(
torch
.
bfloat16
).
pin_memory
()
self
.
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
tensor_name
).
to
(
self
.
infer_dtype
).
pin_memory
()
else
:
self
.
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
tensor_name
).
to
(
torch
.
bfloat16
)
self
.
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
tensor_name
).
to
(
self
.
infer_dtype
)
def
load
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
...
...
lightx2v/models/input_encoders/hf/t5/model.py
View file @
3e215bad
...
...
@@ -9,6 +9,7 @@ import torch.nn.functional as F
from
loguru
import
logger
from
lightx2v.models.input_encoders.hf.q_linear
import
Q8FQuantLinearFp8
,
Q8FQuantLinearInt8
,
TorchaoQuantLinearInt8
,
VllmQuantLinearFp8
,
VllmQuantLinearInt8
from
lightx2v.utils.envs
import
*
from
.tokenizer
import
HuggingfaceTokenizer
...
...
@@ -131,7 +132,7 @@ class T5Attention(nn.Module):
if
hasattr
(
self
,
"cpu_offload"
)
and
self
.
cpu_offload
:
del
attn_bias
attn
=
F
.
softmax
(
attn
.
float
(),
dim
=-
1
).
t
o
(
torch
.
bfloat16
)
attn
=
F
.
softmax
(
attn
.
float
(),
dim
=-
1
).
t
ype_as
(
attn
)
x
=
torch
.
einsum
(
"bnij,bjnc->binc"
,
attn
,
v
)
if
hasattr
(
self
,
"cpu_offload"
)
and
self
.
cpu_offload
:
...
...
@@ -356,7 +357,7 @@ class T5Encoder(nn.Module):
optimize_memory_usage
()
x
=
self
.
dropout
(
x
)
return
x
.
to
(
torch
.
bfloat16
)
return
x
.
to
(
GET_DTYPE
()
)
class
T5Decoder
(
nn
.
Module
):
...
...
lightx2v/models/networks/cogvideox/model.py
View file @
3e215bad
...
...
@@ -12,6 +12,7 @@ from lightx2v.models.networks.cogvideox.infer.transformer_infer import Cogvideox
from
lightx2v.models.networks.cogvideox.weights.post_weights
import
CogvideoxPostWeights
from
lightx2v.models.networks.cogvideox.weights.pre_weights
import
CogvideoxPreWeights
from
lightx2v.models.networks.cogvideox.weights.transformers_weights
import
CogvideoxTransformerWeights
from
lightx2v.utils.envs
import
*
class
CogvideoxModel
:
...
...
@@ -33,7 +34,7 @@ class CogvideoxModel:
def
_load_safetensor_to_dict
(
self
,
file_path
):
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
).
to
(
torch
.
bfloat16
).
cuda
()
for
key
in
f
.
keys
()}
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
()
).
cuda
()
for
key
in
f
.
keys
()}
return
tensor_dict
def
_load_ckpt
(
self
):
...
...
lightx2v/models/networks/hunyuan/infer/pre_infer.py
View file @
3e215bad
...
...
@@ -3,6 +3,8 @@ import math
import
torch
from
einops
import
rearrange
from
lightx2v.utils.envs
import
*
class
HunyuanPreInfer
:
def
__init__
(
self
,
config
):
...
...
@@ -64,7 +66,7 @@ class HunyuanPreInfer:
def
infer_time_in
(
self
,
weights
,
t
):
freqs
=
torch
.
exp
(
-
math
.
log
(
10000
)
*
torch
.
arange
(
start
=
0
,
end
=
128
,
dtype
=
torch
.
float32
)
/
128
).
to
(
device
=
t
.
device
)
args
=
t
.
unsqueeze
(
0
).
unsqueeze
(
0
).
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
).
to
(
dtype
=
torch
.
bfloat16
)
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
).
to
(
dtype
=
GET_DTYPE
()
)
out
=
weights
.
time_in_mlp_0
.
apply
(
embedding
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
out
=
weights
.
time_in_mlp_2
.
apply
(
out
)
...
...
@@ -78,12 +80,12 @@ class HunyuanPreInfer:
def
infer_text_in
(
self
,
weights
,
text_states
,
text_mask
,
t
):
freqs
=
torch
.
exp
(
-
math
.
log
(
10000
)
*
torch
.
arange
(
start
=
0
,
end
=
128
,
dtype
=
torch
.
float32
)
/
128
).
to
(
device
=
t
.
device
)
args
=
t
.
unsqueeze
(
0
).
unsqueeze
(
0
).
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
).
to
(
dtype
=
torch
.
bfloat16
)
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
).
to
(
dtype
=
GET_DTYPE
()
)
out
=
weights
.
txt_in_t_embedder_mlp_0
.
apply
(
embedding
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
timestep_aware_representations
=
weights
.
txt_in_t_embedder_mlp_2
.
apply
(
out
)
mask_float
=
text_mask
.
float
().
unsqueeze
(
-
1
).
to
(
torch
.
bfloat16
)
# [b, s1, 1]
mask_float
=
text_mask
.
float
().
unsqueeze
(
-
1
).
to
(
GET_DTYPE
()
)
# [b, s1, 1]
context_aware_representations
=
(
text_states
*
mask_float
).
sum
(
dim
=
1
)
/
mask_float
.
sum
(
dim
=
1
)
context_aware_representations
=
context_aware_representations
...
...
@@ -148,7 +150,7 @@ class HunyuanPreInfer:
def
infer_guidance_in
(
self
,
weights
,
guidance
):
freqs
=
torch
.
exp
(
-
math
.
log
(
10000
)
*
torch
.
arange
(
start
=
0
,
end
=
128
,
dtype
=
torch
.
float32
)
/
128
).
to
(
device
=
guidance
.
device
)
args
=
guidance
.
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
).
to
(
dtype
=
torch
.
bfloat16
)
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
).
to
(
dtype
=
GET_DTYPE
()
)
out
=
weights
.
guidance_in_mlp_0
.
apply
(
embedding
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
out
=
weights
.
guidance_in_mlp_2
.
apply
(
out
)
...
...
lightx2v/models/networks/hunyuan/infer/utils_fp32.py
View file @
3e215bad
...
...
@@ -2,11 +2,13 @@ from typing import Tuple, Union
import
torch
from
lightx2v.utils.envs
import
*
def
rms_norm
(
x
,
weight
,
eps
):
x
=
x
.
float
()
x
=
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
eps
)
x
=
x
.
to
(
torch
.
bfloat16
)
x
=
x
.
to
(
GET_DTYPE
()
)
x
=
x
*
weight
return
x
...
...
@@ -18,7 +20,7 @@ def rotate_half(x, shape_0, shape_1):
def
rotary_emb
(
x
,
shape_0
,
shape_1
,
cos
,
sin
):
x_out
=
x
*
cos
+
rotate_half
(
x
,
shape_0
,
shape_1
)
*
sin
return
x_out
.
to
(
torch
.
bfloat16
)
return
x_out
.
to
(
GET_DTYPE
()
)
def
apply_rotary_emb
(
...
...
lightx2v/models/networks/hunyuan/model.py
View file @
3e215bad
...
...
@@ -78,7 +78,7 @@ class HunyuanModel:
for
k
in
f
.
keys
():
weight_dict
[
k
]
=
f
.
get_tensor
(
k
)
if
weight_dict
[
k
].
dtype
==
torch
.
float
:
weight_dict
[
k
]
=
weight_dict
[
k
].
to
(
torch
.
bfloat16
)
weight_dict
[
k
]
=
weight_dict
[
k
].
to
(
GET_DTYPE
()
)
return
weight_dict
...
...
lightx2v/models/networks/wan/audio_adapter.py
View file @
3e215bad
...
...
@@ -13,6 +13,8 @@ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from
einops
import
rearrange
from
transformers
import
AutoModel
from
lightx2v.utils.envs
import
*
def
load_safetensors
(
in_path
:
str
):
if
os
.
path
.
isdir
(
in_path
):
...
...
@@ -57,7 +59,7 @@ def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True):
model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
if
dist
.
is_initialized
():
dist
.
barrier
()
return
model
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
return
model
.
to
(
dtype
=
GET_DTYPE
()
,
device
=
"cuda"
)
def
linear_interpolation
(
features
,
output_len
:
int
):
...
...
lightx2v/models/networks/wan/audio_model.py
View file @
3e215bad
...
...
@@ -67,10 +67,10 @@ class WanAudioModel(WanModel):
class
Wan22MoeAudioModel
(
WanAudioModel
):
def
_load_ckpt
(
self
,
u
se_bf16
,
skip_bf16
):
def
_load_ckpt
(
self
,
u
nified_dtype
,
sensitive_layer
):
safetensors_files
=
glob
.
glob
(
os
.
path
.
join
(
self
.
model_path
,
"*.safetensors"
))
weight_dict
=
{}
for
file_path
in
safetensors_files
:
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
,
u
se_bf16
,
skip_bf16
)
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
,
u
nified_dtype
,
sensitive_layer
)
weight_dict
.
update
(
file_weights
)
return
weight_dict
lightx2v/models/networks/wan/causvid_model.py
View file @
3e215bad
...
...
@@ -31,23 +31,25 @@ class WanCausVidModel(WanModel):
self
.
post_infer_class
=
WanPostInfer
self
.
transformer_infer_class
=
WanTransformerInferCausVid
def
_load_ckpt
(
self
,
u
se_bf16
,
skip_bf16
):
def
_load_ckpt
(
self
,
u
nified_dtype
,
sensitive_layer
):
ckpt_folder
=
"causvid_models"
safetensors_path
=
os
.
path
.
join
(
self
.
model_path
,
f
"
{
ckpt_folder
}
/causal_model.safetensors"
)
if
os
.
path
.
exists
(
safetensors_path
):
with
safe_open
(
safetensors_path
,
framework
=
"pt"
)
as
f
:
weight_dict
=
{
key
:
(
f
.
get_tensor
(
key
).
to
(
torch
.
bfloat16
)
if
use_bf16
or
all
(
s
not
in
key
for
s
in
skip_bf16
)
else
f
.
get_tensor
(
key
)).
pin_memory
().
to
(
self
.
device
)
for
key
in
f
.
keys
()}
weight_dict
=
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
)).
pin_memory
().
to
(
self
.
device
)
for
key
in
f
.
keys
()
}
return
weight_dict
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
f
"
{
ckpt_folder
}
/causal_model.pt"
)
if
os
.
path
.
exists
(
ckpt_path
):
weight_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
weight_dict
=
{
key
:
(
weight_dict
[
key
].
to
(
torch
.
bfloat16
)
if
use_bf16
or
all
(
s
not
in
key
for
s
in
s
kip_bf16
)
else
weight_dict
[
key
]).
pin_memory
().
to
(
self
.
device
)
for
key
in
weight_dict
.
keys
()
key
:
(
weight_dict
[
key
].
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
s
ensitive_layer
)
else
weight_dict
[
key
]).
pin_memory
().
to
(
self
.
device
)
for
key
in
weight_dict
.
keys
()
}
return
weight_dict
return
super
().
_load_ckpt
(
u
se_bf16
,
skip_bf16
)
return
super
().
_load_ckpt
(
u
nified_dtype
,
sensitive_layer
)
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
,
kv_start
,
kv_end
):
...
...
lightx2v/models/networks/wan/distill_model.py
View file @
3e215bad
...
...
@@ -20,27 +20,27 @@ class WanDistillModel(WanModel):
def
__init__
(
self
,
model_path
,
config
,
device
):
super
().
__init__
(
model_path
,
config
,
device
)
def
_load_ckpt
(
self
,
u
se_bf16
,
skip_bf16
):
def
_load_ckpt
(
self
,
u
nified_dtype
,
sensitive_layer
):
if
self
.
config
.
get
(
"enable_dynamic_cfg"
,
False
):
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
"distill_cfg_models"
,
"distill_model.safetensors"
)
else
:
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
"distill_models"
,
"distill_model.safetensors"
)
if
os
.
path
.
exists
(
ckpt_path
):
logger
.
info
(
f
"Loading weights from
{
ckpt_path
}
"
)
return
self
.
_load_safetensor_to_dict
(
ckpt_path
,
u
se_bf16
,
skip_bf16
)
return
self
.
_load_safetensor_to_dict
(
ckpt_path
,
u
nified_dtype
,
sensitive_layer
)
return
super
().
_load_ckpt
(
u
se_bf16
,
skip_bf16
)
return
super
().
_load_ckpt
(
u
nified_dtype
,
sensitive_layer
)
class
Wan22MoeDistillModel
(
WanDistillModel
,
Wan22MoeModel
):
def
__init__
(
self
,
model_path
,
config
,
device
):
WanDistillModel
.
__init__
(
self
,
model_path
,
config
,
device
)
def
_load_ckpt
(
self
,
u
se_bf16
,
skip_bf16
):
def
_load_ckpt
(
self
,
u
nified_dtype
,
sensitive_layer
):
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
"distill_model.safetensors"
)
if
os
.
path
.
exists
(
ckpt_path
):
logger
.
info
(
f
"Loading weights from
{
ckpt_path
}
"
)
return
self
.
_load_safetensor_to_dict
(
ckpt_path
,
u
se_bf16
,
skip_bf16
)
return
self
.
_load_safetensor_to_dict
(
ckpt_path
,
u
nified_dtype
,
sensitive_layer
)
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
...
...
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
View file @
3e215bad
...
...
@@ -54,7 +54,7 @@ class WanAudioPreInfer(WanPreInfer):
clip_fea
=
inputs
[
"image_encoder_output"
][
"clip_encoder_out"
]
ref_image_encoder
=
inputs
[
"image_encoder_output"
][
"vae_encoder_out"
]
batch_size
=
len
(
x
)
num_channels
,
num_frames
,
height
,
width
=
x
[
0
].
shape
num_channels
,
_
,
height
,
width
=
x
[
0
].
shape
_
,
ref_num_channels
,
ref_num_frames
,
_
,
_
=
ref_image_encoder
.
shape
if
ref_num_channels
!=
num_channels
:
...
...
lightx2v/models/networks/wan/infer/post_infer.py
View file @
3e215bad
...
...
@@ -10,6 +10,8 @@ class WanPostInfer:
self
.
out_dim
=
config
[
"out_dim"
]
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
clean_cuda_cache
=
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
...
...
@@ -26,11 +28,11 @@ class WanPostInfer:
x
=
weights
.
norm
.
apply
(
x
)
if
GET_DTYPE
()
!=
"BF16"
:
x
=
x
.
float
(
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
sensitive_layer_dtype
)
x
.
mul_
(
1
+
e
[
1
].
squeeze
()).
add_
(
e
[
0
].
squeeze
())
if
GET_DTYPE
()
!=
"BF16"
:
x
=
x
.
to
(
torch
.
bfloat16
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
infer_dtype
)
x
=
weights
.
head
.
apply
(
x
)
x
=
self
.
unpatchify
(
x
,
grid_sizes
)
...
...
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
3e215bad
...
...
@@ -25,6 +25,8 @@ class WanPreInfer:
self
.
text_len
=
config
[
"text_len"
]
self
.
enable_dynamic_cfg
=
config
.
get
(
"enable_dynamic_cfg"
,
False
)
self
.
cfg_scale
=
config
.
get
(
"cfg_scale"
,
4.0
)
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
...
...
@@ -79,8 +81,8 @@ class WanPreInfer:
cfg_embed
=
torch
.
nn
.
functional
.
silu
(
cfg_embed
)
cfg_embed
=
weights
.
cfg_cond_proj_2
.
apply
(
cfg_embed
)
embed
=
embed
+
cfg_embed
if
GET_DTYPE
()
!=
"BF16"
:
embed
=
weights
.
time_embedding_0
.
apply
(
embed
.
float
(
))
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
embed
=
weights
.
time_embedding_0
.
apply
(
embed
.
to
(
self
.
sensitive_layer_dtype
))
else
:
embed
=
weights
.
time_embedding_0
.
apply
(
embed
)
embed
=
torch
.
nn
.
functional
.
silu
(
embed
)
...
...
@@ -100,8 +102,8 @@ class WanPreInfer:
# text embeddings
stacked
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
text_len
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context
])
if
GET_DTYPE
()
!=
"BF16"
:
out
=
weights
.
text_embedding_0
.
apply
(
stacked
.
squeeze
(
0
).
float
(
))
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
out
=
weights
.
text_embedding_0
.
apply
(
stacked
.
squeeze
(
0
).
to
(
self
.
sensitive_layer_dtype
))
else
:
out
=
weights
.
text_embedding_0
.
apply
(
stacked
.
squeeze
(
0
))
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
"tanh"
)
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
3e215bad
...
...
@@ -30,6 +30,8 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
apply_rotary_emb_func
=
apply_rotary_emb
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
mask_map
=
None
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
torch
.
cuda
.
get_device_capability
(
0
)
==
(
9
,
0
):
...
...
@@ -342,13 +344,13 @@ class WanTransformerInfer(BaseTransformerInfer):
norm1_out
=
weights
.
norm1
.
apply
(
x
)
if
GET_DTYPE
()
!=
"BF16"
:
norm1_out
=
norm1_out
.
float
(
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm1_out
=
norm1_out
.
to
(
self
.
sensitive_layer_dtype
)
norm1_out
.
mul_
(
norm1_weight
).
add_
(
norm1_bias
)
if
GET_DTYPE
()
!=
"BF16"
:
norm1_out
=
norm1_out
.
to
(
torch
.
bfloat16
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm1_out
=
norm1_out
.
to
(
self
.
infer_dtype
)
s
,
n
,
d
=
*
norm1_out
.
shape
[:
1
],
self
.
num_heads
,
self
.
head_dim
...
...
@@ -402,8 +404,8 @@ class WanTransformerInfer(BaseTransformerInfer):
return
y
def
infer_cross_attn
(
self
,
weights
,
x
,
context
,
y_out
,
gate_msa
):
if
GET_DTYPE
()
!=
"BF16"
:
x
=
x
.
float
()
+
y_out
.
float
(
)
*
gate_msa
.
squeeze
()
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
sensitive_layer_dtype
)
+
y_out
.
to
(
self
.
sensitive_layer_dtype
)
*
gate_msa
.
squeeze
()
else
:
x
.
add_
(
y_out
*
gate_msa
.
squeeze
())
...
...
@@ -414,10 +416,10 @@ class WanTransformerInfer(BaseTransformerInfer):
else
:
context_img
=
None
if
GET_DTYPE
()
!=
"BF16"
:
context
=
context
.
to
(
torch
.
bfloat16
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
context
=
context
.
to
(
self
.
infer_dtype
)
if
self
.
task
==
"i2v"
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
context_img
=
context_img
.
to
(
torch
.
bfloat16
)
context_img
=
context_img
.
to
(
self
.
infer_dtype
)
n
,
d
=
self
.
num_heads
,
self
.
head_dim
...
...
@@ -485,11 +487,11 @@ class WanTransformerInfer(BaseTransformerInfer):
norm2_bias
=
c_shift_msa
.
squeeze
()
norm2_out
=
weights
.
norm2
.
apply
(
x
)
if
GET_DTYPE
()
!=
"BF16"
:
norm2_out
=
norm2_out
.
float
(
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm2_out
=
norm2_out
.
to
(
self
.
sensitive_layer_dtype
)
norm2_out
.
mul_
(
norm2_weight
).
add_
(
norm2_bias
)
if
GET_DTYPE
()
!=
"BF16"
:
norm2_out
=
norm2_out
.
to
(
torch
.
bfloat16
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm2_out
=
norm2_out
.
to
(
self
.
infer_dtype
)
y
=
weights
.
ffn_0
.
apply
(
norm2_out
)
if
self
.
clean_cuda_cache
:
...
...
@@ -503,8 +505,8 @@ class WanTransformerInfer(BaseTransformerInfer):
return
y
def
post_process
(
self
,
x
,
y
,
c_gate_msa
):
if
GET_DTYPE
()
!=
"BF16"
:
x
=
x
.
float
()
+
y
.
float
(
)
*
c_gate_msa
.
squeeze
()
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
sensitive_layer_dtype
)
+
y
.
to
(
self
.
sensitive_layer_dtype
)
*
c_gate_msa
.
squeeze
()
else
:
x
.
add_
(
y
*
c_gate_msa
.
squeeze
())
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
3e215bad
...
...
@@ -68,7 +68,7 @@ def apply_rotary_emb(x, freqs_i):
# Apply rotary embedding
x_i
=
torch
.
view_as_real
(
x_i
*
freqs_i
).
flatten
(
2
)
x_i
=
torch
.
cat
([
x_i
,
x
[
seq_len
:]])
return
x_i
.
to
(
torch
.
bfloat16
)
return
x_i
.
to
(
GET_DTYPE
()
)
def
apply_rotary_emb_chunk
(
x
,
freqs_i
,
chunk_size
,
remaining_chunk_size
=
100
):
...
...
@@ -82,7 +82,7 @@ def apply_rotary_emb_chunk(x, freqs_i, chunk_size, remaining_chunk_size=100):
freqs_chunk
=
freqs_i
[
start
:
end
]
x_chunk_complex
=
torch
.
view_as_complex
(
x_chunk
.
to
(
torch
.
float32
).
reshape
(
end
-
start
,
n
,
-
1
,
2
))
x_chunk_embedded
=
torch
.
view_as_real
(
x_chunk_complex
*
freqs_chunk
).
flatten
(
2
).
to
(
torch
.
bfloat16
)
x_chunk_embedded
=
torch
.
view_as_real
(
x_chunk_complex
*
freqs_chunk
).
flatten
(
2
).
to
(
GET_DTYPE
()
)
output_chunks
.
append
(
x_chunk_embedded
)
del
x_chunk_complex
,
x_chunk_embedded
torch
.
cuda
.
empty_cache
()
...
...
@@ -101,7 +101,7 @@ def apply_rotary_emb_chunk(x, freqs_i, chunk_size, remaining_chunk_size=100):
del
result
torch
.
cuda
.
empty_cache
()
return
x_i
.
to
(
torch
.
bfloat16
)
return
x_i
.
to
(
GET_DTYPE
()
)
def
rope_params
(
max_seq_len
,
dim
,
theta
=
10000
):
...
...
@@ -123,8 +123,7 @@ def sinusoidal_embedding_1d(dim, position):
# calculation
sinusoid
=
torch
.
outer
(
position
,
torch
.
pow
(
10000
,
-
torch
.
arange
(
half
).
to
(
position
).
div
(
half
)))
x
=
torch
.
cat
([
torch
.
cos
(
sinusoid
),
torch
.
sin
(
sinusoid
)],
dim
=
1
)
if
GET_DTYPE
()
==
"BF16"
:
x
=
x
.
to
(
torch
.
bfloat16
)
x
=
x
.
to
(
GET_SENSITIVE_DTYPE
())
return
x
...
...
@@ -140,15 +139,15 @@ def guidance_scale_embedding(w, embedding_dim=256, cfg_range=(1.0, 6.0), target_
"""
assert
len
(
w
.
shape
)
==
1
cfg_min
,
cfg_max
=
cfg_range
#
w = torch.round(w)
#
w = torch.clamp(w, min=cfg_min, max=cfg_max)
w
=
torch
.
round
(
w
)
w
=
torch
.
clamp
(
w
,
min
=
cfg_min
,
max
=
cfg_max
)
w
=
(
w
-
cfg_min
)
/
(
cfg_max
-
cfg_min
)
# [0, 1]
w
=
w
*
target_range
half_dim
=
embedding_dim
//
2
emb
=
torch
.
log
(
torch
.
tensor
(
10000.0
))
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
dtype
).
to
(
w
.
device
)
*
-
emb
).
to
(
w
.
device
)
emb
=
w
.
to
(
dtype
)[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
cos
(
emb
),
torch
.
sin
(
emb
)],
dim
=
1
)
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
).
to
(
w
.
device
))
assert
emb
.
shape
==
(
w
.
shape
[
0
],
embedding_dim
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment