Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
db5fa430
Unverified
Commit
db5fa430
authored
Aug 22, 2022
by
Patrick von Platen
Committed by
GitHub
Aug 22, 2022
Browse files
[Loading] allow modules to be loaded in fp16 (#230)
parent
0ab94856
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
2 deletions
+14
-2
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+7
-0
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+7
-2
No files found.
src/diffusers/modeling_utils.py
View file @
db5fa430
...
...
@@ -315,6 +315,7 @@ class ModelMixin(torch.nn.Module):
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
from_auto_class
=
kwargs
.
pop
(
"_from_auto"
,
False
)
torch_dtype
=
kwargs
.
pop
(
"torch_dtype"
,
None
)
subfolder
=
kwargs
.
pop
(
"subfolder"
,
None
)
user_agent
=
{
"file_type"
:
"model"
,
"framework"
:
"pytorch"
,
"from_auto_class"
:
from_auto_class
}
...
...
@@ -334,6 +335,12 @@ class ModelMixin(torch.nn.Module):
subfolder
=
subfolder
,
**
kwargs
,
)
if
torch_dtype
is
not
None
and
not
isinstance
(
torch_dtype
,
torch
.
dtype
):
raise
ValueError
(
f
"
{
torch_dtype
}
needs to be of type `torch.dtype`, e.g. `torch.float16`, but is
{
type
(
torch_dtype
)
}
."
)
elif
torch_dtype
is
not
None
:
model
=
model
.
to
(
torch_dtype
)
model
.
register_to_config
(
_name_or_path
=
pretrained_model_name_or_path
)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model
...
...
src/diffusers/pipeline_utils.py
View file @
db5fa430
...
...
@@ -146,6 +146,7 @@ class DiffusionPipeline(ConfigMixin):
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
torch_dtype
=
kwargs
.
pop
(
"torch_dtype"
,
None
)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
...
...
@@ -237,12 +238,16 @@ class DiffusionPipeline(ConfigMixin):
load_method
=
getattr
(
class_obj
,
load_method_name
)
loading_kwargs
=
{}
if
issubclass
(
class_obj
,
torch
.
nn
.
Module
):
loading_kwargs
[
"torch_dtype"
]
=
torch_dtype
# check if the module is in a subdirectory
if
os
.
path
.
isdir
(
os
.
path
.
join
(
cached_folder
,
name
)):
loaded_sub_model
=
load_method
(
os
.
path
.
join
(
cached_folder
,
name
))
loaded_sub_model
=
load_method
(
os
.
path
.
join
(
cached_folder
,
name
)
,
**
loading_kwargs
)
else
:
# else load from the root directory
loaded_sub_model
=
load_method
(
cached_folder
)
loaded_sub_model
=
load_method
(
cached_folder
,
**
loading_kwargs
)
init_kwargs
[
name
]
=
loaded_sub_model
# UNet(...), # DiffusionSchedule(...)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment