Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
6347b21b
"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "66301e124f19099ceef3023494551917fb67da83"
Commit
6347b21b
authored
Apr 29, 2025
by
root
Browse files
Support convert weight to diffusers.
parent
aec90a0d
Changes
23
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
215 additions
and
44 deletions
+215
-44
configs/wan_i2v.json
configs/wan_i2v.json
+2
-0
examples/diffusers/converter.py
examples/diffusers/converter.py
+150
-0
lightx2v/api_server.py
lightx2v/api_server.py
+6
-5
lightx2v/attentions/distributed/comm/ring_comm.py
lightx2v/attentions/distributed/comm/ring_comm.py
+2
-1
lightx2v/attentions/distributed/partial_heads_attn/tests/test_acc.py
...tentions/distributed/partial_heads_attn/tests/test_acc.py
+3
-2
lightx2v/common/backend_infer/trt/common.py
lightx2v/common/backend_infer/trt/common.py
+3
-2
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+4
-3
lightx2v/infer.py
lightx2v/infer.py
+4
-3
lightx2v/models/input_encoders/hf/clip/model.py
lightx2v/models/input_encoders/hf/clip/model.py
+2
-1
lightx2v/models/input_encoders/hf/llama/model.py
lightx2v/models/input_encoders/hf/llama/model.py
+2
-1
lightx2v/models/input_encoders/hf/llava/model.py
lightx2v/models/input_encoders/hf/llava/model.py
+2
-1
lightx2v/models/input_encoders/hf/t5/model.py
lightx2v/models/input_encoders/hf/t5/model.py
+2
-1
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
+2
-1
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+13
-11
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+2
-1
lightx2v/models/runners/graph_runner.py
lightx2v/models/runners/graph_runner.py
+3
-2
lightx2v/models/runners/wan/wan_causal_runner.py
lightx2v/models/runners/wan/wan_causal_runner.py
+7
-6
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+2
-1
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+2
-1
lightx2v/utils/profiler.py
lightx2v/utils/profiler.py
+2
-1
No files found.
configs/wan_i2v.json
View file @
6347b21b
...
@@ -7,6 +7,8 @@
...
@@ -7,6 +7,8 @@
"seed"
:
42
,
"seed"
:
42
,
"sample_guide_scale"
:
5
,
"sample_guide_scale"
:
5
,
"sample_shift"
:
5
,
"sample_shift"
:
5
,
"enable_cfg"
:
true
,
"cpu_offload"
:
false
,
"mm_config"
:
{
"mm_config"
:
{
"mm_type"
:
"W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
,
"mm_type"
:
"W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
,
"weight_auto_quant"
:
true
"weight_auto_quant"
:
true
...
...
examples/diffusers/converter.py
0 → 100755
View file @
6347b21b
import
os
import
re
import
glob
import
json
import
argparse
import
torch
from
safetensors
import
safe_open
,
torch
as
st
from
loguru
import
logger
from
tqdm
import
tqdm
def
get_key_mapping_rules
(
direction
,
model_type
):
if
model_type
==
"wan"
:
unified_rules
=
[
{
"forward"
:
(
r
"^head\.head$"
,
"proj_out"
),
"backward"
:
(
r
"^proj_out$"
,
"head.head"
)},
{
"forward"
:
(
r
"^head\.modulation$"
,
"scale_shift_table"
),
"backward"
:
(
r
"^scale_shift_table$"
,
"head.modulation"
)},
{
"forward"
:
(
r
"^text_embedding\.0\."
,
"condition_embedder.text_embedder.linear_1."
),
"backward"
:
(
r
"^condition_embedder.text_embedder.linear_1\."
,
"text_embedding.0."
)},
{
"forward"
:
(
r
"^text_embedding\.2\."
,
"condition_embedder.text_embedder.linear_2."
),
"backward"
:
(
r
"^condition_embedder.text_embedder.linear_2\."
,
"text_embedding.2."
)},
{
"forward"
:
(
r
"^time_embedding\.0\."
,
"condition_embedder.time_embedder.linear_1."
),
"backward"
:
(
r
"^condition_embedder.time_embedder.linear_1\."
,
"time_embedding.0."
)},
{
"forward"
:
(
r
"^time_embedding\.2\."
,
"condition_embedder.time_embedder.linear_2."
),
"backward"
:
(
r
"^condition_embedder.time_embedder.linear_2\."
,
"time_embedding.2."
)},
{
"forward"
:
(
r
"^time_projection\.1\."
,
"condition_embedder.time_proj."
),
"backward"
:
(
r
"^condition_embedder.time_proj\."
,
"time_projection.1."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.self_attn\.q\."
,
r
"blocks.\1.attn1.to_q."
),
"backward"
:
(
r
"blocks\.(\d+)\.attn1\.to_q\."
,
r
"blocks.\1.self_attn.q."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.self_attn\.k\."
,
r
"blocks.\1.attn1.to_k."
),
"backward"
:
(
r
"blocks\.(\d+)\.attn1\.to_k\."
,
r
"blocks.\1.self_attn.k."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.self_attn\.v\."
,
r
"blocks.\1.attn1.to_v."
),
"backward"
:
(
r
"blocks\.(\d+)\.attn1\.to_v\."
,
r
"blocks.\1.self_attn.v."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.self_attn\.o\."
,
r
"blocks.\1.attn1.to_out.0."
),
"backward"
:
(
r
"blocks\.(\d+)\.attn1\.to_out\.0\."
,
r
"blocks.\1.self_attn.o."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.cross_attn\.q\."
,
r
"blocks.\1.attn2.to_q."
),
"backward"
:
(
r
"blocks\.(\d+)\.attn2\.to_q\."
,
r
"blocks.\1.cross_attn.q."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.cross_attn\.k\."
,
r
"blocks.\1.attn2.to_k."
),
"backward"
:
(
r
"blocks\.(\d+)\.attn2\.to_k\."
,
r
"blocks.\1.cross_attn.k."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.cross_attn\.v\."
,
r
"blocks.\1.attn2.to_v."
),
"backward"
:
(
r
"blocks\.(\d+)\.attn2\.to_v\."
,
r
"blocks.\1.cross_attn.v."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.cross_attn\.o\."
,
r
"blocks.\1.attn2.to_out.0."
),
"backward"
:
(
r
"blocks\.(\d+)\.attn2\.to_out\.0\."
,
r
"blocks.\1.cross_attn.o."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.norm3\."
,
r
"blocks.\1.norm2."
),
"backward"
:
(
r
"blocks\.(\d+)\.norm2\."
,
r
"blocks.\1.norm3."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.ffn\.0\."
,
r
"blocks.\1.ffn.net.0.proj."
),
"backward"
:
(
r
"blocks\.(\d+)\.ffn\.net\.0\.proj\."
,
r
"blocks.\1.ffn.0."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.ffn\.2\."
,
r
"blocks.\1.ffn.net.2."
),
"backward"
:
(
r
"blocks\.(\d+)\.ffn\.net\.2\."
,
r
"blocks.\1.ffn.2."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.modulation\."
,
r
"blocks.\1.scale_shift_table."
),
"backward"
:
(
r
"blocks\.(\d+)\.scale_shift_table(?=\.|$)"
,
r
"blocks.\1.modulation"
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.cross_attn\.k_img\."
,
r
"blocks.\1.attn2.add_k_proj."
),
"backward"
:
(
r
"blocks\.(\d+)\.attn2\.add_k_proj\."
,
r
"blocks.\1.cross_attn.k_img."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.cross_attn\.v_img\."
,
r
"blocks.\1.attn2.add_v_proj."
),
"backward"
:
(
r
"blocks\.(\d+)\.attn2\.add_v_proj\."
,
r
"blocks.\1.cross_attn.v_img."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.cross_attn\.norm_k_img\.weight"
,
r
"blocks.\1.attn2.norm_added_k.weight"
),
"backward"
:
(
r
"blocks\.(\d+)\.attn2\.norm_added_k\.weight"
,
r
"blocks.\1.cross_attn.norm_k_img.weight"
),
},
{
"forward"
:
(
r
"img_emb\.proj\.0\."
,
r
"condition_embedder.image_embedder.norm1."
),
"backward"
:
(
r
"condition_embedder\.image_embedder\.norm1\."
,
r
"img_emb.proj.0."
)},
{
"forward"
:
(
r
"img_emb\.proj\.1\."
,
r
"condition_embedder.image_embedder.ff.net.0.proj."
),
"backward"
:
(
r
"condition_embedder\.image_embedder\.ff\.net\.0\.proj\."
,
r
"img_emb.proj.1."
)},
{
"forward"
:
(
r
"img_emb\.proj\.3\."
,
r
"condition_embedder.image_embedder.ff.net.2."
),
"backward"
:
(
r
"condition_embedder\.image_embedder\.ff\.net\.2\."
,
r
"img_emb.proj.3."
)},
{
"forward"
:
(
r
"img_emb\.proj\.4\."
,
r
"condition_embedder.image_embedder.norm2."
),
"backward"
:
(
r
"condition_embedder\.image_embedder\.norm2\."
,
r
"img_emb.proj.4."
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.self_attn\.norm_q\.weight"
,
r
"blocks.\1.attn1.norm_q.weight"
),
"backward"
:
(
r
"blocks\.(\d+)\.attn1\.norm_q\.weight"
,
r
"blocks.\1.self_attn.norm_q.weight"
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.self_attn\.norm_k\.weight"
,
r
"blocks.\1.attn1.norm_k.weight"
),
"backward"
:
(
r
"blocks\.(\d+)\.attn1\.norm_k\.weight"
,
r
"blocks.\1.self_attn.norm_k.weight"
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.cross_attn\.norm_q\.weight"
,
r
"blocks.\1.attn2.norm_q.weight"
),
"backward"
:
(
r
"blocks\.(\d+)\.attn2\.norm_q\.weight"
,
r
"blocks.\1.cross_attn.norm_q.weight"
)},
{
"forward"
:
(
r
"blocks\.(\d+)\.cross_attn\.norm_k\.weight"
,
r
"blocks.\1.attn2.norm_k.weight"
),
"backward"
:
(
r
"blocks\.(\d+)\.attn2\.norm_k\.weight"
,
r
"blocks.\1.cross_attn.norm_k.weight"
)},
# head projection mapping
{
"forward"
:
(
r
"^head\.head\."
,
"proj_out."
),
"backward"
:
(
r
"^proj_out\."
,
"head.head."
)},
]
if
direction
==
"forward"
:
return
[
rule
[
"forward"
]
for
rule
in
unified_rules
]
elif
direction
==
"backward"
:
return
[
rule
[
"backward"
]
for
rule
in
unified_rules
]
else
:
raise
ValueError
(
f
"Invalid direction:
{
direction
}
"
)
else
:
raise
ValueError
(
f
"Unsupported model type:
{
model_type
}
"
)
def
convert_weights
(
args
):
if
os
.
path
.
isdir
(
args
.
source
):
src_files
=
glob
.
glob
(
os
.
path
.
join
(
args
.
source
,
"*.safetensors"
),
recursive
=
True
)
elif
args
.
source
.
endswith
((
".pth"
,
".safetensors"
,
"pt"
)):
src_files
=
[
args
.
source
]
else
:
raise
ValueError
(
"Invalid input path"
)
merged_weights
=
{}
logger
.
info
(
f
"Processing source files:
{
src_files
}
"
)
for
file_path
in
tqdm
(
src_files
,
desc
=
"Loading weights"
):
logger
.
info
(
f
"Loading weights from:
{
file_path
}
"
)
if
file_path
.
endswith
(
".pt"
)
or
file_path
.
endswith
(
".pth"
):
weights
=
torch
.
load
(
file_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
elif
file_path
.
endswith
(
".safetensors"
):
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
weights
=
{
k
:
f
.
get_tensor
(
k
)
for
k
in
f
.
keys
()}
duplicate_keys
=
set
(
weights
.
keys
())
&
set
(
merged_weights
.
keys
())
if
duplicate_keys
:
raise
ValueError
(
f
"Duplicate keys found:
{
duplicate_keys
}
in file
{
file_path
}
"
)
merged_weights
.
update
(
weights
)
rules
=
get_key_mapping_rules
(
args
.
direction
,
args
.
model_type
)
converted_weights
=
{}
logger
.
info
(
"Converting keys..."
)
for
key
in
tqdm
(
merged_weights
.
keys
(),
desc
=
"Converting keys"
):
new_key
=
key
for
pattern
,
replacement
in
rules
:
new_key
=
re
.
sub
(
pattern
,
replacement
,
new_key
)
converted_weights
[
new_key
]
=
merged_weights
[
key
]
os
.
makedirs
(
args
.
output
,
exist_ok
=
True
)
base_name
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
args
.
source
))[
0
]
if
args
.
source
.
endswith
((
".pth"
,
".safetensors"
))
else
"converted_model"
index
=
{
"metadata"
:
{
"total_size"
:
0
},
"weight_map"
:
{}}
chunk_idx
=
0
current_chunk
=
{}
for
idx
,
(
k
,
v
)
in
tqdm
(
enumerate
(
converted_weights
.
items
()),
desc
=
"Saving chunks"
):
current_chunk
[
k
]
=
v
if
(
idx
+
1
)
%
args
.
chunk_size
==
0
and
args
.
chunk_size
>
0
:
output_filename
=
f
"
{
base_name
}
_part
{
chunk_idx
}
.safetensors"
output_path
=
os
.
path
.
join
(
args
.
output
,
output_filename
)
logger
.
info
(
f
"Saving chunk to:
{
output_path
}
"
)
st
.
save_file
(
current_chunk
,
output_path
)
for
key
in
current_chunk
:
index
[
"weight_map"
][
key
]
=
output_filename
index
[
"metadata"
][
"total_size"
]
+=
os
.
path
.
getsize
(
output_path
)
current_chunk
=
{}
chunk_idx
+=
1
if
current_chunk
:
output_filename
=
f
"
{
base_name
}
_part
{
chunk_idx
}
.safetensors"
output_path
=
os
.
path
.
join
(
args
.
output
,
output_filename
)
logger
.
info
(
f
"Saving final chunk to:
{
output_path
}
"
)
st
.
save_file
(
current_chunk
,
output_path
)
for
key
in
current_chunk
:
index
[
"weight_map"
][
key
]
=
output_filename
index
[
"metadata"
][
"total_size"
]
+=
os
.
path
.
getsize
(
output_path
)
# Save index file
index_path
=
os
.
path
.
join
(
args
.
output
,
"diffusion_pytorch_model.safetensors.index.json"
)
with
open
(
index_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
index
,
f
,
indent
=
2
)
logger
.
info
(
f
"Index file written to:
{
index_path
}
"
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Model weight format converter"
)
parser
.
add_argument
(
"-s"
,
"--source"
,
required
=
True
,
help
=
"Input path (file or directory)"
)
parser
.
add_argument
(
"-o"
,
"--output"
,
required
=
True
,
help
=
"Output directory path"
)
parser
.
add_argument
(
"-d"
,
"--direction"
,
choices
=
[
"forward"
,
"backward"
],
default
=
"forward"
,
help
=
"Conversion direction: forward = 'lightx2v' -> 'Diffusers', backward = reverse"
)
parser
.
add_argument
(
"-c"
,
"--chunk-size"
,
type
=
int
,
default
=
100
,
help
=
"Chunk size for saving (only applies to forward), 0 = no chunking"
)
parser
.
add_argument
(
"-t"
,
"--model_type"
,
choices
=
[
"wan"
],
default
=
"wan"
,
help
=
"Model type"
)
args
=
parser
.
parse_args
()
if
os
.
path
.
isfile
(
args
.
output
):
raise
ValueError
(
"Output path must be a directory, not a file"
)
logger
.
info
(
"Starting model weight conversion..."
)
convert_weights
(
args
)
logger
.
info
(
f
"Conversion completed! Files saved to:
{
args
.
output
}
"
)
if
__name__
==
"__main__"
:
main
()
lightx2v/api_server.py
View file @
6347b21b
...
@@ -4,6 +4,7 @@ import psutil
...
@@ -4,6 +4,7 @@ import psutil
import
argparse
import
argparse
from
fastapi
import
FastAPI
,
Request
from
fastapi
import
FastAPI
,
Request
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
loguru
import
logger
import
uvicorn
import
uvicorn
import
json
import
json
import
asyncio
import
asyncio
...
@@ -26,15 +27,15 @@ def kill_all_related_processes():
...
@@ -26,15 +27,15 @@ def kill_all_related_processes():
try
:
try
:
child
.
kill
()
child
.
kill
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"Failed to kill child process
{
child
.
pid
}
:
{
e
}
"
)
logger
.
info
(
f
"Failed to kill child process
{
child
.
pid
}
:
{
e
}
"
)
try
:
try
:
current_process
.
kill
()
current_process
.
kill
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"Failed to kill main process:
{
e
}
"
)
logger
.
info
(
f
"Failed to kill main process:
{
e
}
"
)
def
signal_handler
(
sig
,
frame
):
def
signal_handler
(
sig
,
frame
):
print
(
"
\n
Received Ctrl+C, shutting down all related processes..."
)
logger
.
info
(
"
\n
Received Ctrl+C, shutting down all related processes..."
)
kill_all_related_processes
()
kill_all_related_processes
()
sys
.
exit
(
0
)
sys
.
exit
(
0
)
...
@@ -79,11 +80,11 @@ if __name__ == "__main__":
...
@@ -79,11 +80,11 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
f
"args:
{
args
}
"
)
logger
.
info
(
f
"args:
{
args
}
"
)
with
ProfilingContext
(
"Init Server Cost"
):
with
ProfilingContext
(
"Init Server Cost"
):
config
=
set_config
(
args
)
config
=
set_config
(
args
)
print
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
runner
=
init_runner
(
config
)
runner
=
init_runner
(
config
)
uvicorn
.
run
(
app
,
host
=
"0.0.0.0"
,
port
=
config
.
port
,
reload
=
False
,
workers
=
1
)
uvicorn
.
run
(
app
,
host
=
"0.0.0.0"
,
port
=
config
.
port
,
reload
=
False
,
workers
=
1
)
lightx2v/attentions/distributed/comm/ring_comm.py
View file @
6347b21b
from
typing
import
Optional
from
typing
import
Optional
from
loguru
import
logger
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -21,7 +22,7 @@ class RingComm:
...
@@ -21,7 +22,7 @@ class RingComm:
def
send_recv
(
self
,
to_send
:
torch
.
Tensor
,
recv_tensor
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
def
send_recv
(
self
,
to_send
:
torch
.
Tensor
,
recv_tensor
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
recv_tensor
is
None
:
if
recv_tensor
is
None
:
res
=
torch
.
empty_like
(
to_send
)
res
=
torch
.
empty_like
(
to_send
)
#
print
(f"send_recv: empty_like {to_send.shape}")
#
logger.info
(f"send_recv: empty_like {to_send.shape}")
else
:
else
:
res
=
recv_tensor
res
=
recv_tensor
...
...
lightx2v/attentions/distributed/partial_heads_attn/tests/test_acc.py
View file @
6347b21b
...
@@ -2,6 +2,7 @@ import torch
...
@@ -2,6 +2,7 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
lightx2v.attentions
import
attention
from
lightx2v.attentions
import
attention
from
lightx2v.utils.utils
import
seed_all
from
lightx2v.utils.utils
import
seed_all
from
loguru
import
logger
seed_all
(
42
)
seed_all
(
42
)
...
@@ -65,10 +66,10 @@ def test_part_head():
...
@@ -65,10 +66,10 @@ def test_part_head():
# 验证结果一致性
# 验证结果一致性
if
cur_rank
==
0
:
if
cur_rank
==
0
:
# import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
print
(
"Outputs match:"
,
torch
.
allclose
(
single_gpu_output
,
combined_output
,
rtol
=
1e-3
,
atol
=
1e-3
))
logger
.
info
(
"Outputs match:"
,
torch
.
allclose
(
single_gpu_output
,
combined_output
,
rtol
=
1e-3
,
atol
=
1e-3
))
# # 验证结果一致性
# # 验证结果一致性
#
print
("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3))
#
logger.info
("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
lightx2v/common/backend_infer/trt/common.py
View file @
6347b21b
...
@@ -20,6 +20,7 @@ import os
...
@@ -20,6 +20,7 @@ import os
import
tensorrt
as
trt
import
tensorrt
as
trt
from
.common_runtime
import
*
from
.common_runtime
import
*
from
loguru
import
logger
try
:
try
:
# Sometimes python does not understand FileNotFoundError
# Sometimes python does not understand FileNotFoundError
...
@@ -67,11 +68,11 @@ def find_sample_data(description="Runs a TensorRT Python sample", subfolder="",
...
@@ -67,11 +68,11 @@ def find_sample_data(description="Runs a TensorRT Python sample", subfolder="",
data_path
=
os
.
path
.
join
(
data_dir
,
subfolder
)
data_path
=
os
.
path
.
join
(
data_dir
,
subfolder
)
if
not
os
.
path
.
exists
(
data_path
):
if
not
os
.
path
.
exists
(
data_path
):
if
data_dir
!=
kDEFAULT_DATA_ROOT
:
if
data_dir
!=
kDEFAULT_DATA_ROOT
:
print
(
"WARNING: "
+
data_path
+
" does not exist. Trying "
+
data_dir
+
" instead."
)
logger
.
info
(
"WARNING: "
+
data_path
+
" does not exist. Trying "
+
data_dir
+
" instead."
)
data_path
=
data_dir
data_path
=
data_dir
# Make sure data directory exists.
# Make sure data directory exists.
if
not
(
os
.
path
.
exists
(
data_path
))
and
data_dir
!=
kDEFAULT_DATA_ROOT
:
if
not
(
os
.
path
.
exists
(
data_path
))
and
data_dir
!=
kDEFAULT_DATA_ROOT
:
print
(
"WARNING: {:} does not exist. Please provide the correct data path with the -d option."
.
format
(
data_path
))
logger
.
info
(
"WARNING: {:} does not exist. Please provide the correct data path with the -d option."
.
format
(
data_path
))
return
data_path
return
data_path
data_paths
=
[
get_data_path
(
data_dir
)
for
data_dir
in
args
.
datadir
]
data_paths
=
[
get_data_path
(
data_dir
)
for
data_dir
in
args
.
datadir
]
...
...
lightx2v/common/ops/mm/mm_weight.py
View file @
6347b21b
...
@@ -4,6 +4,7 @@ from vllm import _custom_ops as ops
...
@@ -4,6 +4,7 @@ from vllm import _custom_ops as ops
import
sgl_kernel
import
sgl_kernel
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.utils.quant_utils
import
IntegerQuantizer
,
FloatQuantizer
from
lightx2v.utils.quant_utils
import
IntegerQuantizer
,
FloatQuantizer
from
loguru
import
logger
try
:
try
:
import
q8_kernels.functional
as
Q8F
import
q8_kernels.functional
as
Q8F
...
@@ -461,7 +462,7 @@ if __name__ == "__main__":
...
@@ -461,7 +462,7 @@ if __name__ == "__main__":
mm_weight
.
load
(
weight_dict
)
mm_weight
.
load
(
weight_dict
)
input_tensor
=
torch
.
randn
(
1024
,
4096
).
to
(
torch
.
bfloat16
).
cuda
()
input_tensor
=
torch
.
randn
(
1024
,
4096
).
to
(
torch
.
bfloat16
).
cuda
()
output_tensor
=
mm_weight
.
apply
(
input_tensor
)
output_tensor
=
mm_weight
.
apply
(
input_tensor
)
print
(
output_tensor
.
shape
)
logger
.
info
(
output_tensor
.
shape
)
weight_dict
=
{
weight_dict
=
{
"xx.weight"
:
torch
.
randn
(
8192
,
4096
),
"xx.weight"
:
torch
.
randn
(
8192
,
4096
),
...
@@ -473,7 +474,7 @@ if __name__ == "__main__":
...
@@ -473,7 +474,7 @@ if __name__ == "__main__":
mm_weight
.
load
(
weight_dict
)
mm_weight
.
load
(
weight_dict
)
input_tensor
=
torch
.
randn
(
1024
,
4096
).
to
(
torch
.
bfloat16
).
cuda
()
input_tensor
=
torch
.
randn
(
1024
,
4096
).
to
(
torch
.
bfloat16
).
cuda
()
output_tensor
=
mm_weight
.
apply
(
input_tensor
)
output_tensor
=
mm_weight
.
apply
(
input_tensor
)
print
(
output_tensor
.
shape
)
logger
.
info
(
output_tensor
.
shape
)
weight_dict
=
{
weight_dict
=
{
"xx.weight"
:
torch
.
randn
(
8192
,
4096
),
"xx.weight"
:
torch
.
randn
(
8192
,
4096
),
...
@@ -485,4 +486,4 @@ if __name__ == "__main__":
...
@@ -485,4 +486,4 @@ if __name__ == "__main__":
mm_weight
.
load
(
weight_dict
)
mm_weight
.
load
(
weight_dict
)
input_tensor
=
torch
.
randn
(
1024
,
4096
).
to
(
torch
.
bfloat16
).
cuda
()
input_tensor
=
torch
.
randn
(
1024
,
4096
).
to
(
torch
.
bfloat16
).
cuda
()
output_tensor
=
mm_weight
.
apply
(
input_tensor
)
output_tensor
=
mm_weight
.
apply
(
input_tensor
)
print
(
output_tensor
.
shape
)
logger
.
info
(
output_tensor
.
shape
)
lightx2v/infer.py
View file @
6347b21b
...
@@ -15,6 +15,7 @@ from lightx2v.models.runners.wan.wan_causal_runner import WanCausalRunner
...
@@ -15,6 +15,7 @@ from lightx2v.models.runners.wan.wan_causal_runner import WanCausalRunner
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.common.ops
import
*
from
lightx2v.common.ops
import
*
from
loguru
import
logger
def
init_runner
(
config
):
def
init_runner
(
config
):
...
@@ -37,17 +38,17 @@ if __name__ == "__main__":
...
@@ -37,17 +38,17 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--enable_cfg"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--negative_prompt"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--negative_prompt"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--image_path"
,
type
=
str
,
default
=
""
,
help
=
"The path to input image file or path for image-to-video (i2v) task"
)
parser
.
add_argument
(
"--image_path"
,
type
=
str
,
default
=
""
,
help
=
"The path to input image file or path for image-to-video (i2v) task"
)
parser
.
add_argument
(
"--save_video_path"
,
type
=
str
,
default
=
"./output_lightx2v.mp4"
,
help
=
"The path to save video path/file"
)
parser
.
add_argument
(
"--save_video_path"
,
type
=
str
,
default
=
"./output_lightx2v.mp4"
,
help
=
"The path to save video path/file"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
f
"args:
{
args
}
"
)
logger
.
info
(
f
"args:
{
args
}
"
)
with
ProfilingContext
(
"Total Cost"
):
with
ProfilingContext
(
"Total Cost"
):
config
=
set_config
(
args
)
config
=
set_config
(
args
)
print
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
runner
=
init_runner
(
config
)
runner
=
init_runner
(
config
)
runner
.
run_pipeline
()
runner
.
run_pipeline
()
lightx2v/models/input_encoders/hf/clip/model.py
View file @
6347b21b
import
torch
import
torch
from
transformers
import
CLIPTextModel
,
AutoTokenizer
from
transformers
import
CLIPTextModel
,
AutoTokenizer
from
loguru
import
logger
class
TextEncoderHFClipModel
:
class
TextEncoderHFClipModel
:
...
@@ -54,4 +55,4 @@ if __name__ == "__main__":
...
@@ -54,4 +55,4 @@ if __name__ == "__main__":
model
=
TextEncoderHFClipModel
(
model_path
,
torch
.
device
(
"cuda"
))
model
=
TextEncoderHFClipModel
(
model_path
,
torch
.
device
(
"cuda"
))
text
=
"A cat walks on the grass, realistic style."
text
=
"A cat walks on the grass, realistic style."
outputs
=
model
.
infer
(
text
)
outputs
=
model
.
infer
(
text
)
print
(
outputs
)
logger
.
info
(
outputs
)
lightx2v/models/input_encoders/hf/llama/model.py
View file @
6347b21b
import
torch
import
torch
from
transformers
import
AutoModel
,
AutoTokenizer
from
transformers
import
AutoModel
,
AutoTokenizer
from
loguru
import
logger
class
TextEncoderHFLlamaModel
:
class
TextEncoderHFLlamaModel
:
...
@@ -67,4 +68,4 @@ if __name__ == "__main__":
...
@@ -67,4 +68,4 @@ if __name__ == "__main__":
model
=
TextEncoderHFLlamaModel
(
model_path
,
torch
.
device
(
"cuda"
))
model
=
TextEncoderHFLlamaModel
(
model_path
,
torch
.
device
(
"cuda"
))
text
=
"A cat walks on the grass, realistic style."
text
=
"A cat walks on the grass, realistic style."
outputs
=
model
.
infer
(
text
)
outputs
=
model
.
infer
(
text
)
print
(
outputs
)
logger
.
info
(
outputs
)
lightx2v/models/input_encoders/hf/llava/model.py
View file @
6347b21b
...
@@ -3,6 +3,7 @@ from PIL import Image
...
@@ -3,6 +3,7 @@ from PIL import Image
import
numpy
as
np
import
numpy
as
np
import
torchvision.transforms
as
transforms
import
torchvision.transforms
as
transforms
from
transformers
import
LlavaForConditionalGeneration
,
CLIPImageProcessor
,
AutoTokenizer
from
transformers
import
LlavaForConditionalGeneration
,
CLIPImageProcessor
,
AutoTokenizer
from
loguru
import
logger
def
generate_crop_size_list
(
base_size
=
256
,
patch_size
=
32
,
max_ratio
=
4.0
):
def
generate_crop_size_list
(
base_size
=
256
,
patch_size
=
32
,
max_ratio
=
4.0
):
...
@@ -158,4 +159,4 @@ if __name__ == "__main__":
...
@@ -158,4 +159,4 @@ if __name__ == "__main__":
img_path
=
"/mtc/yongyang/projects/lightx2v/assets/inputs/imgs/img_1.jpg"
img_path
=
"/mtc/yongyang/projects/lightx2v/assets/inputs/imgs/img_1.jpg"
img
=
Image
.
open
(
img_path
).
convert
(
"RGB"
)
img
=
Image
.
open
(
img_path
).
convert
(
"RGB"
)
outputs
=
model
.
infer
(
text
,
img
,
None
)
outputs
=
model
.
infer
(
text
,
img
,
None
)
print
(
outputs
)
logger
.
info
(
outputs
)
lightx2v/models/input_encoders/hf/t5/model.py
View file @
6347b21b
...
@@ -8,6 +8,7 @@ import torch.nn as nn
...
@@ -8,6 +8,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
.tokenizer
import
HuggingfaceTokenizer
from
.tokenizer
import
HuggingfaceTokenizer
from
loguru
import
logger
__all__
=
[
__all__
=
[
"T5Model"
,
"T5Model"
,
...
@@ -522,4 +523,4 @@ if __name__ == "__main__":
...
@@ -522,4 +523,4 @@ if __name__ == "__main__":
)
)
text
=
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
text
=
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
outputs
=
model
.
infer
(
text
)
outputs
=
model
.
infer
(
text
)
print
(
outputs
)
logger
.
info
(
outputs
)
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
View file @
6347b21b
...
@@ -10,6 +10,7 @@ import torchvision.transforms as T
...
@@ -10,6 +10,7 @@ import torchvision.transforms as T
from
lightx2v.attentions
import
attention
from
lightx2v.attentions
import
attention
from
lightx2v.models.input_encoders.hf.t5.tokenizer
import
HuggingfaceTokenizer
from
lightx2v.models.input_encoders.hf.t5.tokenizer
import
HuggingfaceTokenizer
from
loguru
import
logger
from
.xlm_roberta
import
XLMRoberta
from
.xlm_roberta
import
XLMRoberta
...
@@ -190,7 +191,7 @@ class VisionTransformer(nn.Module):
...
@@ -190,7 +191,7 @@ class VisionTransformer(nn.Module):
norm_eps
=
1e-5
,
norm_eps
=
1e-5
,
):
):
if
image_size
%
patch_size
!=
0
:
if
image_size
%
patch_size
!=
0
:
print
(
"[WARNING] image_size is not divisible by patch_size"
,
flush
=
True
)
logger
.
info
(
"[WARNING] image_size is not divisible by patch_size"
,
flush
=
True
)
assert
pool_type
in
(
"token"
,
"token_fc"
,
"attn_pool"
)
assert
pool_type
in
(
"token"
,
"token_fc"
,
"attn_pool"
)
out_dim
=
out_dim
or
dim
out_dim
=
out_dim
or
dim
super
().
__init__
()
super
().
__init__
()
...
...
lightx2v/models/networks/wan/model.py
View file @
6347b21b
...
@@ -123,7 +123,9 @@ class WanModel:
...
@@ -123,7 +123,9 @@ class WanModel:
self
.
scheduler
.
cnt
+=
1
self
.
scheduler
.
cnt
+=
1
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
noise_pred
=
noise_pred_cond
if
self
.
config
[
"enable_cfg"
]:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
...
...
lightx2v/models/runners/default_runner.py
View file @
6347b21b
...
@@ -4,6 +4,7 @@ import torch.distributed as dist
...
@@ -4,6 +4,7 @@ import torch.distributed as dist
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
,
ProfilingContext
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
,
ProfilingContext
from
lightx2v.utils.utils
import
save_videos_grid
,
cache_video
from
lightx2v.utils.utils
import
save_videos_grid
,
cache_video
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
loguru
import
logger
class
DefaultRunner
:
class
DefaultRunner
:
...
@@ -32,7 +33,7 @@ class DefaultRunner:
...
@@ -32,7 +33,7 @@ class DefaultRunner:
def
run
(
self
):
def
run
(
self
):
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
print
(
f
"==> step_index:
{
step_index
+
1
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
with
ProfilingContext4Debug
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
...
...
lightx2v/models/runners/graph_runner.py
View file @
6347b21b
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
from
loguru
import
logger
class
GraphRunner
:
class
GraphRunner
:
...
@@ -7,10 +8,10 @@ class GraphRunner:
...
@@ -7,10 +8,10 @@ class GraphRunner:
self
.
compile
()
self
.
compile
()
def
compile
(
self
):
def
compile
(
self
):
print
(
"start compile..."
)
logger
.
info
(
"start compile..."
)
with
ProfilingContext4Debug
(
"compile"
):
with
ProfilingContext4Debug
(
"compile"
):
self
.
runner
.
run_step
()
self
.
runner
.
run_step
()
print
(
"end compile..."
)
logger
.
info
(
"end compile..."
)
def
run_pipeline
(
self
):
def
run_pipeline
(
self
):
return
self
.
runner
.
run_pipeline
()
return
self
.
runner
.
run_pipeline
()
lightx2v/models/runners/wan/wan_causal_runner.py
View file @
6347b21b
...
@@ -14,6 +14,7 @@ from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
...
@@ -14,6 +14,7 @@ from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from
lightx2v.models.networks.wan.causal_model
import
WanCausalModel
from
lightx2v.models.networks.wan.causal_model
import
WanCausalModel
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
loguru
import
logger
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -54,7 +55,7 @@ class WanCausalRunner(WanRunner):
...
@@ -54,7 +55,7 @@ class WanCausalRunner(WanRunner):
lora_wrapper
=
WanLoraWrapper
(
model
)
lora_wrapper
=
WanLoraWrapper
(
model
)
lora_name
=
lora_wrapper
.
load_lora
(
self
.
config
.
lora_path
)
lora_name
=
lora_wrapper
.
load_lora
(
self
.
config
.
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
self
.
config
.
strength_model
)
lora_wrapper
.
apply_lora
(
lora_name
,
self
.
config
.
strength_model
)
print
(
f
"Loaded LoRA:
{
lora_name
}
"
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
"
)
vae_model
=
WanVAE
(
vae_pth
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"Wan2.1_VAE.pth"
),
device
=
init_device
,
parallel
=
self
.
config
.
parallel_vae
)
vae_model
=
WanVAE
(
vae_pth
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"Wan2.1_VAE.pth"
),
device
=
init_device
,
parallel
=
self
.
config
.
parallel_vae
)
if
self
.
config
.
task
==
"i2v"
:
if
self
.
config
.
task
==
"i2v"
:
...
@@ -95,13 +96,13 @@ class WanCausalRunner(WanRunner):
...
@@ -95,13 +96,13 @@ class WanCausalRunner(WanRunner):
start_block_idx
=
0
start_block_idx
=
0
for
fragment_idx
in
range
(
self
.
num_fragments
):
for
fragment_idx
in
range
(
self
.
num_fragments
):
print
(
f
"=======> fragment_idx:
{
fragment_idx
+
1
}
/
{
self
.
num_fragments
}
"
)
logger
.
info
(
f
"=======> fragment_idx:
{
fragment_idx
+
1
}
/
{
self
.
num_fragments
}
"
)
kv_start
=
0
kv_start
=
0
kv_end
=
kv_start
+
self
.
num_frame_per_block
*
self
.
frame_seq_length
kv_end
=
kv_start
+
self
.
num_frame_per_block
*
self
.
frame_seq_length
if
fragment_idx
>
0
:
if
fragment_idx
>
0
:
print
(
"recompute the kv_cache ..."
)
logger
.
info
(
"recompute the kv_cache ..."
)
with
ProfilingContext4Debug
(
"step_pre"
):
with
ProfilingContext4Debug
(
"step_pre"
):
self
.
model
.
scheduler
.
latents
=
self
.
model
.
scheduler
.
last_sample
self
.
model
.
scheduler
.
latents
=
self
.
model
.
scheduler
.
last_sample
self
.
model
.
scheduler
.
step_pre
(
step_index
=
self
.
model
.
scheduler
.
infer_steps
-
1
)
self
.
model
.
scheduler
.
step_pre
(
step_index
=
self
.
model
.
scheduler
.
infer_steps
-
1
)
...
@@ -115,12 +116,12 @@ class WanCausalRunner(WanRunner):
...
@@ -115,12 +116,12 @@ class WanCausalRunner(WanRunner):
infer_blocks
=
self
.
infer_blocks
-
(
fragment_idx
>
0
)
infer_blocks
=
self
.
infer_blocks
-
(
fragment_idx
>
0
)
for
block_idx
in
range
(
infer_blocks
):
for
block_idx
in
range
(
infer_blocks
):
print
(
f
"=======> block_idx:
{
block_idx
+
1
}
/
{
infer_blocks
}
"
)
logger
.
info
(
f
"=======> block_idx:
{
block_idx
+
1
}
/
{
infer_blocks
}
"
)
print
(
f
"=======> kv_start:
{
kv_start
}
, kv_end:
{
kv_end
}
"
)
logger
.
info
(
f
"=======> kv_start:
{
kv_start
}
, kv_end:
{
kv_end
}
"
)
self
.
model
.
scheduler
.
reset
()
self
.
model
.
scheduler
.
reset
()
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
print
(
f
"==> step_index:
{
step_index
+
1
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
with
ProfilingContext4Debug
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
6347b21b
...
@@ -14,6 +14,7 @@ from lightx2v.models.networks.wan.model import WanModel
...
@@ -14,6 +14,7 @@ from lightx2v.models.networks.wan.model import WanModel
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
loguru
import
logger
@
RUNNER_REGISTER
(
"wan2.1"
)
@
RUNNER_REGISTER
(
"wan2.1"
)
...
@@ -47,7 +48,7 @@ class WanRunner(DefaultRunner):
...
@@ -47,7 +48,7 @@ class WanRunner(DefaultRunner):
lora_wrapper
=
WanLoraWrapper
(
model
)
lora_wrapper
=
WanLoraWrapper
(
model
)
lora_name
=
lora_wrapper
.
load_lora
(
self
.
config
.
lora_path
)
lora_name
=
lora_wrapper
.
load_lora
(
self
.
config
.
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
self
.
config
.
strength_model
)
lora_wrapper
.
apply_lora
(
lora_name
,
self
.
config
.
strength_model
)
print
(
f
"Loaded LoRA:
{
lora_name
}
"
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
"
)
vae_model
=
WanVAE
(
vae_pth
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"Wan2.1_VAE.pth"
),
device
=
init_device
,
parallel
=
self
.
config
.
parallel_vae
)
vae_model
=
WanVAE
(
vae_pth
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"Wan2.1_VAE.pth"
),
device
=
init_device
,
parallel
=
self
.
config
.
parallel_vae
)
if
self
.
config
.
task
==
"i2v"
:
if
self
.
config
.
task
==
"i2v"
:
...
...
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
6347b21b
...
@@ -7,6 +7,7 @@ import torch.nn as nn
...
@@ -7,6 +7,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
einops
import
rearrange
from
einops
import
rearrange
from
loguru
import
logger
__all__
=
[
__all__
=
[
"WanVAE"
,
"WanVAE"
,
...
@@ -801,7 +802,7 @@ class WanVAE:
...
@@ -801,7 +802,7 @@ class WanVAE:
split_dim
=
2
split_dim
=
2
images
=
self
.
decode_dist
(
zs
,
world_size
,
cur_rank
,
split_dim
)
images
=
self
.
decode_dist
(
zs
,
world_size
,
cur_rank
,
split_dim
)
else
:
else
:
print
(
"Fall back to naive decode mode"
)
logger
.
info
(
"Fall back to naive decode mode"
)
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
else
:
else
:
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
...
...
lightx2v/utils/profiler.py
View file @
6347b21b
...
@@ -2,6 +2,7 @@ import time
...
@@ -2,6 +2,7 @@ import time
import
torch
import
torch
from
contextlib
import
ContextDecorator
from
contextlib
import
ContextDecorator
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
loguru
import
logger
class
_ProfilingContext
(
ContextDecorator
):
class
_ProfilingContext
(
ContextDecorator
):
...
@@ -16,7 +17,7 @@ class _ProfilingContext(ContextDecorator):
...
@@ -16,7 +17,7 @@ class _ProfilingContext(ContextDecorator):
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
print
(
f
"[Profile]
{
self
.
name
}
cost
{
elapsed
:.
6
f
}
seconds"
)
logger
.
info
(
f
"[Profile]
{
self
.
name
}
cost
{
elapsed
:.
6
f
}
seconds"
)
return
False
return
False
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment