Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
daf4c74e
Commit
daf4c74e
authored
Mar 24, 2025
by
helloyongyang
Committed by
Yang Yong(雍洋)
Apr 08, 2025
Browse files
first commit
parent
6c79160f
Changes
107
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1597 additions
and
0 deletions
+1597
-0
lightx2v/common/ops/mm/mm_weight_calib.py
lightx2v/common/ops/mm/mm_weight_calib.py
+53
-0
lightx2v/common/ops/norm/__init__.py
lightx2v/common/ops/norm/__init__.py
+1
-0
lightx2v/common/ops/norm/layer_norm_weight.py
lightx2v/common/ops/norm/layer_norm_weight.py
+45
-0
lightx2v/common/ops/norm/rms_norm_weight.py
lightx2v/common/ops/norm/rms_norm_weight.py
+65
-0
lightx2v/image2v/__init__.py
lightx2v/image2v/__init__.py
+0
-0
lightx2v/image2v/models/wan/__init__.py
lightx2v/image2v/models/wan/__init__.py
+0
-0
lightx2v/image2v/models/wan/model.py
lightx2v/image2v/models/wan/model.py
+542
-0
lightx2v/image2v/models/wan/xlm_roberta.py
lightx2v/image2v/models/wan/xlm_roberta.py
+170
-0
lightx2v/text2v/__init__.py
lightx2v/text2v/__init__.py
+0
-0
lightx2v/text2v/models/__init__.py
lightx2v/text2v/models/__init__.py
+0
-0
lightx2v/text2v/models/networks/hunyuan/__init__.py
lightx2v/text2v/models/networks/hunyuan/__init__.py
+0
-0
lightx2v/text2v/models/networks/hunyuan/infer/__init__.py
lightx2v/text2v/models/networks/hunyuan/infer/__init__.py
+0
-0
lightx2v/text2v/models/networks/hunyuan/infer/feature_caching/__init__.py
...models/networks/hunyuan/infer/feature_caching/__init__.py
+0
-0
lightx2v/text2v/models/networks/hunyuan/infer/feature_caching/transformer_infer.py
...tworks/hunyuan/infer/feature_caching/transformer_infer.py
+284
-0
lightx2v/text2v/models/networks/hunyuan/infer/post_infer.py
lightx2v/text2v/models/networks/hunyuan/infer/post_infer.py
+30
-0
lightx2v/text2v/models/networks/hunyuan/infer/pre_infer.py
lightx2v/text2v/models/networks/hunyuan/infer/pre_infer.py
+133
-0
lightx2v/text2v/models/networks/hunyuan/infer/transformer_infer.py
...text2v/models/networks/hunyuan/infer/transformer_infer.py
+205
-0
lightx2v/text2v/models/networks/hunyuan/infer/utils.py
lightx2v/text2v/models/networks/hunyuan/infer/utils.py
+9
-0
lightx2v/text2v/models/networks/hunyuan/infer/utils_bf16.py
lightx2v/text2v/models/networks/hunyuan/infer/utils_bf16.py
+29
-0
lightx2v/text2v/models/networks/hunyuan/infer/utils_fp32.py
lightx2v/text2v/models/networks/hunyuan/infer/utils_fp32.py
+31
-0
No files found.
lightx2v/common/ops/mm/mm_weight_calib.py
0 → 100644
View file @
daf4c74e
import
torch
from
.mm_weight
import
MMWeight
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.utils.quant_utils
import
IntegerQuantizer
,
FloatQuantizer
@
MM_WEIGHT_REGISTER
(
'Calib'
)
class
MMWeightCalib
(
MMWeight
):
def
__init__
(
self
,
weight_name
,
bias_name
):
super
().
__init__
(
weight_name
,
bias_name
)
def
load
(
self
,
weight_dict
):
assert
self
.
config
and
self
.
config
.
get
(
'mm_type'
,
'Default'
)
!=
'Default'
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
get_quantizer
()
shape_and_dtype
=
self
.
get_quant_shape_and_dtype
(
self
.
weight
.
shape
)
self
.
realq_weight
,
self
.
scales
,
self
.
zeros
=
self
.
w_quantizer
.
real_quant_tensor
(
self
.
weight
)
self
.
realq_weight
=
self
.
realq_weight
.
view
(
shape_and_dtype
[
'tensor'
][
0
]).
contiguous
().
to
(
shape_and_dtype
[
'tensor'
][
1
])
self
.
scales
=
self
.
scales
.
view
(
shape_and_dtype
[
'scales'
][
0
]).
contiguous
().
to
(
shape_and_dtype
[
'scales'
][
1
])
if
self
.
zeros
is
not
None
:
self
.
zeros
=
self
.
zeros
.
view
(
shape_and_dtype
[
'zeros'
][
0
]).
contiguous
().
to
(
shape_and_dtype
[
'zeros'
][
1
])
def
apply
(
self
,
input_tensor
):
return
super
().
apply
(
input_tensor
)
def
get_quantizer
(
self
):
if
self
.
config
[
'mm_type'
]
==
'W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm'
:
self
.
w_setting
=
{
'bit'
:
'e4m3'
,
'symmetric'
:
True
,
'granularity'
:
'channel'
}
self
.
a_setting
=
{
'bit'
:
'e4m3'
,
'symmetric'
:
True
,
'granularity'
:
'channel'
}
self
.
w_quantizer
=
FloatQuantizer
(
**
self
.
w_setting
)
self
.
a_quantizer
=
FloatQuantizer
(
**
self
.
a_setting
)
self
.
act_dynamic_quant
=
True
else
:
raise
NotImplementedError
(
f
'Unsupported mm_type:
{
self
.
config
[
"mm_type"
]
}
'
)
def
get_quant_shape_and_dtype
(
self
,
shape
):
if
self
.
config
[
'mm_type'
]
==
'W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm'
:
return
{
'tensor'
:
(
shape
,
torch
.
float8_e5m2
),
'scales'
:
((
shape
[
0
],
1
),
torch
.
float32
),
'zeros'
:
None
,
}
else
:
raise
NotImplementedError
(
f
'Unsupported mm_type:
{
self
.
config
[
"mm_type"
]
}
'
)
lightx2v/common/ops/norm/__init__.py
0 → 100755
View file @
daf4c74e
from
.rms_norm_weight
import
*
lightx2v/common/ops/norm/layer_norm_weight.py
0 → 100644
View file @
daf4c74e
import
torch
from
abc
import
ABCMeta
,
abstractmethod
from
lightx2v.utils.registry_factory
import
LN_WEIGHT_REGISTER
class
LNWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
bias_name
,
eps
=
1e-6
):
self
.
weight_name
=
weight_name
self
.
bias_name
=
bias_name
self
.
eps
=
eps
self
.
config
=
{}
def
load
(
self
,
weight_dict
):
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
()
if
self
.
weight_name
is
not
None
else
None
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
bias_name
is
not
None
else
None
@
abstractmethod
def
apply
(
self
,
input_tensor
):
pass
def
set_config
(
self
,
config
=
None
):
if
config
is
not
None
:
self
.
config
=
config
def
to_cpu
(
self
):
if
self
.
weight
is
not
None
:
self
.
weight
=
self
.
weight
.
cpu
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cpu
()
def
to_cuda
(
self
):
if
self
.
weight
is
not
None
:
self
.
weight
=
self
.
weight
.
cuda
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cuda
()
@
LN_WEIGHT_REGISTER
(
'Default'
)
class
LNWeight
(
LNWeightTemplate
):
def
__init__
(
self
,
weight_name
,
bias_name
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
bias_name
,
eps
)
def
apply
(
self
,
input_tensor
):
input_tensor
=
torch
.
nn
.
functional
.
layer_norm
(
input_tensor
,
(
input_tensor
.
shape
[
1
],),
self
.
weight
,
self
.
bias
,
self
.
eps
)
return
input_tensor
lightx2v/common/ops/norm/rms_norm_weight.py
0 → 100644
View file @
daf4c74e
import
torch
from
abc
import
ABCMeta
,
abstractmethod
from
lightx2v.utils.registry_factory
import
RMS_WEIGHT_REGISTER
import
sgl_kernel
class
RMSWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
eps
=
1e-6
):
self
.
weight_name
=
weight_name
self
.
eps
=
eps
self
.
config
=
{}
def
load
(
self
,
weight_dict
):
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
()
@
abstractmethod
def
apply
(
self
,
input_tensor
):
pass
def
set_config
(
self
,
config
=
None
):
if
config
is
not
None
:
self
.
config
=
config
def
to_cpu
(
self
):
self
.
weight
=
self
.
weight
.
cpu
()
def
to_cuda
(
self
):
self
.
weight
=
self
.
weight
.
cuda
()
@
RMS_WEIGHT_REGISTER
(
'Default'
)
class
RMSWeight
(
RMSWeightTemplate
):
def
__init__
(
self
,
weight_name
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
eps
)
def
apply
(
self
,
input_tensor
):
input_tensor
=
input_tensor
*
torch
.
rsqrt
(
input_tensor
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
input_tensor
=
input_tensor
*
self
.
weight
return
input_tensor
@
RMS_WEIGHT_REGISTER
(
'FP32'
)
class
RMSWeightFP32
(
RMSWeight
):
def
__init__
(
self
,
weight_name
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
eps
)
def
apply
(
self
,
input_tensor
):
input_tensor
=
input_tensor
.
float
()
input_tensor
=
input_tensor
*
torch
.
rsqrt
(
input_tensor
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
input_tensor
=
input_tensor
.
to
(
torch
.
bfloat16
)
input_tensor
=
input_tensor
*
self
.
weight
return
input_tensor
@
RMS_WEIGHT_REGISTER
(
'sgl-kernel'
)
class
RMSWeightSgl
(
RMSWeight
):
def
__init__
(
self
,
weight_name
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
eps
)
def
apply
(
self
,
input_tensor
):
input_tensor
=
input_tensor
.
contiguous
()
orig_shape
=
input_tensor
.
shape
input_tensor
=
input_tensor
.
view
(
-
1
,
orig_shape
[
-
1
])
input_tensor
=
sgl_kernel
.
rmsnorm
(
input_tensor
,
self
.
weight
,
self
.
eps
).
view
(
orig_shape
)
return
input_tensor
lightx2v/image2v/__init__.py
0 → 100755
View file @
daf4c74e
lightx2v/image2v/models/wan/__init__.py
0 → 100755
View file @
daf4c74e
lightx2v/image2v/models/wan/model.py
0 → 100755
View file @
daf4c74e
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import
logging
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torchvision.transforms
as
T
from
lightx2v.attentions
import
attention
from
lightx2v.text2v.models.text_encoders.hf.t5.tokenizer
import
HuggingfaceTokenizer
from
.xlm_roberta
import
XLMRoberta
__all__
=
[
'XLMRobertaCLIP'
,
'clip_xlm_roberta_vit_h_14'
,
'CLIPModel'
,
]
def
pos_interpolate
(
pos
,
seq_len
):
if
pos
.
size
(
1
)
==
seq_len
:
return
pos
else
:
src_grid
=
int
(
math
.
sqrt
(
pos
.
size
(
1
)))
tar_grid
=
int
(
math
.
sqrt
(
seq_len
))
n
=
pos
.
size
(
1
)
-
src_grid
*
src_grid
return
torch
.
cat
([
pos
[:,
:
n
],
F
.
interpolate
(
pos
[:,
n
:].
float
().
reshape
(
1
,
src_grid
,
src_grid
,
-
1
).
permute
(
0
,
3
,
1
,
2
),
size
=
(
tar_grid
,
tar_grid
),
mode
=
'bicubic'
,
align_corners
=
False
).
flatten
(
2
).
transpose
(
1
,
2
)
],
dim
=
1
)
class
QuickGELU
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
*
torch
.
sigmoid
(
1.702
*
x
)
class
LayerNorm
(
nn
.
LayerNorm
):
def
forward
(
self
,
x
):
return
super
().
forward
(
x
.
float
()).
type_as
(
x
)
class
SelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
causal
=
False
,
attn_dropout
=
0.0
,
proj_dropout
=
0.0
):
assert
dim
%
num_heads
==
0
super
().
__init__
()
self
.
dim
=
dim
self
.
num_heads
=
num_heads
self
.
head_dim
=
dim
//
num_heads
self
.
causal
=
causal
self
.
attn_dropout
=
attn_dropout
self
.
proj_dropout
=
proj_dropout
# layers
self
.
to_qkv
=
nn
.
Linear
(
dim
,
dim
*
3
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
def
forward
(
self
,
x
):
"""
x: [B, L, C].
"""
b
,
s
,
c
,
n
,
d
=
*
x
.
size
(),
self
.
num_heads
,
self
.
head_dim
# compute query, key, value
q
,
k
,
v
=
self
.
to_qkv
(
x
).
view
(
b
,
s
,
3
,
n
,
d
).
unbind
(
2
)
# compute attention
x
=
attention
(
q
=
q
,
k
=
k
,
v
=
v
,
attention_type
=
'torch_sdpa'
)
x
=
x
.
reshape
(
b
,
s
,
c
)
# output
x
=
self
.
proj
(
x
)
x
=
F
.
dropout
(
x
,
self
.
proj_dropout
,
self
.
training
)
return
x
class
SwiGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
mid_dim
):
super
().
__init__
()
self
.
dim
=
dim
self
.
mid_dim
=
mid_dim
# layers
self
.
fc1
=
nn
.
Linear
(
dim
,
mid_dim
)
self
.
fc2
=
nn
.
Linear
(
dim
,
mid_dim
)
self
.
fc3
=
nn
.
Linear
(
mid_dim
,
dim
)
def
forward
(
self
,
x
):
x
=
F
.
silu
(
self
.
fc1
(
x
))
*
self
.
fc2
(
x
)
x
=
self
.
fc3
(
x
)
return
x
class
AttentionBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
mlp_ratio
,
num_heads
,
post_norm
=
False
,
causal
=
False
,
activation
=
'quick_gelu'
,
attn_dropout
=
0.0
,
proj_dropout
=
0.0
,
norm_eps
=
1e-5
):
assert
activation
in
[
'quick_gelu'
,
'gelu'
,
'swi_glu'
]
super
().
__init__
()
self
.
dim
=
dim
self
.
mlp_ratio
=
mlp_ratio
self
.
num_heads
=
num_heads
self
.
post_norm
=
post_norm
self
.
causal
=
causal
self
.
norm_eps
=
norm_eps
# layers
self
.
norm1
=
LayerNorm
(
dim
,
eps
=
norm_eps
)
self
.
attn
=
SelfAttention
(
dim
,
num_heads
,
causal
,
attn_dropout
,
proj_dropout
)
self
.
norm2
=
LayerNorm
(
dim
,
eps
=
norm_eps
)
if
activation
==
'swi_glu'
:
self
.
mlp
=
SwiGLU
(
dim
,
int
(
dim
*
mlp_ratio
))
else
:
self
.
mlp
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
int
(
dim
*
mlp_ratio
)),
QuickGELU
()
if
activation
==
'quick_gelu'
else
nn
.
GELU
(),
nn
.
Linear
(
int
(
dim
*
mlp_ratio
),
dim
),
nn
.
Dropout
(
proj_dropout
))
def
forward
(
self
,
x
):
if
self
.
post_norm
:
x
=
x
+
self
.
norm1
(
self
.
attn
(
x
))
x
=
x
+
self
.
norm2
(
self
.
mlp
(
x
))
else
:
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
))
x
=
x
+
self
.
mlp
(
self
.
norm2
(
x
))
return
x
class
AttentionPool
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
mlp_ratio
,
num_heads
,
activation
=
'gelu'
,
proj_dropout
=
0.0
,
norm_eps
=
1e-5
):
assert
dim
%
num_heads
==
0
super
().
__init__
()
self
.
dim
=
dim
self
.
mlp_ratio
=
mlp_ratio
self
.
num_heads
=
num_heads
self
.
head_dim
=
dim
//
num_heads
self
.
proj_dropout
=
proj_dropout
self
.
norm_eps
=
norm_eps
# layers
gain
=
1.0
/
math
.
sqrt
(
dim
)
self
.
cls_embedding
=
nn
.
Parameter
(
gain
*
torch
.
randn
(
1
,
1
,
dim
))
self
.
to_q
=
nn
.
Linear
(
dim
,
dim
)
self
.
to_kv
=
nn
.
Linear
(
dim
,
dim
*
2
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
norm
=
LayerNorm
(
dim
,
eps
=
norm_eps
)
self
.
mlp
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
int
(
dim
*
mlp_ratio
)),
QuickGELU
()
if
activation
==
'quick_gelu'
else
nn
.
GELU
(),
nn
.
Linear
(
int
(
dim
*
mlp_ratio
),
dim
),
nn
.
Dropout
(
proj_dropout
))
def
forward
(
self
,
x
):
"""
x: [B, L, C].
"""
b
,
s
,
c
,
n
,
d
=
*
x
.
size
(),
self
.
num_heads
,
self
.
head_dim
# compute query, key, value
q
=
self
.
to_q
(
self
.
cls_embedding
).
view
(
1
,
1
,
n
,
d
).
expand
(
b
,
-
1
,
-
1
,
-
1
)
k
,
v
=
self
.
to_kv
(
x
).
view
(
b
,
s
,
2
,
n
,
d
).
unbind
(
2
)
# compute attention
x
=
attention
(
q
=
q
,
k
=
k
,
v
=
v
,
attention_type
=
'torch_sdpa'
)
x
=
x
.
reshape
(
b
,
1
,
c
)
# output
x
=
self
.
proj
(
x
)
x
=
F
.
dropout
(
x
,
self
.
proj_dropout
,
self
.
training
)
# mlp
x
=
x
+
self
.
mlp
(
self
.
norm
(
x
))
return
x
[:,
0
]
class
VisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
image_size
=
224
,
patch_size
=
16
,
dim
=
768
,
mlp_ratio
=
4
,
out_dim
=
512
,
num_heads
=
12
,
num_layers
=
12
,
pool_type
=
'token'
,
pre_norm
=
True
,
post_norm
=
False
,
activation
=
'quick_gelu'
,
attn_dropout
=
0.0
,
proj_dropout
=
0.0
,
embedding_dropout
=
0.0
,
norm_eps
=
1e-5
):
if
image_size
%
patch_size
!=
0
:
print
(
'[WARNING] image_size is not divisible by patch_size'
,
flush
=
True
)
assert
pool_type
in
(
'token'
,
'token_fc'
,
'attn_pool'
)
out_dim
=
out_dim
or
dim
super
().
__init__
()
self
.
image_size
=
image_size
self
.
patch_size
=
patch_size
self
.
num_patches
=
(
image_size
//
patch_size
)
**
2
self
.
dim
=
dim
self
.
mlp_ratio
=
mlp_ratio
self
.
out_dim
=
out_dim
self
.
num_heads
=
num_heads
self
.
num_layers
=
num_layers
self
.
pool_type
=
pool_type
self
.
post_norm
=
post_norm
self
.
norm_eps
=
norm_eps
# embeddings
gain
=
1.0
/
math
.
sqrt
(
dim
)
self
.
patch_embedding
=
nn
.
Conv2d
(
3
,
dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
bias
=
not
pre_norm
)
if
pool_type
in
(
'token'
,
'token_fc'
):
self
.
cls_embedding
=
nn
.
Parameter
(
gain
*
torch
.
randn
(
1
,
1
,
dim
))
self
.
pos_embedding
=
nn
.
Parameter
(
gain
*
torch
.
randn
(
1
,
self
.
num_patches
+
(
1
if
pool_type
in
(
'token'
,
'token_fc'
)
else
0
),
dim
))
self
.
dropout
=
nn
.
Dropout
(
embedding_dropout
)
# transformer
self
.
pre_norm
=
LayerNorm
(
dim
,
eps
=
norm_eps
)
if
pre_norm
else
None
self
.
transformer
=
nn
.
Sequential
(
*
[
AttentionBlock
(
dim
,
mlp_ratio
,
num_heads
,
post_norm
,
False
,
activation
,
attn_dropout
,
proj_dropout
,
norm_eps
)
for
_
in
range
(
num_layers
)
])
self
.
post_norm
=
LayerNorm
(
dim
,
eps
=
norm_eps
)
# head
if
pool_type
==
'token'
:
self
.
head
=
nn
.
Parameter
(
gain
*
torch
.
randn
(
dim
,
out_dim
))
elif
pool_type
==
'token_fc'
:
self
.
head
=
nn
.
Linear
(
dim
,
out_dim
)
elif
pool_type
==
'attn_pool'
:
self
.
head
=
AttentionPool
(
dim
,
mlp_ratio
,
num_heads
,
activation
,
proj_dropout
,
norm_eps
)
def
forward
(
self
,
x
,
interpolation
=
False
,
use_31_block
=
False
):
b
=
x
.
size
(
0
)
# embeddings
x
=
self
.
patch_embedding
(
x
).
flatten
(
2
).
permute
(
0
,
2
,
1
)
if
self
.
pool_type
in
(
'token'
,
'token_fc'
):
x
=
torch
.
cat
([
self
.
cls_embedding
.
expand
(
b
,
-
1
,
-
1
),
x
],
dim
=
1
)
if
interpolation
:
e
=
pos_interpolate
(
self
.
pos_embedding
,
x
.
size
(
1
))
else
:
e
=
self
.
pos_embedding
x
=
self
.
dropout
(
x
+
e
)
if
self
.
pre_norm
is
not
None
:
x
=
self
.
pre_norm
(
x
)
# transformer
if
use_31_block
:
x
=
self
.
transformer
[:
-
1
](
x
)
return
x
else
:
x
=
self
.
transformer
(
x
)
return
x
class
XLMRobertaWithHead
(
XLMRoberta
):
def
__init__
(
self
,
**
kwargs
):
self
.
out_dim
=
kwargs
.
pop
(
'out_dim'
)
super
().
__init__
(
**
kwargs
)
# head
mid_dim
=
(
self
.
dim
+
self
.
out_dim
)
//
2
self
.
head
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
dim
,
mid_dim
,
bias
=
False
),
nn
.
GELU
(),
nn
.
Linear
(
mid_dim
,
self
.
out_dim
,
bias
=
False
))
def
forward
(
self
,
ids
):
# xlm-roberta
x
=
super
().
forward
(
ids
)
# average pooling
mask
=
ids
.
ne
(
self
.
pad_id
).
unsqueeze
(
-
1
).
to
(
x
)
x
=
(
x
*
mask
).
sum
(
dim
=
1
)
/
mask
.
sum
(
dim
=
1
)
# head
x
=
self
.
head
(
x
)
return
x
class
XLMRobertaCLIP
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
=
1024
,
image_size
=
224
,
patch_size
=
14
,
vision_dim
=
1280
,
vision_mlp_ratio
=
4
,
vision_heads
=
16
,
vision_layers
=
32
,
vision_pool
=
'token'
,
vision_pre_norm
=
True
,
vision_post_norm
=
False
,
activation
=
'gelu'
,
vocab_size
=
250002
,
max_text_len
=
514
,
type_size
=
1
,
pad_id
=
1
,
text_dim
=
1024
,
text_heads
=
16
,
text_layers
=
24
,
text_post_norm
=
True
,
text_dropout
=
0.1
,
attn_dropout
=
0.0
,
proj_dropout
=
0.0
,
embedding_dropout
=
0.0
,
norm_eps
=
1e-5
):
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
image_size
=
image_size
self
.
patch_size
=
patch_size
self
.
vision_dim
=
vision_dim
self
.
vision_mlp_ratio
=
vision_mlp_ratio
self
.
vision_heads
=
vision_heads
self
.
vision_layers
=
vision_layers
self
.
vision_pre_norm
=
vision_pre_norm
self
.
vision_post_norm
=
vision_post_norm
self
.
activation
=
activation
self
.
vocab_size
=
vocab_size
self
.
max_text_len
=
max_text_len
self
.
type_size
=
type_size
self
.
pad_id
=
pad_id
self
.
text_dim
=
text_dim
self
.
text_heads
=
text_heads
self
.
text_layers
=
text_layers
self
.
text_post_norm
=
text_post_norm
self
.
norm_eps
=
norm_eps
# models
self
.
visual
=
VisionTransformer
(
image_size
=
image_size
,
patch_size
=
patch_size
,
dim
=
vision_dim
,
mlp_ratio
=
vision_mlp_ratio
,
out_dim
=
embed_dim
,
num_heads
=
vision_heads
,
num_layers
=
vision_layers
,
pool_type
=
vision_pool
,
pre_norm
=
vision_pre_norm
,
post_norm
=
vision_post_norm
,
activation
=
activation
,
attn_dropout
=
attn_dropout
,
proj_dropout
=
proj_dropout
,
embedding_dropout
=
embedding_dropout
,
norm_eps
=
norm_eps
)
self
.
textual
=
XLMRobertaWithHead
(
vocab_size
=
vocab_size
,
max_seq_len
=
max_text_len
,
type_size
=
type_size
,
pad_id
=
pad_id
,
dim
=
text_dim
,
out_dim
=
embed_dim
,
num_heads
=
text_heads
,
num_layers
=
text_layers
,
post_norm
=
text_post_norm
,
dropout
=
text_dropout
)
self
.
log_scale
=
nn
.
Parameter
(
math
.
log
(
1
/
0.07
)
*
torch
.
ones
([]))
def
forward
(
self
,
imgs
,
txt_ids
):
"""
imgs: [B, 3, H, W] of torch.float32.
- mean: [0.48145466, 0.4578275, 0.40821073]
- std: [0.26862954, 0.26130258, 0.27577711]
txt_ids: [B, L] of torch.long.
Encoded by data.CLIPTokenizer.
"""
xi
=
self
.
visual
(
imgs
)
xt
=
self
.
textual
(
txt_ids
)
return
xi
,
xt
def
param_groups
(
self
):
groups
=
[{
'params'
:
[
p
for
n
,
p
in
self
.
named_parameters
()
if
'norm'
in
n
or
n
.
endswith
(
'bias'
)
],
'weight_decay'
:
0.0
},
{
'params'
:
[
p
for
n
,
p
in
self
.
named_parameters
()
if
not
(
'norm'
in
n
or
n
.
endswith
(
'bias'
))
]
}]
return
groups
def
_clip
(
pretrained
=
False
,
pretrained_name
=
None
,
model_cls
=
XLMRobertaCLIP
,
return_transforms
=
False
,
return_tokenizer
=
False
,
tokenizer_padding
=
'eos'
,
dtype
=
torch
.
float32
,
device
=
'cpu'
,
**
kwargs
):
# init a model on device
with
torch
.
device
(
device
):
model
=
model_cls
(
**
kwargs
)
# set device
model
=
model
.
to
(
dtype
=
dtype
,
device
=
device
)
output
=
(
model
,)
# init transforms
if
return_transforms
:
# mean and std
if
'siglip'
in
pretrained_name
.
lower
():
mean
,
std
=
[
0.5
,
0.5
,
0.5
],
[
0.5
,
0.5
,
0.5
]
else
:
mean
=
[
0.48145466
,
0.4578275
,
0.40821073
]
std
=
[
0.26862954
,
0.26130258
,
0.27577711
]
# transforms
transforms
=
T
.
Compose
([
T
.
Resize
((
model
.
image_size
,
model
.
image_size
),
interpolation
=
T
.
InterpolationMode
.
BICUBIC
),
T
.
ToTensor
(),
T
.
Normalize
(
mean
=
mean
,
std
=
std
)
])
output
+=
(
transforms
,)
return
output
[
0
]
if
len
(
output
)
==
1
else
output
def
clip_xlm_roberta_vit_h_14
(
pretrained
=
False
,
pretrained_name
=
'open-clip-xlm-roberta-large-vit-huge-14'
,
**
kwargs
):
cfg
=
dict
(
embed_dim
=
1024
,
image_size
=
224
,
patch_size
=
14
,
vision_dim
=
1280
,
vision_mlp_ratio
=
4
,
vision_heads
=
16
,
vision_layers
=
32
,
vision_pool
=
'token'
,
activation
=
'gelu'
,
vocab_size
=
250002
,
max_text_len
=
514
,
type_size
=
1
,
pad_id
=
1
,
text_dim
=
1024
,
text_heads
=
16
,
text_layers
=
24
,
text_post_norm
=
True
,
text_dropout
=
0.1
,
attn_dropout
=
0.0
,
proj_dropout
=
0.0
,
embedding_dropout
=
0.0
)
cfg
.
update
(
**
kwargs
)
return
_clip
(
pretrained
,
pretrained_name
,
XLMRobertaCLIP
,
**
cfg
)
class
CLIPModel
:
def
__init__
(
self
,
dtype
,
device
,
checkpoint_path
,
tokenizer_path
):
self
.
dtype
=
dtype
self
.
device
=
device
self
.
checkpoint_path
=
checkpoint_path
self
.
tokenizer_path
=
tokenizer_path
# init model
self
.
model
,
self
.
transforms
=
clip_xlm_roberta_vit_h_14
(
pretrained
=
False
,
return_transforms
=
True
,
return_tokenizer
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
model
=
self
.
model
.
eval
().
requires_grad_
(
False
)
logging
.
info
(
f
'loading
{
checkpoint_path
}
'
)
self
.
model
.
load_state_dict
(
torch
.
load
(
checkpoint_path
,
map_location
=
'cpu'
,
weights_only
=
True
))
# init tokenizer
self
.
tokenizer
=
HuggingfaceTokenizer
(
name
=
tokenizer_path
,
seq_len
=
self
.
model
.
max_text_len
-
2
,
clean
=
'whitespace'
)
def
visual
(
self
,
videos
):
# preprocess
size
=
(
self
.
model
.
image_size
,)
*
2
videos
=
torch
.
cat
([
F
.
interpolate
(
u
.
transpose
(
0
,
1
),
size
=
size
,
mode
=
'bicubic'
,
align_corners
=
False
)
for
u
in
videos
])
videos
=
self
.
transforms
.
transforms
[
-
1
](
videos
.
mul_
(
0.5
).
add_
(
0.5
))
# forward
with
torch
.
amp
.
autocast
(
'cuda'
,
dtype
=
self
.
dtype
):
out
=
self
.
model
.
visual
(
videos
,
use_31_block
=
True
)
return
out
lightx2v/image2v/models/wan/xlm_roberta.py
0 → 100755
View file @
daf4c74e
# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
__all__
=
[
'XLMRoberta'
,
'xlm_roberta_large'
]
class
SelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
dropout
=
0.1
,
eps
=
1e-5
):
assert
dim
%
num_heads
==
0
super
().
__init__
()
self
.
dim
=
dim
self
.
num_heads
=
num_heads
self
.
head_dim
=
dim
//
num_heads
self
.
eps
=
eps
# layers
self
.
q
=
nn
.
Linear
(
dim
,
dim
)
self
.
k
=
nn
.
Linear
(
dim
,
dim
)
self
.
v
=
nn
.
Linear
(
dim
,
dim
)
self
.
o
=
nn
.
Linear
(
dim
,
dim
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
x
,
mask
):
"""
x: [B, L, C].
"""
b
,
s
,
c
,
n
,
d
=
*
x
.
size
(),
self
.
num_heads
,
self
.
head_dim
# compute query, key, value
q
=
self
.
q
(
x
).
reshape
(
b
,
s
,
n
,
d
).
permute
(
0
,
2
,
1
,
3
)
k
=
self
.
k
(
x
).
reshape
(
b
,
s
,
n
,
d
).
permute
(
0
,
2
,
1
,
3
)
v
=
self
.
v
(
x
).
reshape
(
b
,
s
,
n
,
d
).
permute
(
0
,
2
,
1
,
3
)
# compute attention
p
=
self
.
dropout
.
p
if
self
.
training
else
0.0
x
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
p
)
x
=
x
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
b
,
s
,
c
)
# output
x
=
self
.
o
(
x
)
x
=
self
.
dropout
(
x
)
return
x
class
AttentionBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
post_norm
,
dropout
=
0.1
,
eps
=
1e-5
):
super
().
__init__
()
self
.
dim
=
dim
self
.
num_heads
=
num_heads
self
.
post_norm
=
post_norm
self
.
eps
=
eps
# layers
self
.
attn
=
SelfAttention
(
dim
,
num_heads
,
dropout
,
eps
)
self
.
norm1
=
nn
.
LayerNorm
(
dim
,
eps
=
eps
)
self
.
ffn
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
dim
*
4
),
nn
.
GELU
(),
nn
.
Linear
(
dim
*
4
,
dim
),
nn
.
Dropout
(
dropout
))
self
.
norm2
=
nn
.
LayerNorm
(
dim
,
eps
=
eps
)
def
forward
(
self
,
x
,
mask
):
if
self
.
post_norm
:
x
=
self
.
norm1
(
x
+
self
.
attn
(
x
,
mask
))
x
=
self
.
norm2
(
x
+
self
.
ffn
(
x
))
else
:
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
),
mask
)
x
=
x
+
self
.
ffn
(
self
.
norm2
(
x
))
return
x
class
XLMRoberta
(
nn
.
Module
):
"""
XLMRobertaModel with no pooler and no LM head.
"""
def
__init__
(
self
,
vocab_size
=
250002
,
max_seq_len
=
514
,
type_size
=
1
,
pad_id
=
1
,
dim
=
1024
,
num_heads
=
16
,
num_layers
=
24
,
post_norm
=
True
,
dropout
=
0.1
,
eps
=
1e-5
):
super
().
__init__
()
self
.
vocab_size
=
vocab_size
self
.
max_seq_len
=
max_seq_len
self
.
type_size
=
type_size
self
.
pad_id
=
pad_id
self
.
dim
=
dim
self
.
num_heads
=
num_heads
self
.
num_layers
=
num_layers
self
.
post_norm
=
post_norm
self
.
eps
=
eps
# embeddings
self
.
token_embedding
=
nn
.
Embedding
(
vocab_size
,
dim
,
padding_idx
=
pad_id
)
self
.
type_embedding
=
nn
.
Embedding
(
type_size
,
dim
)
self
.
pos_embedding
=
nn
.
Embedding
(
max_seq_len
,
dim
,
padding_idx
=
pad_id
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
# blocks
self
.
blocks
=
nn
.
ModuleList
([
AttentionBlock
(
dim
,
num_heads
,
post_norm
,
dropout
,
eps
)
for
_
in
range
(
num_layers
)
])
# norm layer
self
.
norm
=
nn
.
LayerNorm
(
dim
,
eps
=
eps
)
def
forward
(
self
,
ids
):
"""
ids: [B, L] of torch.LongTensor.
"""
b
,
s
=
ids
.
shape
mask
=
ids
.
ne
(
self
.
pad_id
).
long
()
# embeddings
x
=
self
.
token_embedding
(
ids
)
+
\
self
.
type_embedding
(
torch
.
zeros_like
(
ids
))
+
\
self
.
pos_embedding
(
self
.
pad_id
+
torch
.
cumsum
(
mask
,
dim
=
1
)
*
mask
)
if
self
.
post_norm
:
x
=
self
.
norm
(
x
)
x
=
self
.
dropout
(
x
)
# blocks
mask
=
torch
.
where
(
mask
.
view
(
b
,
1
,
1
,
s
).
gt
(
0
),
0.0
,
torch
.
finfo
(
x
.
dtype
).
min
)
for
block
in
self
.
blocks
:
x
=
block
(
x
,
mask
)
# output
if
not
self
.
post_norm
:
x
=
self
.
norm
(
x
)
return
x
def
xlm_roberta_large
(
pretrained
=
False
,
return_tokenizer
=
False
,
device
=
'cpu'
,
**
kwargs
):
"""
XLMRobertaLarge adapted from Huggingface.
"""
# params
cfg
=
dict
(
vocab_size
=
250002
,
max_seq_len
=
514
,
type_size
=
1
,
pad_id
=
1
,
dim
=
1024
,
num_heads
=
16
,
num_layers
=
24
,
post_norm
=
True
,
dropout
=
0.1
,
eps
=
1e-5
)
cfg
.
update
(
**
kwargs
)
# init a model on device
with
torch
.
device
(
device
):
model
=
XLMRoberta
(
**
cfg
)
return
model
lightx2v/text2v/__init__.py
0 → 100755
View file @
daf4c74e
lightx2v/text2v/models/__init__.py
0 → 100755
View file @
daf4c74e
lightx2v/text2v/models/networks/hunyuan/__init__.py
0 → 100755
View file @
daf4c74e
lightx2v/text2v/models/networks/hunyuan/infer/__init__.py
0 → 100755
View file @
daf4c74e
lightx2v/text2v/models/networks/hunyuan/infer/feature_caching/__init__.py
0 → 100755
View file @
daf4c74e
lightx2v/text2v/models/networks/hunyuan/infer/feature_caching/transformer_infer.py
0 → 100755
View file @
daf4c74e
import
torch
from
einops
import
rearrange
from
lightx2v.attentions
import
attention
from
..utils_bf16
import
apply_rotary_emb
from
typing
import
Dict
import
math
from
..transformer_infer
import
HunyuanTransformerInfer
def
taylor_cache_init
(
cache_dic
:
Dict
,
current
:
Dict
):
"""
Initialize Taylor cache, expanding storage areas for Taylor series derivatives
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
if
current
[
'step'
]
==
0
:
cache_dic
[
'cache'
][
-
1
][
current
[
'stream'
]][
current
[
'layer'
]][
current
[
'module'
]]
=
{}
def
derivative_approximation
(
cache_dic
:
Dict
,
current
:
Dict
,
feature
:
torch
.
Tensor
):
"""
Compute derivative approximation
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
difference_distance
=
current
[
'activated_steps'
][
-
1
]
-
current
[
'activated_steps'
][
-
2
]
#difference_distance = current['activated_times'][-1] - current['activated_times'][-2]
updated_taylor_factors
=
{}
updated_taylor_factors
[
0
]
=
feature
for
i
in
range
(
cache_dic
[
'max_order'
]):
if
(
cache_dic
[
'cache'
][
-
1
][
current
[
'stream'
]][
current
[
'layer'
]][
current
[
'module'
]].
get
(
i
,
None
)
is
not
None
)
and
(
current
[
'step'
]
>
cache_dic
[
'first_enhance'
]
-
2
):
updated_taylor_factors
[
i
+
1
]
=
(
updated_taylor_factors
[
i
]
-
cache_dic
[
'cache'
][
-
1
][
current
[
'stream'
]][
current
[
'layer'
]][
current
[
'module'
]][
i
])
/
difference_distance
else
:
break
cache_dic
[
'cache'
][
-
1
][
current
[
'stream'
]][
current
[
'layer'
]][
current
[
'module'
]]
=
updated_taylor_factors
def
taylor_formula
(
cache_dic
:
Dict
,
current
:
Dict
)
->
torch
.
Tensor
:
"""
Compute Taylor expansion error
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
x
=
current
[
'step'
]
-
current
[
'activated_steps'
][
-
1
]
#x = current['t'] - current['activated_times'][-1]
output
=
0
for
i
in
range
(
len
(
cache_dic
[
'cache'
][
-
1
][
current
[
'stream'
]][
current
[
'layer'
]][
current
[
'module'
]])):
output
+=
(
1
/
math
.
factorial
(
i
))
*
cache_dic
[
'cache'
][
-
1
][
current
[
'stream'
]][
current
[
'layer'
]][
current
[
'module'
]][
i
]
*
(
x
**
i
)
return
output
class
HunyuanTransformerInferFeatureCaching
(
HunyuanTransformerInfer
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
infer
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
txt_seq_len
=
txt
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
self
.
scheduler
.
current
[
'stream'
]
=
'double_stream'
for
i
in
range
(
self
.
double_blocks_num
):
self
.
scheduler
.
current
[
'layer'
]
=
i
img
,
txt
=
self
.
infer_double_block
(
weights
.
double_blocks_weights
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
)
x
=
torch
.
cat
((
img
,
txt
),
0
)
self
.
scheduler
.
current
[
'stream'
]
=
'single_stream'
for
i
in
range
(
self
.
single_blocks_num
):
self
.
scheduler
.
current
[
'layer'
]
=
i
x
=
self
.
infer_single_block
(
weights
.
single_blocks_weights
[
i
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
)
img
=
x
[:
img_seq_len
,
...]
return
img
,
vec
def
infer_double_block
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
vec_silu
=
torch
.
nn
.
functional
.
silu
(
vec
)
img_mod_out
=
weights
.
img_mod
.
apply
(
vec_silu
)
(
img_mod1_shift
,
img_mod1_scale
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
)
=
img_mod_out
.
chunk
(
6
,
dim
=-
1
)
txt_mod_out
=
weights
.
txt_mod
.
apply
(
vec_silu
)
(
txt_mod1_shift
,
txt_mod1_scale
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
=
txt_mod_out
.
chunk
(
6
,
dim
=-
1
)
if
self
.
scheduler
.
current
[
'type'
]
==
'full'
:
img_q
,
img_k
,
img_v
=
self
.
infer_double_block_img_pre_atten
(
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
freqs_cis
)
txt_q
,
txt_k
,
txt_v
=
self
.
infer_double_block_txt_pre_atten
(
weights
,
txt
,
txt_mod1_scale
,
txt_mod1_shift
)
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
v
=
torch
.
cat
((
img_v
,
txt_v
),
dim
=
0
)
if
not
self
.
parallel_attention
:
attn
=
attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
,
)
else
:
# world_size = dist.get_world_size()
attn
=
self
.
parallel_attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
img_qkv_len
=
img_q
.
shape
[
0
],
cu_seqlens_qkv
=
cu_seqlens_qkv
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
img_attn
,
txt_attn
=
attn
[:
img
.
shape
[
0
]],
attn
[
img
.
shape
[
0
]
:]
img
=
self
.
infer_double_block_img_post_atten
(
weights
,
img
,
img_attn
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
)
txt
=
self
.
infer_double_block_txt_post_atten
(
weights
,
txt
,
txt_attn
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
)
return
img
,
txt
elif
self
.
scheduler
.
current
[
'type'
]
==
'taylor_cache'
:
self
.
scheduler
.
current
[
'module'
]
=
'img_attn'
out
=
taylor_formula
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
)
out
=
out
*
img_mod1_gate
img
=
img
+
out
self
.
scheduler
.
current
[
'module'
]
=
'img_mlp'
out
=
taylor_formula
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
)
out
=
out
*
img_mod2_gate
img
=
img
+
out
self
.
scheduler
.
current
[
'module'
]
=
'txt_attn'
out
=
taylor_formula
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
)
out
=
out
*
txt_mod1_gate
txt
=
txt
+
out
self
.
scheduler
.
current
[
'module'
]
=
'txt_mlp'
out
=
out
*
txt_mod2_gate
txt
=
txt
+
out
return
img
,
txt
def
infer_double_block_img_post_atten
(
self
,
weights
,
img
,
img_attn
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
):
self
.
scheduler
.
current
[
'module'
]
=
'img_attn'
taylor_cache_init
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
)
out
=
weights
.
img_attn_proj
.
apply
(
img_attn
)
derivative_approximation
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
,
out
)
out
=
out
*
img_mod1_gate
img
=
img
+
out
self
.
scheduler
.
current
[
'module'
]
=
'img_mlp'
taylor_cache_init
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
)
out
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
out
=
out
*
(
1
+
img_mod2_scale
)
+
img_mod2_shift
out
=
weights
.
img_mlp_fc1
.
apply
(
out
)
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
'tanh'
)
out
=
weights
.
img_mlp_fc2
.
apply
(
out
)
derivative_approximation
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
,
out
)
out
=
out
*
img_mod2_gate
img
=
img
+
out
return
img
def
infer_double_block_txt_post_atten
(
self
,
weights
,
txt
,
txt_attn
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
):
self
.
scheduler
.
current
[
'module'
]
=
'txt_attn'
taylor_cache_init
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
)
out
=
weights
.
txt_attn_proj
.
apply
(
txt_attn
)
derivative_approximation
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
,
out
)
out
=
out
*
txt_mod1_gate
txt
=
txt
+
out
self
.
scheduler
.
current
[
'module'
]
=
'txt_mlp'
taylor_cache_init
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
)
out
=
torch
.
nn
.
functional
.
layer_norm
(
txt
,
(
txt
.
shape
[
1
],),
None
,
None
,
1e-6
)
out
=
out
*
(
1
+
txt_mod2_scale
)
+
txt_mod2_shift
out
=
weights
.
txt_mlp_fc1
.
apply
(
out
)
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
'tanh'
)
out
=
weights
.
txt_mlp_fc2
.
apply
(
out
)
derivative_approximation
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
,
out
)
out
=
out
*
txt_mod2_gate
txt
=
txt
+
out
return
txt
def
infer_single_block
(
self
,
weights
,
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
out
=
torch
.
nn
.
functional
.
silu
(
vec
)
out
=
weights
.
modulation
.
apply
(
out
)
mod_shift
,
mod_scale
,
mod_gate
=
out
.
chunk
(
3
,
dim
=-
1
)
if
self
.
scheduler
.
current
[
'type'
]
==
'full'
:
out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
x_mod
=
out
*
(
1
+
mod_scale
)
+
mod_shift
x_mod
=
weights
.
linear1
.
apply
(
x_mod
)
qkv
,
mlp
=
torch
.
split
(
x_mod
,
[
3
*
self
.
hidden_size
,
self
.
mlp_hidden_dim
],
dim
=-
1
)
self
.
scheduler
.
current
[
'module'
]
=
'attn'
taylor_cache_init
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
)
q
,
k
,
v
=
rearrange
(
qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
q
=
weights
.
q_norm
.
apply
(
q
)
k
=
weights
.
k_norm
.
apply
(
k
)
img_q
,
txt_q
=
q
[:
-
txt_seq_len
,
:,
:],
q
[
-
txt_seq_len
:,
:,
:]
img_k
,
txt_k
=
k
[:
-
txt_seq_len
,
:,
:],
k
[
-
txt_seq_len
:,
:,
:]
img_q
,
img_k
=
apply_rotary_emb
(
img_q
,
img_k
,
freqs_cis
)
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
if
not
self
.
parallel_attention
:
attn
=
attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
,
)
else
:
attn
=
self
.
parallel_attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
img_qkv_len
=
img_q
.
shape
[
0
],
cu_seqlens_qkv
=
cu_seqlens_qkv
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
derivative_approximation
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
,
attn
)
self
.
scheduler
.
current
[
'module'
]
=
'total'
taylor_cache_init
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
)
out
=
torch
.
nn
.
functional
.
gelu
(
mlp
,
approximate
=
'tanh'
)
out
=
torch
.
cat
((
attn
,
out
),
1
)
out
=
weights
.
linear2
.
apply
(
out
)
derivative_approximation
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
,
out
)
out
=
out
*
mod_gate
x
=
x
+
out
return
x
elif
self
.
scheduler
.
current
[
'type'
]
==
'taylor_cache'
:
self
.
scheduler
.
current
[
'module'
]
=
'total'
out
=
taylor_formula
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
)
out
=
out
*
mod_gate
x
=
x
+
out
return
x
lightx2v/text2v/models/networks/hunyuan/infer/post_infer.py
0 → 100755
View file @
daf4c74e
import
torch
class
HunyuanPostInfer
():
def
__init__
(
self
):
pass
def
infer
(
self
,
weights
,
img
,
vec
,
shape
):
out
=
torch
.
nn
.
functional
.
silu
(
vec
)
out
=
weights
.
final_layer_adaLN_modulation_1
.
apply
(
out
)
shift
,
scale
=
out
.
chunk
(
2
,
dim
=
1
)
out
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
out
=
out
*
(
1
+
scale
)
+
shift
out
=
weights
.
final_layer_linear
.
apply
(
out
.
to
(
torch
.
float32
))
_
,
_
,
ot
,
oh
,
ow
=
shape
patch_size
=
[
1
,
2
,
2
]
tt
,
th
,
tw
=
(
ot
//
patch_size
[
0
],
oh
//
patch_size
[
1
],
ow
//
patch_size
[
2
],
)
c
=
16
pt
,
ph
,
pw
=
patch_size
out
=
out
.
reshape
(
shape
=
(
1
,
tt
,
th
,
tw
,
c
,
pt
,
ph
,
pw
))
out
=
torch
.
einsum
(
"nthwcopq->nctohpwq"
,
out
)
out
=
out
.
reshape
(
shape
=
(
1
,
c
,
tt
*
pt
,
th
*
ph
,
tw
*
pw
))
return
out
lightx2v/text2v/models/networks/hunyuan/infer/pre_infer.py
0 → 100755
View file @
daf4c74e
import
torch
import
math
from
einops
import
rearrange
from
lightx2v.attentions
import
attention
class
HunyuanPreInfer
():
def
__init__
(
self
):
self
.
heads_num
=
24
def
infer
(
self
,
weights
,
x
,
t
,
text_states
,
text_mask
,
text_states_2
,
freqs_cos
,
freqs_sin
,
guidance
):
time_out
=
self
.
infer_time_in
(
weights
,
t
)
img_out
=
self
.
infer_img_in
(
weights
,
x
)
infer_text_out
=
self
.
infer_text_in
(
weights
,
text_states
,
text_mask
,
t
)
infer_vector_out
=
self
.
infer_vector_in
(
weights
,
text_states_2
)
vec
=
time_out
+
infer_vector_out
guidance_out
=
self
.
infer_guidance_in
(
weights
,
guidance
)
vec
=
vec
+
guidance_out
txt_seq_len
=
infer_text_out
.
shape
[
0
]
img_seq_len
=
img_out
.
shape
[
1
]
batch_size
=
text_mask
.
shape
[
0
]
text_len
=
text_mask
.
sum
(
dim
=
1
)
max_len
=
text_mask
.
shape
[
1
]
+
img_seq_len
cu_seqlens_qkv
=
torch
.
zeros
([
2
*
batch_size
+
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
for
i
in
range
(
batch_size
):
s
=
text_len
[
i
]
+
img_seq_len
s1
=
i
*
max_len
+
s
s2
=
(
i
+
1
)
*
max_len
cu_seqlens_qkv
[
2
*
i
+
1
]
=
s1
cu_seqlens_qkv
[
2
*
i
+
2
]
=
s2
max_seqlen_qkv
=
img_seq_len
+
txt_seq_len
return
img_out
[
0
],
infer_text_out
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
(
freqs_cos
,
freqs_sin
)
def
infer_time_in
(
self
,
weights
,
t
):
freqs
=
torch
.
exp
(
-
math
.
log
(
10000
)
*
torch
.
arange
(
start
=
0
,
end
=
128
,
dtype
=
torch
.
float32
)
/
128
).
to
(
device
=
t
.
device
)
args
=
t
.
unsqueeze
(
0
).
unsqueeze
(
0
).
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
).
to
(
dtype
=
torch
.
bfloat16
)
out
=
weights
.
time_in_mlp_0
.
apply
(
embedding
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
out
=
weights
.
time_in_mlp_2
.
apply
(
out
)
return
out
def
infer_img_in
(
self
,
weights
,
x
):
out
=
weights
.
img_in_proj
.
apply
(
x
)
out
=
out
.
flatten
(
2
).
transpose
(
1
,
2
)
return
out
def
infer_text_in
(
self
,
weights
,
text_states
,
text_mask
,
t
):
freqs
=
torch
.
exp
(
-
math
.
log
(
10000
)
*
torch
.
arange
(
start
=
0
,
end
=
128
,
dtype
=
torch
.
float32
)
/
128
).
to
(
device
=
t
.
device
)
args
=
t
.
unsqueeze
(
0
).
unsqueeze
(
0
).
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
).
to
(
dtype
=
torch
.
bfloat16
)
out
=
weights
.
txt_in_t_embedder_mlp_0
.
apply
(
embedding
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
timestep_aware_representations
=
weights
.
txt_in_t_embedder_mlp_2
.
apply
(
out
)
mask_float
=
text_mask
.
float
().
unsqueeze
(
-
1
).
to
(
torch
.
bfloat16
)
# [b, s1, 1]
context_aware_representations
=
(
text_states
*
mask_float
).
sum
(
dim
=
1
)
/
mask_float
.
sum
(
dim
=
1
)
context_aware_representations
=
context_aware_representations
out
=
weights
.
txt_in_c_embedder_linear_1
.
apply
(
context_aware_representations
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
context_aware_representations
=
weights
.
txt_in_c_embedder_linear_2
.
apply
(
out
)
c
=
timestep_aware_representations
+
context_aware_representations
txt_in_input_embed
=
weights
.
txt_in_input_embedder
.
apply
(
text_states
[
0
])
batch_size
=
text_mask
.
shape
[
0
]
seq_len
=
text_mask
.
shape
[
1
]
self_attn_mask_1
=
text_mask
.
view
(
batch_size
,
1
,
1
,
seq_len
).
repeat
(
1
,
1
,
seq_len
,
1
)
self_attn_mask_2
=
self_attn_mask_1
.
transpose
(
2
,
3
)
self_attn_mask
=
(
self_attn_mask_1
&
self_attn_mask_2
).
bool
()
self_attn_mask
[:,
:,
:,
0
]
=
True
cx
=
torch
.
nn
.
functional
.
silu
(
c
)
cx
=
weights
.
txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1
.
apply
(
cx
)
gate_msa
,
gate_mlp
=
cx
.
chunk
(
2
,
dim
=
1
)
normx
=
weights
.
txt_in_individual_token_refiner_blocks_0_norm1
.
apply
(
txt_in_input_embed
)
qkv
=
weights
.
txt_in_individual_token_refiner_blocks_0_self_attn_qkv
.
apply
(
normx
)
q
,
k
,
v
=
rearrange
(
qkv
.
unsqueeze
(
0
),
"B L (K H D) -> K B L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
attn
=
attention
(
attention_type
=
"torch_sdpa"
,
q
=
q
,
k
=
k
,
v
=
v
,
attn_mask
=
self_attn_mask
)[
0
]
out
=
weights
.
txt_in_individual_token_refiner_blocks_0_self_attn_proj
.
apply
(
attn
)
out_1
=
txt_in_input_embed
+
out
*
gate_msa
out
=
weights
.
txt_in_individual_token_refiner_blocks_0_norm2
.
apply
(
out_1
)
# mlp
out
=
weights
.
txt_in_individual_token_refiner_blocks_0_mlp_fc1
.
apply
(
out
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
out
=
weights
.
txt_in_individual_token_refiner_blocks_0_mlp_fc2
.
apply
(
out
)
txt_in_input_embed
=
out_1
+
out
*
gate_mlp
cx
=
torch
.
nn
.
functional
.
silu
(
c
)
cx
=
weights
.
txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1
.
apply
(
cx
)
gate_msa
,
gate_mlp
=
cx
.
chunk
(
2
,
dim
=
1
)
normx
=
weights
.
txt_in_individual_token_refiner_blocks_1_norm1
.
apply
(
txt_in_input_embed
)
qkv
=
weights
.
txt_in_individual_token_refiner_blocks_1_self_attn_qkv
.
apply
(
normx
)
q
,
k
,
v
=
rearrange
(
qkv
.
unsqueeze
(
0
),
"B L (K H D) -> K B L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
attn
=
attention
(
attention_type
=
"torch_sdpa"
,
q
=
q
,
k
=
k
,
v
=
v
,
attn_mask
=
self_attn_mask
)[
0
]
out
=
weights
.
txt_in_individual_token_refiner_blocks_1_self_attn_proj
.
apply
(
attn
)
out_1
=
txt_in_input_embed
+
out
*
gate_msa
out
=
weights
.
txt_in_individual_token_refiner_blocks_1_norm2
.
apply
(
out_1
)
# mlp
out
=
weights
.
txt_in_individual_token_refiner_blocks_1_mlp_fc1
.
apply
(
out
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
out
=
weights
.
txt_in_individual_token_refiner_blocks_1_mlp_fc2
.
apply
(
out
)
out
=
out_1
+
out
*
gate_mlp
return
out
def
infer_vector_in
(
self
,
weights
,
text_states_2
):
out
=
weights
.
vector_in_in_layer
.
apply
(
text_states_2
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
out
=
weights
.
vector_in_out_layer
.
apply
(
out
)
return
out
def
infer_guidance_in
(
self
,
weights
,
guidance
):
freqs
=
torch
.
exp
(
-
math
.
log
(
10000
)
*
torch
.
arange
(
start
=
0
,
end
=
128
,
dtype
=
torch
.
float32
)
/
128
).
to
(
device
=
guidance
.
device
)
args
=
guidance
.
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
).
to
(
dtype
=
torch
.
bfloat16
)
out
=
weights
.
guidance_in_mlp_0
.
apply
(
embedding
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
out
=
weights
.
guidance_in_mlp_2
.
apply
(
out
)
return
out
lightx2v/text2v/models/networks/hunyuan/infer/transformer_infer.py
0 → 100755
View file @
daf4c74e
import
torch
from
einops
import
rearrange
from
lightx2v.attentions
import
attention
from
.utils_bf16
import
apply_rotary_emb
class
HunyuanTransformerInfer
():
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
attention_type
=
config
.
get
(
"attention_type"
,
"flash_attn2"
)
self
.
double_blocks_num
=
20
self
.
single_blocks_num
=
40
self
.
heads_num
=
24
self
.
hidden_size
=
3072
self
.
mlp_hidden_dim
=
12288
self
.
parallel_attention
=
None
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
def
infer
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
txt_seq_len
=
txt
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
for
i
in
range
(
self
.
double_blocks_num
):
img
,
txt
=
self
.
infer_double_block
(
weights
.
double_blocks_weights
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
)
x
=
torch
.
cat
((
img
,
txt
),
0
)
for
i
in
range
(
self
.
single_blocks_num
):
x
=
self
.
infer_single_block
(
weights
.
single_blocks_weights
[
i
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
)
img
=
x
[:
img_seq_len
,
...]
return
img
,
vec
def
infer_double_block
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
vec_silu
=
torch
.
nn
.
functional
.
silu
(
vec
)
img_mod_out
=
weights
.
img_mod
.
apply
(
vec_silu
)
(
img_mod1_shift
,
img_mod1_scale
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
)
=
img_mod_out
.
chunk
(
6
,
dim
=-
1
)
txt_mod_out
=
weights
.
txt_mod
.
apply
(
vec_silu
)
(
txt_mod1_shift
,
txt_mod1_scale
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
=
txt_mod_out
.
chunk
(
6
,
dim
=-
1
)
img_q
,
img_k
,
img_v
=
self
.
infer_double_block_img_pre_atten
(
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
freqs_cis
)
txt_q
,
txt_k
,
txt_v
=
self
.
infer_double_block_txt_pre_atten
(
weights
,
txt
,
txt_mod1_scale
,
txt_mod1_shift
)
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
v
=
torch
.
cat
((
img_v
,
txt_v
),
dim
=
0
)
if
not
self
.
parallel_attention
:
attn
=
attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
,
)
else
:
# world_size = dist.get_world_size()
attn
=
self
.
parallel_attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
img_qkv_len
=
img_q
.
shape
[
0
],
cu_seqlens_qkv
=
cu_seqlens_qkv
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
img_attn
,
txt_attn
=
attn
[:
img
.
shape
[
0
]],
attn
[
img
.
shape
[
0
]
:]
img
=
self
.
infer_double_block_img_post_atten
(
weights
,
img
,
img_attn
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
)
txt
=
self
.
infer_double_block_txt_post_atten
(
weights
,
txt
,
txt_attn
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
)
return
img
,
txt
def
infer_double_block_img_pre_atten
(
self
,
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
freqs_cis
):
img_modulated
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
img_modulated
=
img_modulated
*
(
1
+
img_mod1_scale
)
+
img_mod1_shift
img_qkv
=
weights
.
img_attn_qkv
.
apply
(
img_modulated
)
img_q
,
img_k
,
img_v
=
rearrange
(
img_qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
img_q
=
weights
.
img_attn_q_norm
.
apply
(
img_q
)
img_k
=
weights
.
img_attn_k_norm
.
apply
(
img_k
)
img_q
,
img_k
=
apply_rotary_emb
(
img_q
,
img_k
,
freqs_cis
)
return
img_q
,
img_k
,
img_v
def
infer_double_block_txt_pre_atten
(
self
,
weights
,
txt
,
txt_mod1_scale
,
txt_mod1_shift
):
txt_modulated
=
torch
.
nn
.
functional
.
layer_norm
(
txt
,
(
txt
.
shape
[
1
],),
None
,
None
,
1e-6
)
txt_modulated
=
txt_modulated
*
(
1
+
txt_mod1_scale
)
+
txt_mod1_shift
txt_qkv
=
weights
.
txt_attn_qkv
.
apply
(
txt_modulated
)
txt_q
,
txt_k
,
txt_v
=
rearrange
(
txt_qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
txt_q
=
weights
.
txt_attn_q_norm
.
apply
(
txt_q
)
txt_k
=
weights
.
txt_attn_k_norm
.
apply
(
txt_k
)
return
txt_q
,
txt_k
,
txt_v
def
infer_double_block_img_post_atten
(
self
,
weights
,
img
,
img_attn
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
):
out
=
weights
.
img_attn_proj
.
apply
(
img_attn
)
out
=
out
*
img_mod1_gate
img
=
img
+
out
out
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
out
=
out
*
(
1
+
img_mod2_scale
)
+
img_mod2_shift
out
=
weights
.
img_mlp_fc1
.
apply
(
out
)
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
'tanh'
)
out
=
weights
.
img_mlp_fc2
.
apply
(
out
)
out
=
out
*
img_mod2_gate
img
=
img
+
out
return
img
def
infer_double_block_txt_post_atten
(
self
,
weights
,
txt
,
txt_attn
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
):
out
=
weights
.
txt_attn_proj
.
apply
(
txt_attn
)
out
=
out
*
txt_mod1_gate
txt
=
txt
+
out
out
=
torch
.
nn
.
functional
.
layer_norm
(
txt
,
(
txt
.
shape
[
1
],),
None
,
None
,
1e-6
)
out
=
out
*
(
1
+
txt_mod2_scale
)
+
txt_mod2_shift
out
=
weights
.
txt_mlp_fc1
.
apply
(
out
)
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
'tanh'
)
out
=
weights
.
txt_mlp_fc2
.
apply
(
out
)
out
=
out
*
txt_mod2_gate
txt
=
txt
+
out
return
txt
def
infer_single_block
(
self
,
weights
,
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
out
=
torch
.
nn
.
functional
.
silu
(
vec
)
out
=
weights
.
modulation
.
apply
(
out
)
mod_shift
,
mod_scale
,
mod_gate
=
out
.
chunk
(
3
,
dim
=-
1
)
out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
x_mod
=
out
*
(
1
+
mod_scale
)
+
mod_shift
x_mod
=
weights
.
linear1
.
apply
(
x_mod
)
qkv
,
mlp
=
torch
.
split
(
x_mod
,
[
3
*
self
.
hidden_size
,
self
.
mlp_hidden_dim
],
dim
=-
1
)
q
,
k
,
v
=
rearrange
(
qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
q
=
weights
.
q_norm
.
apply
(
q
)
k
=
weights
.
k_norm
.
apply
(
k
)
img_q
,
txt_q
=
q
[:
-
txt_seq_len
,
:,
:],
q
[
-
txt_seq_len
:,
:,
:]
img_k
,
txt_k
=
k
[:
-
txt_seq_len
,
:,
:],
k
[
-
txt_seq_len
:,
:,
:]
img_q
,
img_k
=
apply_rotary_emb
(
img_q
,
img_k
,
freqs_cis
)
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
if
not
self
.
parallel_attention
:
attn
=
attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
,
)
else
:
attn
=
self
.
parallel_attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
img_qkv_len
=
img_q
.
shape
[
0
],
cu_seqlens_qkv
=
cu_seqlens_qkv
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
out
=
torch
.
nn
.
functional
.
gelu
(
mlp
,
approximate
=
'tanh'
)
out
=
torch
.
cat
((
attn
,
out
),
1
)
out
=
weights
.
linear2
.
apply
(
out
)
out
=
out
*
mod_gate
x
=
x
+
out
return
x
lightx2v/text2v/models/networks/hunyuan/infer/utils.py
0 → 100644
View file @
daf4c74e
import
sgl_kernel
def
rms_norm
(
x
,
weight
,
eps
):
x
=
x
.
contiguous
()
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
orig_shape
[
-
1
])
x
=
sgl_kernel
.
rmsnorm
(
x
,
weight
,
eps
).
view
(
orig_shape
)
return
x
lightx2v/text2v/models/networks/hunyuan/infer/utils_bf16.py
0 → 100644
View file @
daf4c74e
import
torch
from
typing
import
Any
,
List
,
Tuple
,
Optional
,
Union
,
Dict
def
rms_norm
(
x
,
weight
,
eps
):
x
=
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
eps
)
x
=
x
*
weight
return
x
def
rotate_half
(
x
,
shape_0
,
shape_1
):
x_real
,
x_imag
=
x
.
reshape
(
shape_0
,
shape_1
,
-
1
,
2
).
unbind
(
-
1
)
return
torch
.
stack
([
-
x_imag
,
x_real
],
dim
=-
1
).
flatten
(
2
)
def
rotary_emb
(
x
,
shape_0
,
shape_1
,
cos
,
sin
):
x_out
=
(
x
*
cos
+
rotate_half
(
x
,
shape_0
,
shape_1
)
*
sin
)
return
x_out
def
apply_rotary_emb
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
freqs_cis
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
shape_0
,
shape_1
,
shape_2
=
xq
.
shape
cos
=
freqs_cis
[
0
].
view
(
shape_0
,
1
,
shape_2
)
sin
=
freqs_cis
[
1
].
view
(
shape_0
,
1
,
shape_2
)
xq_out
=
rotary_emb
(
xq
,
shape_0
,
shape_1
,
cos
,
sin
)
xk_out
=
rotary_emb
(
xk
,
shape_0
,
shape_1
,
cos
,
sin
)
return
xq_out
,
xk_out
lightx2v/text2v/models/networks/hunyuan/infer/utils_fp32.py
0 → 100644
View file @
daf4c74e
import
torch
from
typing
import
Any
,
List
,
Tuple
,
Optional
,
Union
,
Dict
def
rms_norm
(
x
,
weight
,
eps
):
x
=
x
.
float
()
x
=
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
eps
)
x
=
x
.
to
(
torch
.
bfloat16
)
x
=
x
*
weight
return
x
def
rotate_half
(
x
,
shape_0
,
shape_1
):
x_real
,
x_imag
=
x
.
float
().
reshape
(
shape_0
,
shape_1
,
-
1
,
2
).
unbind
(
-
1
)
return
torch
.
stack
([
-
x_imag
,
x_real
],
dim
=-
1
).
flatten
(
2
)
def
rotary_emb
(
x
,
shape_0
,
shape_1
,
cos
,
sin
):
x_out
=
(
x
*
cos
+
rotate_half
(
x
,
shape_0
,
shape_1
)
*
sin
)
return
x_out
.
to
(
torch
.
bfloat16
)
def
apply_rotary_emb
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
freqs_cis
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
shape_0
,
shape_1
,
shape_2
=
xq
.
shape
cos
=
freqs_cis
[
0
].
view
(
shape_0
,
1
,
shape_2
)
sin
=
freqs_cis
[
1
].
view
(
shape_0
,
1
,
shape_2
)
xq_out
=
rotary_emb
(
xq
.
float
(),
shape_0
,
shape_1
,
cos
,
sin
)
xk_out
=
rotary_emb
(
xk
.
float
(),
shape_0
,
shape_1
,
cos
,
sin
)
return
xq_out
,
xk_out
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment