Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
6347b21b
Commit
6347b21b
authored
Apr 29, 2025
by
root
Browse files
Support convert weight to diffusers.
parent
aec90a0d
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
215 additions
and
44 deletions
+215
-44
configs/wan_i2v.json
configs/wan_i2v.json
+2
-0
examples/diffusers/converter.py
examples/diffusers/converter.py
+150
-0
lightx2v/api_server.py
lightx2v/api_server.py
+6
-5
lightx2v/attentions/distributed/comm/ring_comm.py
lightx2v/attentions/distributed/comm/ring_comm.py
+2
-1
lightx2v/attentions/distributed/partial_heads_attn/tests/test_acc.py
...tentions/distributed/partial_heads_attn/tests/test_acc.py
+3
-2
lightx2v/common/backend_infer/trt/common.py
lightx2v/common/backend_infer/trt/common.py
+3
-2
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+4
-3
lightx2v/infer.py
lightx2v/infer.py
+4
-3
lightx2v/models/input_encoders/hf/clip/model.py
lightx2v/models/input_encoders/hf/clip/model.py
+2
-1
lightx2v/models/input_encoders/hf/llama/model.py
lightx2v/models/input_encoders/hf/llama/model.py
+2
-1
lightx2v/models/input_encoders/hf/llava/model.py
lightx2v/models/input_encoders/hf/llava/model.py
+2
-1
lightx2v/models/input_encoders/hf/t5/model.py
lightx2v/models/input_encoders/hf/t5/model.py
+2
-1
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
+2
-1
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+13
-11
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+2
-1
lightx2v/models/runners/graph_runner.py
lightx2v/models/runners/graph_runner.py
+3
-2
lightx2v/models/runners/wan/wan_causal_runner.py
lightx2v/models/runners/wan/wan_causal_runner.py
+7
-6
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+2
-1
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+2
-1
lightx2v/utils/profiler.py
lightx2v/utils/profiler.py
+2
-1
No files found.
configs/wan_i2v.json
View file @
6347b21b
...
@@ -7,6 +7,8 @@
...
@@ -7,6 +7,8 @@
"seed"
:
42
,
"seed"
:
42
,
"sample_guide_scale"
:
5
,
"sample_guide_scale"
:
5
,
"sample_shift"
:
5
,
"sample_shift"
:
5
,
"enable_cfg"
:
true
,
"cpu_offload"
:
false
,
"mm_config"
:
{
"mm_config"
:
{
"mm_type"
:
"W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
,
"mm_type"
:
"W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
,
"weight_auto_quant"
:
true
"weight_auto_quant"
:
true
...
...
examples/diffusers/converter.py
0 → 100755
View file @
6347b21b
import
os
import
re
import
glob
import
json
import
argparse
import
torch
from
safetensors
import
safe_open
,
torch
as
st
from
loguru
import
logger
from
tqdm
import
tqdm
def
get_key_mapping_rules
(
direction
,
model_type
):
if
model_type
==
"wan"
:
unified_rules
=
[
{
"forward"
:
(
r
"^head\.head$"
,
"proj_out"
),
"backward"
:
(
r
"^proj_out$"
,
"head.head"
)},
{
"forward"
:
(
r
"^head\.modulation$"
,
"scale_shift_table"
),
"backward"
:
(
r
"^scale_shift_table$"
,
"head.modulation"
)},
{
"forward"
:
(
r
"^text_embedding\.0\."
,
"condition_embedder.text_embedder.linear_1."
),
"backward"
:
(
r
"^condition_embedder.text_embedder.linear_1\."
,
"text_embedding.0."
)},
{
"forward"
:
(
r
"^text_embedding\.2\."
,
"condition_embedder.text_embedder.linear_2."
),
"backward"
:
(
r
"^condition_embedder.text_embedder.linear_2\."
,
"text_embedding.2."
)},
{
"forward"
:
(
r
"^time_embedding\.0\."
,
"condition_embedder.time_embedder.linear_1."
),
"backward"
:
(
r
"^condition_embedder.time_embedder.linear_1\."
,
"time_embedding.0."
)},
{
"forward"
:
(
r
"^time_embedding\.2\."
,
"condition_embedder.time_embedder.linear_2."
),
"backward"
:
(
r
"^condition_embedder.time_embedder.linear_2\."
,
"time_embedding.2."
)},
{
"forward"
:
(
r
"^time_projection\.1\."
,
"condition_embedder.time_proj."
),
"backward"
:
(
r
"^condition_embedder.time_proj\."
,
"time_projection.1."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.self_attn\.q\."
,
r
"blocks.\1.attn1.to_q."
),
"backward"
:
(
r
"blocks\.(\d+)\.attn1\.to_q\."
,
r
"blocks.\1.self_attn.q."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.self_attn\.k\."
,
r
"blocks.\1.attn1.to_k."
),
"backward"
:
(
r
"blocks\.(\d+)\.attn1\.to_k\."
,
r
"blocks.\1.self_attn.k."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.self_attn\.v\."
,
r
"blocks.\1.attn1.to_v."
),
"backward"
:
(
r
"blocks\.(\d+)\.attn1\.to_v\."
,
r
"blocks.\1.self_attn.v."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.self_attn\.o\."
,
r
"blocks.\1.attn1.to_out.0."
),
"backward"
:
(
r
"blocks\.(\d+)\.attn1\.to_out\.0\."
,
r
"blocks.\1.self_attn.o."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.cross_attn\.q\."
,
r
"blocks.\1.attn2.to_q."
),
"backward"
:
(
r
"blocks\.(\d+)\.attn2\.to_q\."
,
r
"blocks.\1.cross_attn.q."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.cross_attn\.k\."
,
r
"blocks.\1.attn2.to_k."
),
"backward"
:
(
r
"blocks\.(\d+)\.attn2\.to_k\."
,
r
"blocks.\1.cross_attn.k."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.cross_attn\.v\."
,
r
"blocks.\1.attn2.to_v."
),
"backward"
:
(
r
"blocks\.(\d+)\.attn2\.to_v\."
,
r
"blocks.\1.cross_attn.v."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.cross_attn\.o\."
,
r
"blocks.\1.attn2.to_out.0."
),
"backward"
:
(
r
"blocks\.(\d+)\.attn2\.to_out\.0\."
,
r
"blocks.\1.cross_attn.o."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.norm3\."
,
r
"blocks.\1.norm2."
),
"backward"
:
(
r
"blocks\.(\d+)\.norm2\."
,
r
"blocks.\1.norm3."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.ffn\.0\."
,
r
"blocks.\1.ffn.net.0.proj."
),
"backward"
:
(
r
"blocks\.(\d+)\.ffn\.net\.0\.proj\."
,
r
"blocks.\1.ffn.0."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.ffn\.2\."
,
r
"blocks.\1.ffn.net.2."
),
"backward"
:
(
r
"blocks\.(\d+)\.ffn\.net\.2\."
,
r
"blocks.\1.ffn.2."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.modulation\."
,
r
"blocks.\1.scale_shift_table."
),
"backward"
:
(
r
"blocks\.(\d+)\.scale_shift_table(?=\.|$)"
,
r
"blocks.\1.modulation"
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.cross_attn\.k_img\."
,
r
"blocks.\1.attn2.add_k_proj."
),
"backward"
:
(
r
"blocks\.(\d+)\.attn2\.add_k_proj\."
,
r
"blocks.\1.cross_attn.k_img."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.cross_attn\.v_img\."
,
r
"blocks.\1.attn2.add_v_proj."
),
"backward"
:
(
r
"blocks\.(\d+)\.attn2\.add_v_proj\."
,
r
"blocks.\1.cross_attn.v_img."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.cross_attn\.norm_k_img\.weight"
,
r
"blocks.\1.attn2.norm_added_k.weight"
),
"backward"
:
(
r
"blocks\.(\d+)\.attn2\.norm_added_k\.weight"
,
r
"blocks.\1.cross_attn.norm_k_img.weight"
),
},
{
"forward"
:
(
r
"img_emb\.proj\.0\."
,
r
"condition_embedder.image_embedder.norm1."
),
"backward"
:
(
r
"condition_embedder\.image_embedder\.norm1\."
,
r
"img_emb.proj.0."
)},
{
"forward"
:
(
r
"img_emb\.proj\.1\."
,
r
"condition_embedder.image_embedder.ff.net.0.proj."
),
"backward"
:
(
r
"condition_embedder\.image_embedder\.ff\.net\.0\.proj\."
,
r
"img_emb.proj.1."
)},
{
"forward"
:
(
r
"img_emb\.proj\.3\."
,
r
"condition_embedder.image_embedder.ff.net.2."
),
"backward"
:
(
r
"condition_embedder\.image_embedder\.ff\.net\.2\."
,
r
"img_emb.proj.3."
)},
{
"forward"
:
(
r
"img_emb\.proj\.4\."
,
r
"condition_embedder.image_embedder.norm2."
),
"backward"
:
(
r
"condition_embedder\.image_embedder\.norm2\."
,
r
"img_emb.proj.4."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.self_attn\.norm_q\.weight"
,
r
"blocks.\1.attn1.norm_q.weight"
),
"backward"
:
(
r
"blocks\.(\d+)\.attn1\.norm_q\.weight"
,
r
"blocks.\1.self_attn.norm_q.weight"
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.self_attn\.norm_k\.weight"
,
r
"blocks.\1.attn1.norm_k.weight"
),
"backward"
:
(
r
"blocks\.(\d+)\.attn1\.norm_k\.weight"
,
r
"blocks.\1.self_attn.norm_k.weight"
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.cross_attn\.norm_q\.weight"
,
r
"blocks.\1.attn2.norm_q.weight"
),
"backward"
:
(
r
"blocks\.(\d+)\.attn2\.norm_q\.weight"
,
r
"blocks.\1.cross_attn.norm_q.weight"
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.cross_attn\.norm_k\.weight"
,
r
"blocks.\1.attn2.norm_k.weight"
),
"backward"
:
(
r
"blocks\.(\d+)\.attn2\.norm_k\.weight"
,
r
"blocks.\1.cross_attn.norm_k.weight"
)},
# head projection mapping
{
"forward"
:
(
r
"^head\.head\."
,
"proj_out."
),
"backward"
:
(
r
"^proj_out\."
,
"head.head."
)},
]
if
direction
==
"forward"
:
return
[
rule
[
"forward"
]
for
rule
in
unified_rules
]
elif
direction
==
"backward"
:
return
[
rule
[
"backward"
]
for
rule
in
unified_rules
]
else
:
raise
ValueError
(
f
"Invalid direction:
{
direction
}
"
)
else
:
raise
ValueError
(
f
"Unsupported model type:
{
model_type
}
"
)
def
convert_weights
(
args
):
if
os
.
path
.
isdir
(
args
.
source
):
src_files
=
glob
.
glob
(
os
.
path
.
join
(
args
.
source
,
"*.safetensors"
),
recursive
=
True
)
elif
args
.
source
.
endswith
((
".pth"
,
".safetensors"
,
"pt"
)):
src_files
=
[
args
.
source
]
else
:
raise
ValueError
(
"Invalid input path"
)
merged_weights
=
{}
logger
.
info
(
f
"Processing source files:
{
src_files
}
"
)
for
file_path
in
tqdm
(
src_files
,
desc
=
"Loading weights"
):
logger
.
info
(
f
"Loading weights from:
{
file_path
}
"
)
if
file_path
.
endswith
(
".pt"
)
or
file_path
.
endswith
(
".pth"
):
weights
=
torch
.
load
(
file_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
elif
file_path
.
endswith
(
".safetensors"
):
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
weights
=
{
k
:
f
.
get_tensor
(
k
)
for
k
in
f
.
keys
()}
duplicate_keys
=
set
(
weights
.
keys
())
&
set
(
merged_weights
.
keys
())
if
duplicate_keys
:
raise
ValueError
(
f
"Duplicate keys found:
{
duplicate_keys
}
in file
{
file_path
}
"
)
merged_weights
.
update
(
weights
)
rules
=
get_key_mapping_rules
(
args
.
direction
,
args
.
model_type
)
converted_weights
=
{}
logger
.
info
(
"Converting keys..."
)
for
key
in
tqdm
(
merged_weights
.
keys
(),
desc
=
"Converting keys"
):
new_key
=
key
for
pattern
,
replacement
in
rules
:
new_key
=
re
.
sub
(
pattern
,
replacement
,
new_key
)
converted_weights
[
new_key
]
=
merged_weights
[
key
]
os
.
makedirs
(
args
.
output
,
exist_ok
=
True
)
base_name
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
args
.
source
))[
0
]
if
args
.
source
.
endswith
((
".pth"
,
".safetensors"
))
else
"converted_model"
index
=
{
"metadata"
:
{
"total_size"
:
0
},
"weight_map"
:
{}}
chunk_idx
=
0
current_chunk
=
{}
for
idx
,
(
k
,
v
)
in
tqdm
(
enumerate
(
converted_weights
.
items
()),
desc
=
"Saving chunks"
):
current_chunk
[
k
]
=
v
if
(
idx
+
1
)
%
args
.
chunk_size
==
0
and
args
.
chunk_size
>
0
:
output_filename
=
f
"
{
base_name
}
_part
{
chunk_idx
}
.safetensors"
output_path
=
os
.
path
.
join
(
args
.
output
,
output_filename
)
logger
.
info
(
f
"Saving chunk to:
{
output_path
}
"
)
st
.
save_file
(
current_chunk
,
output_path
)
for
key
in
current_chunk
:
index
[
"weight_map"
][
key
]
=
output_filename
index
[
"metadata"
][
"total_size"
]
+=
os
.
path
.
getsize
(
output_path
)
current_chunk
=
{}
chunk_idx
+=
1
if
current_chunk
:
output_filename
=
f
"
{
base_name
}
_part
{
chunk_idx
}
.safetensors"
output_path
=
os
.
path
.
join
(
args
.
output
,
output_filename
)
logger
.
info
(
f
"Saving final chunk to:
{
output_path
}
"
)
st
.
save_file
(
current_chunk
,
output_path
)
for
key
in
current_chunk
:
index
[
"weight_map"
][
key
]
=
output_filename
index
[
"metadata"
][
"total_size"
]
+=
os
.
path
.
getsize
(
output_path
)
# Save index file
index_path
=
os
.
path
.
join
(
args
.
output
,
"diffusion_pytorch_model.safetensors.index.json"
)
with
open
(
index_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
index
,
f
,
indent
=
2
)
logger
.
info
(
f
"Index file written to:
{
index_path
}
"
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Model weight format converter"
)
parser
.
add_argument
(
"-s"
,
"--source"
,
required
=
True
,
help
=
"Input path (file or directory)"
)
parser
.
add_argument
(
"-o"
,
"--output"
,
required
=
True
,
help
=
"Output directory path"
)
parser
.
add_argument
(
"-d"
,
"--direction"
,
choices
=
[
"forward"
,
"backward"
],
default
=
"forward"
,
help
=
"Conversion direction: forward = 'lightx2v' -> 'Diffusers', backward = reverse"
)
parser
.
add_argument
(
"-c"
,
"--chunk-size"
,
type
=
int
,
default
=
100
,
help
=
"Chunk size for saving (only applies to forward), 0 = no chunking"
)
parser
.
add_argument
(
"-t"
,
"--model_type"
,
choices
=
[
"wan"
],
default
=
"wan"
,
help
=
"Model type"
)
args
=
parser
.
parse_args
()
if
os
.
path
.
isfile
(
args
.
output
):
raise
ValueError
(
"Output path must be a directory, not a file"
)
logger
.
info
(
"Starting model weight conversion..."
)
convert_weights
(
args
)
logger
.
info
(
f
"Conversion completed! Files saved to:
{
args
.
output
}
"
)
if
__name__
==
"__main__"
:
main
()
lightx2v/api_server.py
View file @
6347b21b
...
@@ -4,6 +4,7 @@ import psutil
...
@@ -4,6 +4,7 @@ import psutil
import
argparse
import
argparse
from
fastapi
import
FastAPI
,
Request
from
fastapi
import
FastAPI
,
Request
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
loguru
import
logger
import
uvicorn
import
uvicorn
import
json
import
json
import
asyncio
import
asyncio
...
@@ -26,15 +27,15 @@ def kill_all_related_processes():
...
@@ -26,15 +27,15 @@ def kill_all_related_processes():
try
:
try
:
child
.
kill
()
child
.
kill
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"Failed to kill child process
{
child
.
pid
}
:
{
e
}
"
)
logger
.
info
(
f
"Failed to kill child process
{
child
.
pid
}
:
{
e
}
"
)
try
:
try
:
current_process
.
kill
()
current_process
.
kill
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"Failed to kill main process:
{
e
}
"
)
logger
.
info
(
f
"Failed to kill main process:
{
e
}
"
)
def
signal_handler
(
sig
,
frame
):
def
signal_handler
(
sig
,
frame
):
print
(
"
\n
Received Ctrl+C, shutting down all related processes..."
)
logger
.
info
(
"
\n
Received Ctrl+C, shutting down all related processes..."
)
kill_all_related_processes
()
kill_all_related_processes
()
sys
.
exit
(
0
)
sys
.
exit
(
0
)
...
@@ -79,11 +80,11 @@ if __name__ == "__main__":
...
@@ -79,11 +80,11 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
f
"args:
{
args
}
"
)
logger
.
info
(
f
"args:
{
args
}
"
)
with
ProfilingContext
(
"Init Server Cost"
):
with
ProfilingContext
(
"Init Server Cost"
):
config
=
set_config
(
args
)
config
=
set_config
(
args
)
print
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
runner
=
init_runner
(
config
)
runner
=
init_runner
(
config
)
uvicorn
.
run
(
app
,
host
=
"0.0.0.0"
,
port
=
config
.
port
,
reload
=
False
,
workers
=
1
)
uvicorn
.
run
(
app
,
host
=
"0.0.0.0"
,
port
=
config
.
port
,
reload
=
False
,
workers
=
1
)
lightx2v/attentions/distributed/comm/ring_comm.py
View file @
6347b21b
from
typing
import
Optional
from
typing
import
Optional
from
loguru
import
logger
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -21,7 +22,7 @@ class RingComm:
...
@@ -21,7 +22,7 @@ class RingComm:
def
send_recv
(
self
,
to_send
:
torch
.
Tensor
,
recv_tensor
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
def
send_recv
(
self
,
to_send
:
torch
.
Tensor
,
recv_tensor
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
recv_tensor
is
None
:
if
recv_tensor
is
None
:
res
=
torch
.
empty_like
(
to_send
)
res
=
torch
.
empty_like
(
to_send
)
#
print
(f"send_recv: empty_like {to_send.shape}")
#
logger.info
(f"send_recv: empty_like {to_send.shape}")
else
:
else
:
res
=
recv_tensor
res
=
recv_tensor
...
...
lightx2v/attentions/distributed/partial_heads_attn/tests/test_acc.py
View file @
6347b21b
...
@@ -2,6 +2,7 @@ import torch
...
@@ -2,6 +2,7 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
lightx2v.attentions
import
attention
from
lightx2v.attentions
import
attention
from
lightx2v.utils.utils
import
seed_all
from
lightx2v.utils.utils
import
seed_all
from
loguru
import
logger
seed_all
(
42
)
seed_all
(
42
)
...
@@ -65,10 +66,10 @@ def test_part_head():
...
@@ -65,10 +66,10 @@ def test_part_head():
# 验证结果一致性
# 验证结果一致性
if
cur_rank
==
0
:
if
cur_rank
==
0
:
# import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
print
(
"Outputs match:"
,
torch
.
allclose
(
single_gpu_output
,
combined_output
,
rtol
=
1e-3
,
atol
=
1e-3
))
logger
.
info
(
"Outputs match:"
,
torch
.
allclose
(
single_gpu_output
,
combined_output
,
rtol
=
1e-3
,
atol
=
1e-3
))
# # 验证结果一致性
# # 验证结果一致性
#
print
("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3))
#
logger.info
("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
lightx2v/common/backend_infer/trt/common.py
View file @
6347b21b
...
@@ -20,6 +20,7 @@ import os
...
@@ -20,6 +20,7 @@ import os
import
tensorrt
as
trt
import
tensorrt
as
trt
from
.common_runtime
import
*
from
.common_runtime
import
*
from
loguru
import
logger
try
:
try
:
# Sometimes python does not understand FileNotFoundError
# Sometimes python does not understand FileNotFoundError
...
@@ -67,11 +68,11 @@ def find_sample_data(description="Runs a TensorRT Python sample", subfolder="",
...
@@ -67,11 +68,11 @@ def find_sample_data(description="Runs a TensorRT Python sample", subfolder="",
data_path
=
os
.
path
.
join
(
data_dir
,
subfolder
)
data_path
=
os
.
path
.
join
(
data_dir
,
subfolder
)
if
not
os
.
path
.
exists
(
data_path
):
if
not
os
.
path
.
exists
(
data_path
):
if
data_dir
!=
kDEFAULT_DATA_ROOT
:
if
data_dir
!=
kDEFAULT_DATA_ROOT
:
print
(
"WARNING: "
+
data_path
+
" does not exist. Trying "
+
data_dir
+
" instead."
)
logger
.
info
(
"WARNING: "
+
data_path
+
" does not exist. Trying "
+
data_dir
+
" instead."
)
data_path
=
data_dir
data_path
=
data_dir
# Make sure data directory exists.
# Make sure data directory exists.
if
not
(
os
.
path
.
exists
(
data_path
))
and
data_dir
!=
kDEFAULT_DATA_ROOT
:
if
not
(
os
.
path
.
exists
(
data_path
))
and
data_dir
!=
kDEFAULT_DATA_ROOT
:
print
(
"WARNING: {:} does not exist. Please provide the correct data path with the -d option."
.
format
(
data_path
))
logger
.
info
(
"WARNING: {:} does not exist. Please provide the correct data path with the -d option."
.
format
(
data_path
))
return
data_path
return
data_path
data_paths
=
[
get_data_path
(
data_dir
)
for
data_dir
in
args
.
datadir
]
data_paths
=
[
get_data_path
(
data_dir
)
for
data_dir
in
args
.
datadir
]
...
...
lightx2v/common/ops/mm/mm_weight.py
View file @
6347b21b
...
@@ -4,6 +4,7 @@ from vllm import _custom_ops as ops
...
@@ -4,6 +4,7 @@ from vllm import _custom_ops as ops
import
sgl_kernel
import
sgl_kernel
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.utils.quant_utils
import
IntegerQuantizer
,
FloatQuantizer
from
lightx2v.utils.quant_utils
import
IntegerQuantizer
,
FloatQuantizer
from
loguru
import
logger
try
:
try
:
import
q8_kernels.functional
as
Q8F
import
q8_kernels.functional
as
Q8F
...
@@ -461,7 +462,7 @@ if __name__ == "__main__":
...
@@ -461,7 +462,7 @@ if __name__ == "__main__":
mm_weight
.
load
(
weight_dict
)
mm_weight
.
load
(
weight_dict
)
input_tensor
=
torch
.
randn
(
1024
,
4096
).
to
(
torch
.
bfloat16
).
cuda
()
input_tensor
=
torch
.
randn
(
1024
,
4096
).
to
(
torch
.
bfloat16
).
cuda
()
output_tensor
=
mm_weight
.
apply
(
input_tensor
)
output_tensor
=
mm_weight
.
apply
(
input_tensor
)
print
(
output_tensor
.
shape
)
logger
.
info
(
output_tensor
.
shape
)
weight_dict
=
{
weight_dict
=
{
"xx.weight"
:
torch
.
randn
(
8192
,
4096
),
"xx.weight"
:
torch
.
randn
(
8192
,
4096
),
...
@@ -473,7 +474,7 @@ if __name__ == "__main__":
...
@@ -473,7 +474,7 @@ if __name__ == "__main__":
mm_weight
.
load
(
weight_dict
)
mm_weight
.
load
(
weight_dict
)
input_tensor
=
torch
.
randn
(
1024
,
4096
).
to
(
torch
.
bfloat16
).
cuda
()
input_tensor
=
torch
.
randn
(
1024
,
4096
).
to
(
torch
.
bfloat16
).
cuda
()
output_tensor
=
mm_weight
.
apply
(
input_tensor
)
output_tensor
=
mm_weight
.
apply
(
input_tensor
)
print
(
output_tensor
.
shape
)
logger
.
info
(
output_tensor
.
shape
)
weight_dict
=
{
weight_dict
=
{
"xx.weight"
:
torch
.
randn
(
8192
,
4096
),
"xx.weight"
:
torch
.
randn
(
8192
,
4096
),
...
@@ -485,4 +486,4 @@ if __name__ == "__main__":
...
@@ -485,4 +486,4 @@ if __name__ == "__main__":
mm_weight
.
load
(
weight_dict
)
mm_weight
.
load
(
weight_dict
)
input_tensor
=
torch
.
randn
(
1024
,
4096
).
to
(
torch
.
bfloat16
).
cuda
()
input_tensor
=
torch
.
randn
(
1024
,
4096
).
to
(
torch
.
bfloat16
).
cuda
()
output_tensor
=
mm_weight
.
apply
(
input_tensor
)
output_tensor
=
mm_weight
.
apply
(
input_tensor
)
print
(
output_tensor
.
shape
)
logger
.
info
(
output_tensor
.
shape
)
lightx2v/infer.py
View file @
6347b21b
...
@@ -15,6 +15,7 @@ from lightx2v.models.runners.wan.wan_causal_runner import WanCausalRunner
...
@@ -15,6 +15,7 @@ from lightx2v.models.runners.wan.wan_causal_runner import WanCausalRunner
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.common.ops
import
*
from
lightx2v.common.ops
import
*
from
loguru
import
logger
def
init_runner
(
config
):
def
init_runner
(
config
):
...
@@ -37,17 +38,17 @@ if __name__ == "__main__":
...
@@ -37,17 +38,17 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--enable_cfg"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--negative_prompt"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--negative_prompt"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--image_path"
,
type
=
str
,
default
=
""
,
help
=
"The path to input image file or path for image-to-video (i2v) task"
)
parser
.
add_argument
(
"--image_path"
,
type
=
str
,
default
=
""
,
help
=
"The path to input image file or path for image-to-video (i2v) task"
)
parser
.
add_argument
(
"--save_video_path"
,
type
=
str
,
default
=
"./output_lightx2v.mp4"
,
help
=
"The path to save video path/file"
)
parser
.
add_argument
(
"--save_video_path"
,
type
=
str
,
default
=
"./output_lightx2v.mp4"
,
help
=
"The path to save video path/file"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
f
"args:
{
args
}
"
)
logger
.
info
(
f
"args:
{
args
}
"
)
with
ProfilingContext
(
"Total Cost"
):
with
ProfilingContext
(
"Total Cost"
):
config
=
set_config
(
args
)
config
=
set_config
(
args
)
print
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
runner
=
init_runner
(
config
)
runner
=
init_runner
(
config
)
runner
.
run_pipeline
()
runner
.
run_pipeline
()
lightx2v/models/input_encoders/hf/clip/model.py
View file @
6347b21b
import
torch
import
torch
from
transformers
import
CLIPTextModel
,
AutoTokenizer
from
transformers
import
CLIPTextModel
,
AutoTokenizer
from
loguru
import
logger
class
TextEncoderHFClipModel
:
class
TextEncoderHFClipModel
:
...
@@ -54,4 +55,4 @@ if __name__ == "__main__":
...
@@ -54,4 +55,4 @@ if __name__ == "__main__":
model
=
TextEncoderHFClipModel
(
model_path
,
torch
.
device
(
"cuda"
))
model
=
TextEncoderHFClipModel
(
model_path
,
torch
.
device
(
"cuda"
))
text
=
"A cat walks on the grass, realistic style."
text
=
"A cat walks on the grass, realistic style."
outputs
=
model
.
infer
(
text
)
outputs
=
model
.
infer
(
text
)
print
(
outputs
)
logger
.
info
(
outputs
)
lightx2v/models/input_encoders/hf/llama/model.py
View file @
6347b21b
import
torch
import
torch
from
transformers
import
AutoModel
,
AutoTokenizer
from
transformers
import
AutoModel
,
AutoTokenizer
from
loguru
import
logger
class
TextEncoderHFLlamaModel
:
class
TextEncoderHFLlamaModel
:
...
@@ -67,4 +68,4 @@ if __name__ == "__main__":
...
@@ -67,4 +68,4 @@ if __name__ == "__main__":
model
=
TextEncoderHFLlamaModel
(
model_path
,
torch
.
device
(
"cuda"
))
model
=
TextEncoderHFLlamaModel
(
model_path
,
torch
.
device
(
"cuda"
))
text
=
"A cat walks on the grass, realistic style."
text
=
"A cat walks on the grass, realistic style."
outputs
=
model
.
infer
(
text
)
outputs
=
model
.
infer
(
text
)
print
(
outputs
)
logger
.
info
(
outputs
)
lightx2v/models/input_encoders/hf/llava/model.py
View file @
6347b21b
...
@@ -3,6 +3,7 @@ from PIL import Image
...
@@ -3,6 +3,7 @@ from PIL import Image
import
numpy
as
np
import
numpy
as
np
import
torchvision.transforms
as
transforms
import
torchvision.transforms
as
transforms
from
transformers
import
LlavaForConditionalGeneration
,
CLIPImageProcessor
,
AutoTokenizer
from
transformers
import
LlavaForConditionalGeneration
,
CLIPImageProcessor
,
AutoTokenizer
from
loguru
import
logger
def
generate_crop_size_list
(
base_size
=
256
,
patch_size
=
32
,
max_ratio
=
4.0
):
def
generate_crop_size_list
(
base_size
=
256
,
patch_size
=
32
,
max_ratio
=
4.0
):
...
@@ -158,4 +159,4 @@ if __name__ == "__main__":
...
@@ -158,4 +159,4 @@ if __name__ == "__main__":
img_path
=
"/mtc/yongyang/projects/lightx2v/assets/inputs/imgs/img_1.jpg"
img_path
=
"/mtc/yongyang/projects/lightx2v/assets/inputs/imgs/img_1.jpg"
img
=
Image
.
open
(
img_path
).
convert
(
"RGB"
)
img
=
Image
.
open
(
img_path
).
convert
(
"RGB"
)
outputs
=
model
.
infer
(
text
,
img
,
None
)
outputs
=
model
.
infer
(
text
,
img
,
None
)
print
(
outputs
)
logger
.
info
(
outputs
)
lightx2v/models/input_encoders/hf/t5/model.py
View file @
6347b21b
...
@@ -8,6 +8,7 @@ import torch.nn as nn
...
@@ -8,6 +8,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
.tokenizer
import
HuggingfaceTokenizer
from
.tokenizer
import
HuggingfaceTokenizer
from
loguru
import
logger
__all__
=
[
__all__
=
[
"T5Model"
,
"T5Model"
,
...
@@ -522,4 +523,4 @@ if __name__ == "__main__":
...
@@ -522,4 +523,4 @@ if __name__ == "__main__":
)
)
text
=
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
text
=
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
outputs
=
model
.
infer
(
text
)
outputs
=
model
.
infer
(
text
)
print
(
outputs
)
logger
.
info
(
outputs
)
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
View file @
6347b21b
...
@@ -10,6 +10,7 @@ import torchvision.transforms as T
...
@@ -10,6 +10,7 @@ import torchvision.transforms as T
from
lightx2v.attentions
import
attention
from
lightx2v.attentions
import
attention
from
lightx2v.models.input_encoders.hf.t5.tokenizer
import
HuggingfaceTokenizer
from
lightx2v.models.input_encoders.hf.t5.tokenizer
import
HuggingfaceTokenizer
from
loguru
import
logger
from
.xlm_roberta
import
XLMRoberta
from
.xlm_roberta
import
XLMRoberta
...
@@ -190,7 +191,7 @@ class VisionTransformer(nn.Module):
...
@@ -190,7 +191,7 @@ class VisionTransformer(nn.Module):
norm_eps
=
1e-5
,
norm_eps
=
1e-5
,
):
):
if
image_size
%
patch_size
!=
0
:
if
image_size
%
patch_size
!=
0
:
print
(
"[WARNING] image_size is not divisible by patch_size"
,
flush
=
True
)
logger
.
info
(
"[WARNING] image_size is not divisible by patch_size"
,
flush
=
True
)
assert
pool_type
in
(
"token"
,
"token_fc"
,
"attn_pool"
)
assert
pool_type
in
(
"token"
,
"token_fc"
,
"attn_pool"
)
out_dim
=
out_dim
or
dim
out_dim
=
out_dim
or
dim
super
().
__init__
()
super
().
__init__
()
...
...
lightx2v/models/networks/wan/model.py
View file @
6347b21b
...
@@ -123,18 +123,20 @@ class WanModel:
...
@@ -123,18 +123,20 @@ class WanModel:
self
.
scheduler
.
cnt
+=
1
self
.
scheduler
.
cnt
+=
1
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
noise_pred
=
noise_pred_cond
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
if
self
.
config
[
"enable_cfg"
]:
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
scheduler
.
cnt
+=
1
self
.
scheduler
.
cnt
+=
1
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
config
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
config
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
if
self
.
config
[
"cpu_offload"
]:
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cpu
()
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
lightx2v/models/runners/default_runner.py
View file @
6347b21b
...
@@ -4,6 +4,7 @@ import torch.distributed as dist
...
@@ -4,6 +4,7 @@ import torch.distributed as dist
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
,
ProfilingContext
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
,
ProfilingContext
from
lightx2v.utils.utils
import
save_videos_grid
,
cache_video
from
lightx2v.utils.utils
import
save_videos_grid
,
cache_video
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
loguru
import
logger
class
DefaultRunner
:
class
DefaultRunner
:
...
@@ -32,7 +33,7 @@ class DefaultRunner:
...
@@ -32,7 +33,7 @@ class DefaultRunner:
def
run
(
self
):
def
run
(
self
):
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
print
(
f
"==> step_index:
{
step_index
+
1
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
with
ProfilingContext4Debug
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
...
...
lightx2v/models/runners/graph_runner.py
View file @
6347b21b
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
from
loguru
import
logger
class
GraphRunner
:
class
GraphRunner
:
...
@@ -7,10 +8,10 @@ class GraphRunner:
...
@@ -7,10 +8,10 @@ class GraphRunner:
self
.
compile
()
self
.
compile
()
def
compile
(
self
):
def
compile
(
self
):
print
(
"start compile..."
)
logger
.
info
(
"start compile..."
)
with
ProfilingContext4Debug
(
"compile"
):
with
ProfilingContext4Debug
(
"compile"
):
self
.
runner
.
run_step
()
self
.
runner
.
run_step
()
print
(
"end compile..."
)
logger
.
info
(
"end compile..."
)
def
run_pipeline
(
self
):
def
run_pipeline
(
self
):
return
self
.
runner
.
run_pipeline
()
return
self
.
runner
.
run_pipeline
()
lightx2v/models/runners/wan/wan_causal_runner.py
View file @
6347b21b
...
@@ -14,6 +14,7 @@ from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
...
@@ -14,6 +14,7 @@ from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from
lightx2v.models.networks.wan.causal_model
import
WanCausalModel
from
lightx2v.models.networks.wan.causal_model
import
WanCausalModel
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
loguru
import
logger
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -54,7 +55,7 @@ class WanCausalRunner(WanRunner):
...
@@ -54,7 +55,7 @@ class WanCausalRunner(WanRunner):
lora_wrapper
=
WanLoraWrapper
(
model
)
lora_wrapper
=
WanLoraWrapper
(
model
)
lora_name
=
lora_wrapper
.
load_lora
(
self
.
config
.
lora_path
)
lora_name
=
lora_wrapper
.
load_lora
(
self
.
config
.
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
self
.
config
.
strength_model
)
lora_wrapper
.
apply_lora
(
lora_name
,
self
.
config
.
strength_model
)
print
(
f
"Loaded LoRA:
{
lora_name
}
"
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
"
)
vae_model
=
WanVAE
(
vae_pth
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"Wan2.1_VAE.pth"
),
device
=
init_device
,
parallel
=
self
.
config
.
parallel_vae
)
vae_model
=
WanVAE
(
vae_pth
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"Wan2.1_VAE.pth"
),
device
=
init_device
,
parallel
=
self
.
config
.
parallel_vae
)
if
self
.
config
.
task
==
"i2v"
:
if
self
.
config
.
task
==
"i2v"
:
...
@@ -95,13 +96,13 @@ class WanCausalRunner(WanRunner):
...
@@ -95,13 +96,13 @@ class WanCausalRunner(WanRunner):
start_block_idx
=
0
start_block_idx
=
0
for
fragment_idx
in
range
(
self
.
num_fragments
):
for
fragment_idx
in
range
(
self
.
num_fragments
):
print
(
f
"=======> fragment_idx:
{
fragment_idx
+
1
}
/
{
self
.
num_fragments
}
"
)
logger
.
info
(
f
"=======> fragment_idx:
{
fragment_idx
+
1
}
/
{
self
.
num_fragments
}
"
)
kv_start
=
0
kv_start
=
0
kv_end
=
kv_start
+
self
.
num_frame_per_block
*
self
.
frame_seq_length
kv_end
=
kv_start
+
self
.
num_frame_per_block
*
self
.
frame_seq_length
if
fragment_idx
>
0
:
if
fragment_idx
>
0
:
print
(
"recompute the kv_cache ..."
)
logger
.
info
(
"recompute the kv_cache ..."
)
with
ProfilingContext4Debug
(
"step_pre"
):
with
ProfilingContext4Debug
(
"step_pre"
):
self
.
model
.
scheduler
.
latents
=
self
.
model
.
scheduler
.
last_sample
self
.
model
.
scheduler
.
latents
=
self
.
model
.
scheduler
.
last_sample
self
.
model
.
scheduler
.
step_pre
(
step_index
=
self
.
model
.
scheduler
.
infer_steps
-
1
)
self
.
model
.
scheduler
.
step_pre
(
step_index
=
self
.
model
.
scheduler
.
infer_steps
-
1
)
...
@@ -115,12 +116,12 @@ class WanCausalRunner(WanRunner):
...
@@ -115,12 +116,12 @@ class WanCausalRunner(WanRunner):
infer_blocks
=
self
.
infer_blocks
-
(
fragment_idx
>
0
)
infer_blocks
=
self
.
infer_blocks
-
(
fragment_idx
>
0
)
for
block_idx
in
range
(
infer_blocks
):
for
block_idx
in
range
(
infer_blocks
):
print
(
f
"=======> block_idx:
{
block_idx
+
1
}
/
{
infer_blocks
}
"
)
logger
.
info
(
f
"=======> block_idx:
{
block_idx
+
1
}
/
{
infer_blocks
}
"
)
print
(
f
"=======> kv_start:
{
kv_start
}
, kv_end:
{
kv_end
}
"
)
logger
.
info
(
f
"=======> kv_start:
{
kv_start
}
, kv_end:
{
kv_end
}
"
)
self
.
model
.
scheduler
.
reset
()
self
.
model
.
scheduler
.
reset
()
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
print
(
f
"==> step_index:
{
step_index
+
1
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
with
ProfilingContext4Debug
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
6347b21b
...
@@ -14,6 +14,7 @@ from lightx2v.models.networks.wan.model import WanModel
...
@@ -14,6 +14,7 @@ from lightx2v.models.networks.wan.model import WanModel
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
loguru
import
logger
@
RUNNER_REGISTER
(
"wan2.1"
)
@
RUNNER_REGISTER
(
"wan2.1"
)
...
@@ -47,7 +48,7 @@ class WanRunner(DefaultRunner):
...
@@ -47,7 +48,7 @@ class WanRunner(DefaultRunner):
lora_wrapper
=
WanLoraWrapper
(
model
)
lora_wrapper
=
WanLoraWrapper
(
model
)
lora_name
=
lora_wrapper
.
load_lora
(
self
.
config
.
lora_path
)
lora_name
=
lora_wrapper
.
load_lora
(
self
.
config
.
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
self
.
config
.
strength_model
)
lora_wrapper
.
apply_lora
(
lora_name
,
self
.
config
.
strength_model
)
print
(
f
"Loaded LoRA:
{
lora_name
}
"
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
"
)
vae_model
=
WanVAE
(
vae_pth
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"Wan2.1_VAE.pth"
),
device
=
init_device
,
parallel
=
self
.
config
.
parallel_vae
)
vae_model
=
WanVAE
(
vae_pth
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"Wan2.1_VAE.pth"
),
device
=
init_device
,
parallel
=
self
.
config
.
parallel_vae
)
if
self
.
config
.
task
==
"i2v"
:
if
self
.
config
.
task
==
"i2v"
:
...
...
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
6347b21b
...
@@ -7,6 +7,7 @@ import torch.nn as nn
...
@@ -7,6 +7,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
einops
import
rearrange
from
einops
import
rearrange
from
loguru
import
logger
__all__
=
[
__all__
=
[
"WanVAE"
,
"WanVAE"
,
...
@@ -801,7 +802,7 @@ class WanVAE:
...
@@ -801,7 +802,7 @@ class WanVAE:
split_dim
=
2
split_dim
=
2
images
=
self
.
decode_dist
(
zs
,
world_size
,
cur_rank
,
split_dim
)
images
=
self
.
decode_dist
(
zs
,
world_size
,
cur_rank
,
split_dim
)
else
:
else
:
print
(
"Fall back to naive decode mode"
)
logger
.
info
(
"Fall back to naive decode mode"
)
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
else
:
else
:
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
...
...
lightx2v/utils/profiler.py
View file @
6347b21b
...
@@ -2,6 +2,7 @@ import time
...
@@ -2,6 +2,7 @@ import time
import
torch
import
torch
from
contextlib
import
ContextDecorator
from
contextlib
import
ContextDecorator
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
loguru
import
logger
class
_ProfilingContext
(
ContextDecorator
):
class
_ProfilingContext
(
ContextDecorator
):
...
@@ -16,7 +17,7 @@ class _ProfilingContext(ContextDecorator):
...
@@ -16,7 +17,7 @@ class _ProfilingContext(ContextDecorator):
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
print
(
f
"[Profile]
{
self
.
name
}
cost
{
elapsed
:.
6
f
}
seconds"
)
logger
.
info
(
f
"[Profile]
{
self
.
name
}
cost
{
elapsed
:.
6
f
}
seconds"
)
return
False
return
False
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment