Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
6c5817f8
"092/backend_request_func.py" did not exist on "49c10c0c9cefa0184cfd85ac20f3cfb9a085ff76"
Commit
6c5817f8
authored
Jun 09, 2025
by
谷石桥
Browse files
Fix
parent
2ef8e74e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
8 deletions
+8
-8
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+8
-8
No files found.
lightx2v/models/networks/wan/model.py
View file @
6c5817f8
...
...
@@ -66,9 +66,9 @@ class WanModel:
use_bfloat16
=
self
.
config
.
get
(
"use_bfloat16"
,
True
)
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
if
use_bfloat16
:
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
).
to
(
torch
.
bfloat16
).
to
(
self
.
device
)
for
key
in
f
.
keys
()}
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
).
pin_memory
().
to
(
torch
.
bfloat16
).
to
(
self
.
device
)
for
key
in
f
.
keys
()}
else
:
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
).
to
(
self
.
device
)
for
key
in
f
.
keys
()}
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
).
pin_memory
().
to
(
self
.
device
)
for
key
in
f
.
keys
()}
return
tensor_dict
def
_load_ckpt
(
self
):
...
...
@@ -107,9 +107,9 @@ class WanModel:
with
safe_open
(
safetensor_path
,
framework
=
"pt"
,
device
=
str
(
self
.
device
))
as
f
:
logger
.
info
(
f
"Loading weights from
{
safetensor_path
}
"
)
for
k
in
f
.
keys
():
weight_dict
[
k
]
=
f
.
get_tensor
(
k
)
weight_dict
[
k
]
=
f
.
get_tensor
(
k
)
.
pin_memory
()
if
weight_dict
[
k
].
dtype
==
torch
.
float
:
weight_dict
[
k
]
=
weight_dict
[
k
].
to
(
torch
.
bfloat16
)
weight_dict
[
k
]
=
weight_dict
[
k
].
pin_memory
().
to
(
torch
.
bfloat16
)
return
weight_dict
...
...
@@ -121,9 +121,9 @@ class WanModel:
safetensor_path
=
os
.
path
.
join
(
lazy_load_model_path
,
"non_block.safetensors"
)
with
safe_open
(
safetensor_path
,
framework
=
"pt"
,
device
=
str
(
self
.
device
))
as
f
:
for
k
in
f
.
keys
():
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
)
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
)
.
pin_memory
()
if
pre_post_weight_dict
[
k
].
dtype
==
torch
.
float
:
pre_post_weight_dict
[
k
]
=
pre_post_weight_dict
[
k
].
to
(
torch
.
bfloat16
)
pre_post_weight_dict
[
k
]
=
pre_post_weight_dict
[
k
].
pin_memory
().
to
(
torch
.
bfloat16
)
safetensors_pattern
=
os
.
path
.
join
(
lazy_load_model_path
,
"block_*.safetensors"
)
safetensors_files
=
glob
.
glob
(
safetensors_pattern
)
...
...
@@ -134,9 +134,9 @@ class WanModel:
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
for
k
in
f
.
keys
():
if
"modulation"
in
k
:
transformer_weight_dict
[
k
]
=
f
.
get_tensor
(
k
)
transformer_weight_dict
[
k
]
=
f
.
get_tensor
(
k
)
.
pin_memory
()
if
transformer_weight_dict
[
k
].
dtype
==
torch
.
float
:
transformer_weight_dict
[
k
]
=
transformer_weight_dict
[
k
].
to
(
torch
.
bfloat16
)
transformer_weight_dict
[
k
]
=
transformer_weight_dict
[
k
].
pin_memory
().
to
(
torch
.
bfloat16
)
return
pre_post_weight_dict
,
transformer_weight_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment