Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
f21528e7
Commit
f21528e7
authored
Apr 01, 2025
by
gushiqiao
Committed by
Yang Yong(雍洋)
Apr 08, 2025
Browse files
Support q8f kernel and fix bugs. (#6)
Co-authored-by:
gushiqiao
<
gushiqiao@sensetime.com
>
parent
bd0f840f
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
185 additions
and
98 deletions
+185
-98
README.md
README.md
+1
-1
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+91
-1
lightx2v/text2v/models/networks/wan/infer/transformer_infer.py
...x2v/text2v/models/networks/wan/infer/transformer_infer.py
+6
-22
lightx2v/text2v/models/networks/wan/infer/utils.py
lightx2v/text2v/models/networks/wan/infer/utils.py
+0
-8
lightx2v/text2v/models/networks/wan/weights/pre_weights.py
lightx2v/text2v/models/networks/wan/weights/pre_weights.py
+3
-3
lightx2v/text2v/models/networks/wan/weights/transformer_weights.py
...text2v/models/networks/wan/weights/transformer_weights.py
+19
-15
lightx2v/text2v/models/schedulers/wan/feature_caching/scheduler.py
...text2v/models/schedulers/wan/feature_caching/scheduler.py
+55
-38
scripts/run_hunyuan_t2v.sh
scripts/run_hunyuan_t2v.sh
+0
-0
scripts/run_hunyuan_t2v_dist.sh
scripts/run_hunyuan_t2v_dist.sh
+1
-1
scripts/run_hunyuan_t2v_taylorseer.sh
scripts/run_hunyuan_t2v_taylorseer.sh
+0
-0
scripts/run_wan_i2v.sh
scripts/run_wan_i2v.sh
+0
-0
scripts/run_wan_t2v.sh
scripts/run_wan_t2v.sh
+9
-9
scripts/run_wan_t2v_dist.sh
scripts/run_wan_t2v_dist.sh
+0
-0
No files found.
README.md
View file @
f21528e7
...
...
@@ -21,7 +21,7 @@ docker run --gpus all -itd --ipc=host --name [name] -v /mnt:/mnt --entrypoint /b
```
git clone https://gitlab.bj.sensetime.com/video-gen/lightx2v.git
cd lightx2v
cd lightx2v
/scripts
# Modify the parameters of the running script
bash run_hunyuan_t2v.sh
...
...
lightx2v/common/ops/mm/mm_weight.py
View file @
f21528e7
...
...
@@ -3,6 +3,10 @@ from abc import ABCMeta, abstractmethod
from
vllm
import
_custom_ops
as
ops
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.utils.quant_utils
import
IntegerQuantizer
,
FloatQuantizer
try
:
import
q8_kernels.functional
as
Q8F
except
ImportError
:
Q8F
=
None
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
...
...
@@ -113,7 +117,7 @@ class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate):
@
MM_WEIGHT_REGISTER
(
'W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm'
)
class
MMWeightW
fp
8channelA
fp
8channeldynamicVllm
(
MMWeightTemplate
):
class
MMWeightW
int
8channelA
int
8channeldynamicVllm
(
MMWeightTemplate
):
'''
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm
...
...
@@ -159,6 +163,92 @@ class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate):
self
.
bias
=
self
.
bias
.
cuda
()
@
MM_WEIGHT_REGISTER
(
'W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F'
)
class
MMWeightWint8channelAint8channeldynamicQ8F
(
MMWeightTemplate
):
'''
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: Q8F
'''
def
__init__
(
self
,
weight_name
,
bias_name
):
super
().
__init__
(
weight_name
,
bias_name
)
def
load
(
self
,
weight_dict
):
if
self
.
config
.
get
(
'weight_auto_quant'
,
True
):
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
()
w_quantizer
=
IntegerQuantizer
(
8
,
True
,
'channel'
)
self
.
weight
,
self
.
weight_scale
,
_
=
w_quantizer
.
real_quant_tensor
(
self
.
weight
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
int8
)
self
.
weight_scale
=
self
.
weight_scale
.
to
(
torch
.
float32
)
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
()
self
.
weight_scale
=
weight_dict
[
self
.
weight_name
.
rstrip
(
".weight"
)
+
'.weight_scale'
].
cuda
()
self
.
bias
=
weight_dict
[
self
.
bias_name
].
float
().
cuda
()
if
self
.
bias_name
is
not
None
else
None
def
apply
(
self
,
input_tensor
,
act
=
None
):
qinput
,
x_scale
,
_
=
ops
.
scaled_int8_quant
(
input_tensor
,
scale
=
None
,
azp
=
None
,
symmetric
=
True
)
output_tensor
=
Q8F
.
linear
.
q8_linear
(
qinput
,
self
.
weight
,
self
.
bias
,
x_scale
,
self
.
weight_scale
,
fuse_gelu
=
False
,
out_dtype
=
torch
.
bfloat16
)
return
output_tensor
.
squeeze
(
0
)
def
to_cpu
(
self
):
self
.
weight
=
self
.
weight
.
cpu
()
self
.
weight_scale
=
self
.
weight_scale
.
cpu
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cpu
()
def
to_cuda
(
self
):
self
.
weight
=
self
.
weight
.
cuda
()
self
.
weight_scale
=
self
.
weight_scale
.
cuda
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cuda
()
@
MM_WEIGHT_REGISTER
(
'W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F'
)
class
MMWeightWfp8channelAfp8channeldynamicQ8F
(
MMWeightTemplate
):
'''
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: Q8F
'''
def
__init__
(
self
,
weight_name
,
bias_name
):
super
().
__init__
(
weight_name
,
bias_name
)
def
load
(
self
,
weight_dict
):
if
self
.
config
.
get
(
'weight_auto_quant'
,
True
):
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
()
w_quantizer
=
FloatQuantizer
(
'e4m3'
,
True
,
'channel'
)
self
.
weight
,
self
.
weight_scale
,
_
=
w_quantizer
.
real_quant_tensor
(
self
.
weight
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
float8_e4m3fn
)
self
.
weight_scale
=
self
.
weight_scale
.
to
(
torch
.
float32
)
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
()
self
.
weight_scale
=
weight_dict
[
self
.
weight_name
.
rstrip
(
".weight"
)
+
'.weight_scale'
].
cuda
()
self
.
bias
=
weight_dict
[
self
.
bias_name
].
float
().
cuda
()
if
self
.
bias_name
is
not
None
else
None
def
apply
(
self
,
input_tensor
):
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input_tensor
,
None
,
scale_ub
=
None
,
use_per_token_if_dynamic
=
True
)
output_tensor
=
Q8F
.
linear
.
fp8_linear
(
qinput
,
self
.
weight
,
self
.
bias
,
x_scale
,
self
.
weight_scale
,
out_dtype
=
torch
.
bfloat16
)
return
output_tensor
.
squeeze
(
0
)
def
to_cpu
(
self
):
self
.
weight
=
self
.
weight
.
cpu
()
self
.
weight_scale
=
self
.
weight_scale
.
cpu
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cpu
()
def
to_cuda
(
self
):
self
.
weight
=
self
.
weight
.
cuda
()
self
.
weight_scale
=
self
.
weight_scale
.
cuda
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cuda
()
if
__name__
==
'__main__'
:
weight_dict
=
{
'xx.weight'
:
torch
.
randn
(
8192
,
4096
).
to
(
torch
.
float8_e4m3fn
),
...
...
lightx2v/text2v/models/networks/wan/infer/transformer_infer.py
View file @
f21528e7
import
torch
from
.utils
import
compute_freqs
,
compute_freqs_dist
,
apply_rotary_emb
,
rms_norm
from
.utils
import
compute_freqs
,
compute_freqs_dist
,
apply_rotary_emb
from
lightx2v.attentions
import
attention
...
...
@@ -60,15 +60,8 @@ class WanTransformerInfer:
norm1_out
=
(
norm1_out
*
(
1
+
embed0
[
1
])
+
embed0
[
0
]).
squeeze
(
0
)
s
,
n
,
d
=
*
norm1_out
.
shape
[:
1
],
self
.
num_heads
,
self
.
head_dim
q
=
rms_norm
(
weights
.
self_attn_q
.
apply
(
norm1_out
),
weights
.
self_attn_norm_q_weight
,
1e-6
).
view
(
s
,
n
,
d
)
k
=
rms_norm
(
weights
.
self_attn_k
.
apply
(
norm1_out
),
weights
.
self_attn_norm_k_weight
,
1e-6
).
view
(
s
,
n
,
d
)
q
=
weights
.
self_attn_norm_q
.
apply
(
weights
.
self_attn_q
.
apply
(
norm1_out
)).
view
(
s
,
n
,
d
)
k
=
weights
.
self_attn_norm_k
.
apply
(
weights
.
self_attn_k
.
apply
(
norm1_out
)).
view
(
s
,
n
,
d
)
v
=
weights
.
self_attn_v
.
apply
(
norm1_out
).
view
(
s
,
n
,
d
)
if
not
self
.
parallel_attention
:
...
...
@@ -114,21 +107,12 @@ class WanTransformerInfer:
context
=
context
[
257
:]
n
,
d
=
self
.
num_heads
,
self
.
head_dim
q
=
rms_norm
(
weights
.
cross_attn_q
.
apply
(
norm3_out
),
weights
.
cross_attn_norm_q_weight
,
1e-6
).
view
(
-
1
,
n
,
d
)
k
=
rms_norm
(
weights
.
cross_attn_k
.
apply
(
context
),
weights
.
cross_attn_norm_k_weight
,
1e-6
).
view
(
-
1
,
n
,
d
)
q
=
weights
.
cross_attn_norm_q
.
apply
(
weights
.
cross_attn_q
.
apply
(
norm3_out
)).
view
(
-
1
,
n
,
d
)
k
=
weights
.
cross_attn_norm_k
.
apply
(
weights
.
cross_attn_k
.
apply
(
context
)).
view
(
-
1
,
n
,
d
)
v
=
weights
.
cross_attn_v
.
apply
(
context
).
view
(
-
1
,
n
,
d
)
if
self
.
task
==
'i2v'
:
k_img
=
rms_norm
(
weights
.
cross_attn_k_img
.
apply
(
context_img
),
weights
.
cross_attn_norm_k_img_weight
,
1e-6
).
view
(
-
1
,
n
,
d
)
k_img
=
weights
.
cross_attn_norm_k_img
.
apply
(
weights
.
cross_attn_k_img
.
apply
(
context_img
)).
view
(
-
1
,
n
,
d
)
v_img
=
weights
.
cross_attn_v_img
.
apply
(
context_img
).
view
(
-
1
,
n
,
d
)
cu_seqlens_q
,
cu_seqlens_k
,
lq
,
lk
=
self
.
_calculate_q_k_len
(
...
...
lightx2v/text2v/models/networks/wan/infer/utils.py
View file @
f21528e7
...
...
@@ -4,14 +4,6 @@ import torch.cuda.amp as amp
import
torch.distributed
as
dist
def
rms_norm
(
x
,
weight
,
eps
):
x
=
x
.
contiguous
()
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
orig_shape
[
-
1
])
x
=
sgl_kernel
.
rmsnorm
(
x
,
weight
,
eps
).
view
(
orig_shape
)
return
x
def
compute_freqs
(
c
,
grid_sizes
,
freqs
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
...
...
lightx2v/text2v/models/networks/wan/weights/pre_weights.py
View file @
f21528e7
...
...
@@ -42,16 +42,16 @@ class WanPreWeights:
self
.
weight_list
.
append
(
self
.
proj_4
)
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
LNWeightTemplate
)
or
isinstance
(
mm_weight
,
Conv3dWeightTemplate
):
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
Conv3dWeightTemplate
)
)
:
mm_weight
.
set_config
(
self
.
config
[
'mm_config'
])
mm_weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
LNWeightTemplate
)
or
isinstance
(
mm_weight
,
Conv3dWeightTemplate
):
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
Conv3dWeightTemplate
)
)
:
mm_weight
.
to_cpu
()
def
to_cuda
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
LNWeightTemplate
)
or
isinstance
(
mm_weight
,
Conv3dWeightTemplate
):
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
Conv3dWeightTemplate
)
)
:
mm_weight
.
to_cuda
()
\ No newline at end of file
lightx2v/text2v/models/networks/wan/weights/transformer_weights.py
View file @
f21528e7
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
RMS_WEIGHT_REGISTER
from
lightx2v.common.ops.mm.mm_weight
import
MMWeightTemplate
from
lightx2v.common.ops.norm.layer_norm_weight
import
LNWeightTemplate
from
lightx2v.common.ops.norm.rms_norm_weight
import
RMSWeightTemplate
class
WanTransformerWeights
:
...
...
@@ -42,15 +43,17 @@ class WanTransformerAttentionBlock:
self
.
self_attn_k
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
'blocks.
{
self
.
block_index
}
.self_attn.k.weight'
,
f
'blocks.
{
self
.
block_index
}
.self_attn.k.bias'
)
self
.
self_attn_v
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
'blocks.
{
self
.
block_index
}
.self_attn.v.weight'
,
f
'blocks.
{
self
.
block_index
}
.self_attn.v.bias'
)
self
.
self_attn_o
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
'blocks.
{
self
.
block_index
}
.self_attn.o.weight'
,
f
'blocks.
{
self
.
block_index
}
.self_attn.o.bias'
)
self
.
self_attn_norm_q_weight
=
weight_dict
[
f
'blocks.
{
self
.
block_index
}
.self_attn.norm_q.weight'
]
self
.
self_attn_norm_k_weight
=
weight_dict
[
f
'blocks.
{
self
.
block_index
}
.self_attn.norm_k.weight'
]
self
.
norm3
=
LN_WEIGHT_REGISTER
[
self
.
mm_type
](
f
'blocks.
{
self
.
block_index
}
.norm3.weight'
,
f
'blocks.
{
self
.
block_index
}
.norm3.bias'
,
eps
=
1e-6
)
self
.
self_attn_norm_q
=
RMS_WEIGHT_REGISTER
[
'sgl-kernel'
](
f
'blocks.
{
self
.
block_index
}
.self_attn.norm_q.weight'
)
self
.
self_attn_norm_k
=
RMS_WEIGHT_REGISTER
[
'sgl-kernel'
](
f
'blocks.
{
self
.
block_index
}
.self_attn.norm_k.weight'
)
self
.
norm3
=
LN_WEIGHT_REGISTER
[
'Default'
](
f
'blocks.
{
self
.
block_index
}
.norm3.weight'
,
f
'blocks.
{
self
.
block_index
}
.norm3.bias'
,
eps
=
1e-6
)
self
.
cross_attn_q
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
'blocks.
{
self
.
block_index
}
.cross_attn.q.weight'
,
f
'blocks.
{
self
.
block_index
}
.cross_attn.q.bias'
)
self
.
cross_attn_k
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
'blocks.
{
self
.
block_index
}
.cross_attn.k.weight'
,
f
'blocks.
{
self
.
block_index
}
.cross_attn.k.bias'
)
self
.
cross_attn_v
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
'blocks.
{
self
.
block_index
}
.cross_attn.v.weight'
,
f
'blocks.
{
self
.
block_index
}
.cross_attn.v.bias'
)
self
.
cross_attn_o
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
'blocks.
{
self
.
block_index
}
.cross_attn.o.weight'
,
f
'blocks.
{
self
.
block_index
}
.cross_attn.o.bias'
)
self
.
cross_attn_norm_q_weight
=
weight_dict
[
f
'blocks.
{
self
.
block_index
}
.cross_attn.norm_q.weight'
]
self
.
cross_attn_norm_k_weight
=
weight_dict
[
f
'blocks.
{
self
.
block_index
}
.cross_attn.norm_k.weight'
]
self
.
cross_attn_norm_q
=
RMS_WEIGHT_REGISTER
[
'sgl-kernel'
](
f
'blocks.
{
self
.
block_index
}
.cross_attn.norm_q.weight'
)
self
.
cross_attn_norm_k
=
RMS_WEIGHT_REGISTER
[
'sgl-kernel'
](
f
'blocks.
{
self
.
block_index
}
.cross_attn.norm_k.weight'
)
self
.
ffn_0
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
'blocks.
{
self
.
block_index
}
.ffn.0.weight'
,
f
'blocks.
{
self
.
block_index
}
.ffn.0.bias'
)
self
.
ffn_2
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
'blocks.
{
self
.
block_index
}
.ffn.2.weight'
,
f
'blocks.
{
self
.
block_index
}
.ffn.2.bias'
)
self
.
modulation
=
weight_dict
[
f
'blocks.
{
self
.
block_index
}
.modulation'
]
...
...
@@ -60,15 +63,15 @@ class WanTransformerAttentionBlock:
self
.
self_attn_k
,
self
.
self_attn_v
,
self
.
self_attn_o
,
self
.
self_attn_norm_q
_weight
,
self
.
self_attn_norm_k
_weight
,
self
.
self_attn_norm_q
,
self
.
self_attn_norm_k
,
self
.
norm3
,
self
.
cross_attn_q
,
self
.
cross_attn_k
,
self
.
cross_attn_v
,
self
.
cross_attn_o
,
self
.
cross_attn_norm_q
_weight
,
self
.
cross_attn_norm_k
_weight
,
self
.
cross_attn_norm_q
,
self
.
cross_attn_norm_k
,
self
.
ffn_0
,
self
.
ffn_2
,
self
.
modulation
,
...
...
@@ -77,26 +80,27 @@ class WanTransformerAttentionBlock:
if
self
.
task
==
'i2v'
:
self
.
cross_attn_k_img
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
'blocks.
{
self
.
block_index
}
.cross_attn.k_img.weight'
,
f
'blocks.
{
self
.
block_index
}
.cross_attn.k_img.bias'
)
self
.
cross_attn_v_img
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
'blocks.
{
self
.
block_index
}
.cross_attn.v_img.weight'
,
f
'blocks.
{
self
.
block_index
}
.cross_attn.v_img.bias'
)
self
.
cross_attn_norm_k_img_weight
=
weight_dict
[
f
'blocks.
{
self
.
block_index
}
.cross_attn.norm_k_img.weight'
]
# self.cross_attn_norm_k_img_weight = weight_dict[f'blocks.{self.block_index}.cross_attn.norm_k_img.weight']
self
.
cross_attn_norm_k_img
=
RMS_WEIGHT_REGISTER
[
'sgl-kernel'
](
f
'blocks.
{
self
.
block_index
}
.cross_attn.norm_k_img.weight'
)
self
.
weight_list
.
append
(
self
.
cross_attn_k_img
)
self
.
weight_list
.
append
(
self
.
cross_attn_v_img
)
self
.
weight_list
.
append
(
self
.
cross_attn_norm_k_img
_weight
)
self
.
weight_list
.
append
(
self
.
cross_attn_norm_k_img
)
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
LN
WeightTemplate
):
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
RMS
WeightTemplate
)
)
:
mm_weight
.
set_config
(
self
.
config
[
'mm_config'
])
mm_weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
LN
WeightTemplate
):
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
RMS
WeightTemplate
)
)
:
mm_weight
.
to_cpu
()
else
:
mm_weight
.
cpu
()
def
to_cuda
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
LN
WeightTemplate
):
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
RMS
WeightTemplate
)
)
:
mm_weight
.
to_cuda
()
else
:
mm_weight
.
cuda
()
\ No newline at end of file
lightx2v/text2v/models/schedulers/wan/feature_caching/scheduler.py
View file @
f21528e7
...
...
@@ -16,41 +16,58 @@ class WanSchedulerFeatureCaching(WanScheduler):
self
.
previous_residual_odd
=
None
self
.
use_ret_steps
=
self
.
args
.
use_ret_steps
if
self
.
use_ret_steps
:
if
self
.
args
.
target_width
==
480
or
self
.
args
.
target_height
==
480
:
self
.
coefficients
=
[
2.57151496e05
,
-
3.54229917e04
,
1.40286849e03
,
-
1.35890334e01
,
1.32517977e-01
,
]
if
self
.
args
.
target_width
==
720
or
self
.
args
.
target_height
==
720
:
self
.
coefficients
=
[
8.10705460e03
,
2.13393892e03
,
-
3.72934672e02
,
1.66203073e01
,
-
4.17769401e-02
,
]
self
.
ret_steps
=
5
*
2
self
.
cutoff_steps
=
self
.
args
.
infer_steps
*
2
else
:
if
self
.
args
.
target_width
==
480
or
self
.
args
.
target_height
==
480
:
self
.
coefficients
=
[
-
3.02331670e02
,
2.23948934e02
,
-
5.25463970e01
,
5.87348440e00
,
-
2.01973289e-01
,
]
if
self
.
args
.
target_width
==
720
or
self
.
args
.
target_height
==
720
:
self
.
coefficients
=
[
-
114.36346466
,
65.26524496
,
-
18.82220707
,
4.91518089
,
-
0.23412683
,
]
self
.
ret_steps
=
1
*
2
self
.
cutoff_steps
=
self
.
args
.
infer_steps
*
2
-
2
\ No newline at end of file
if
self
.
args
.
task
==
'i2v'
:
if
self
.
use_ret_steps
:
if
self
.
args
.
target_width
==
480
or
self
.
args
.
target_height
==
480
:
self
.
coefficients
=
[
2.57151496e05
,
-
3.54229917e04
,
1.40286849e03
,
-
1.35890334e01
,
1.32517977e-01
,
]
if
self
.
args
.
target_width
==
720
or
self
.
args
.
target_height
==
720
:
self
.
coefficients
=
[
8.10705460e03
,
2.13393892e03
,
-
3.72934672e02
,
1.66203073e01
,
-
4.17769401e-02
,
]
self
.
ret_steps
=
5
*
2
self
.
cutoff_steps
=
self
.
args
.
infer_steps
*
2
else
:
if
self
.
args
.
target_width
==
480
or
self
.
args
.
target_height
==
480
:
self
.
coefficients
=
[
-
3.02331670e02
,
2.23948934e02
,
-
5.25463970e01
,
5.87348440e00
,
-
2.01973289e-01
,
]
if
self
.
args
.
target_width
==
720
or
self
.
args
.
target_height
==
720
:
self
.
coefficients
=
[
-
114.36346466
,
65.26524496
,
-
18.82220707
,
4.91518089
,
-
0.23412683
,
]
self
.
ret_steps
=
1
*
2
self
.
cutoff_steps
=
self
.
args
.
infer_steps
*
2
-
2
elif
self
.
args
.
task
==
't2v'
:
if
self
.
use_ret_steps
:
if
'1.3B'
in
self
.
args
.
model_path
:
self
.
coefficients
=
[
-
5.21862437e+04
,
9.23041404e+03
,
-
5.28275948e+02
,
1.36987616e+01
,
-
4.99875664e-02
]
if
'14B'
in
self
.
args
.
model_path
:
self
.
coefficients
=
[
-
3.03318725e+05
,
4.90537029e+04
,
-
2.65530556e+03
,
5.87365115e+01
,
-
3.15583525e-01
]
self
.
ret_steps
=
5
*
2
self
.
cutoff_steps
=
self
.
args
.
infer_steps
*
2
else
:
if
'1.3B'
in
self
.
args
.
model_path
:
self
.
coefficients
=
[
2.39676752e+03
,
-
1.31110545e+03
,
2.01331979e+02
,
-
8.29855975e+00
,
1.37887774e-01
]
if
'14B'
in
self
.
args
.
model_path
:
self
.
coefficients
=
[
-
5784.54975374
,
5449.50911966
,
-
1811.16591783
,
256.27178429
,
-
13.02252404
]
self
.
ret_steps
=
1
*
2
self
.
cutoff_steps
=
self
.
args
.
infer_steps
*
2
-
2
\ No newline at end of file
run_hunyuan_t2v.sh
→
scripts/
run_hunyuan_t2v.sh
View file @
f21528e7
File moved
run_hunyuan_t2v_dist.sh
→
scripts/
run_hunyuan_t2v_dist.sh
View file @
f21528e7
...
...
@@ -4,7 +4,7 @@ model_path=/workspace/ckpts_link # H800-14
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
torchrun
--nproc_per_node
=
4 main.py
\
torchrun
--nproc_per_node
=
4
../
main.py
\
--model_cls
hunyuan
\
--model_path
$model_path
\
--prompt
"A cat walks on the grass, realistic style."
\
...
...
run_hunyuan_t2v_taylorseer.sh
→
scripts/
run_hunyuan_t2v_taylorseer.sh
View file @
f21528e7
File moved
run_wan_i2v.sh
→
scripts/
run_wan_i2v.sh
View file @
f21528e7
File moved
run_wan_t2v.sh
→
scripts/
run_wan_t2v.sh
View file @
f21528e7
#!/bin/bash
# model_path=/mnt/nvme1/yongyang/models/hy/ckpts # H800-13
# model_path=/workspace/wan/Wan2.1-T2V-1.3B # H800-14
# config_path=/workspace/wan/Wan2.1-T2V-1.3B/config.json
model_path
=
/mnt/nvme0/yongyang/projects/wan/Wan2.1-T2V-1.3B
# H800-14
config_path
=
/mnt/nvme0/yongyang/projects/wan/Wan2.1-T2V-1.3B/config.json
model_path
=
/workspace/wan/Wan2.1-T2V-1.3B
# H800-14
config_path
=
/workspace/wan/Wan2.1-T2V-1.3B/config.json
export
CUDA_VISIBLE_DEVICES
=
0
python main.py
\
python
../
main.py
\
--model_cls
wan2.1
\
--task
t2v
\
--model_path
$model_path
\
...
...
@@ -20,7 +17,10 @@ python main.py \
--seed
42
\
--sample_neg_promp
色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走
\
--config_path
$config_path
\
--save_video_path
./output_lightx2v_seed42.mp4
\
--save_video_path
./output_lightx2v_seed42
_q8f1_teacache
.mp4
\
--sample_guide_scale
6
\
--sample_shift
8
# --mm_config '{"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
\ No newline at end of file
--sample_shift
8
\
# --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F", "weight_auto_quant": true}' \
# --feature_caching Tea \
# --use_ret_steps \
# --teacache_thresh 0.2
\ No newline at end of file
run_wan_t2v_dist.sh
→
scripts/
run_wan_t2v_dist.sh
View file @
f21528e7
File moved
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment