Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
80de641c
Unverified
Commit
80de641c
authored
Sep 23, 2025
by
Dhruv Nair
Committed by
GitHub
Sep 23, 2025
Browse files
Allow Automodel to support custom model code (#12353)
* update * update
parent
76810eca
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
24 deletions
+52
-24
src/diffusers/models/auto_model.py
src/diffusers/models/auto_model.py
+47
-24
src/diffusers/utils/dynamic_modules_utils.py
src/diffusers/utils/dynamic_modules_utils.py
+5
-0
No files found.
src/diffusers/models/auto_model.py
View file @
80de641c
...
@@ -19,6 +19,7 @@ from huggingface_hub.utils import validate_hf_hub_args
...
@@ -19,6 +19,7 @@ from huggingface_hub.utils import validate_hf_hub_args
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..utils
import
logging
from
..utils
import
logging
from
..utils.dynamic_modules_utils
import
get_class_from_dynamic_module
,
resolve_trust_remote_code
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
...
@@ -114,6 +115,8 @@ class AutoModel(ConfigMixin):
...
@@ -114,6 +115,8 @@ class AutoModel(ConfigMixin):
disable_mmap ('bool', *optional*, defaults to 'False'):
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
trust_remote_cocde (`bool`, *optional*, defaults to `False`):
Whether to trust remote code
<Tip>
<Tip>
...
@@ -140,22 +143,22 @@ class AutoModel(ConfigMixin):
...
@@ -140,22 +143,22 @@ class AutoModel(ConfigMixin):
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
```
"""
"""
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
None
)
force_download
=
kwargs
.
pop
(
"force_download"
,
False
)
proxies
=
kwargs
.
pop
(
"proxies"
,
None
)
token
=
kwargs
.
pop
(
"token"
,
None
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
subfolder
=
kwargs
.
pop
(
"subfolder"
,
None
)
subfolder
=
kwargs
.
pop
(
"subfolder"
,
None
)
trust_remote_code
=
kwargs
.
pop
(
"trust_remote_code"
,
False
)
load_config_kwargs
=
{
"cache_dir"
:
cache_dir
,
hub_kwargs_names
=
[
"force_download"
:
force_download
,
"cache_dir"
,
"proxies"
:
proxies
,
"force_download"
,
"token"
:
token
,
"local_files_only"
,
"local_files_only"
:
local_files_only
,
"proxies"
,
"revision"
:
revision
,
"resume_download"
,
}
"revision"
,
"token"
,
]
hub_kwargs
=
{
name
:
kwargs
.
pop
(
name
,
None
)
for
name
in
hub_kwargs_names
}
# load_config_kwargs uses the same hub kwargs minus subfolder and resume_download
load_config_kwargs
=
{
k
:
v
for
k
,
v
in
hub_kwargs
.
items
()
if
k
not
in
[
"subfolder"
,
"resume_download"
]}
library
=
None
library
=
None
orig_class_name
=
None
orig_class_name
=
None
...
@@ -189,15 +192,35 @@ class AutoModel(ConfigMixin):
...
@@ -189,15 +192,35 @@ class AutoModel(ConfigMixin):
else
:
else
:
raise
ValueError
(
f
"Couldn't find model associated with the config file at
{
pretrained_model_or_path
}
."
)
raise
ValueError
(
f
"Couldn't find model associated with the config file at
{
pretrained_model_or_path
}
."
)
from
..pipelines.pipeline_loading_utils
import
ALL_IMPORTABLE_CLASSES
,
get_class_obj_and_candidates
has_remote_code
=
"auto_map"
in
config
and
cls
.
__name__
in
config
[
"auto_map"
]
trust_remote_code
=
resolve_trust_remote_code
(
trust_remote_code
,
pretrained_model_or_path
,
has_remote_code
)
model_cls
,
_
=
get_class_obj_and_candidates
(
if
not
(
has_remote_code
and
trust_remote_code
):
library_name
=
library
,
raise
ValueError
(
class_name
=
orig_class_name
,
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
importable_classes
=
ALL_IMPORTABLE_CLASSES
,
)
pipelines
=
None
,
is_pipeline_module
=
False
,
if
has_remote_code
and
trust_remote_code
:
)
class_ref
=
config
[
"auto_map"
][
cls
.
__name__
]
module_file
,
class_name
=
class_ref
.
split
(
"."
)
module_file
=
module_file
+
".py"
model_cls
=
get_class_from_dynamic_module
(
pretrained_model_or_path
,
subfolder
=
subfolder
,
module_file
=
module_file
,
class_name
=
class_name
,
**
hub_kwargs
,
**
kwargs
,
)
else
:
from
..pipelines.pipeline_loading_utils
import
ALL_IMPORTABLE_CLASSES
,
get_class_obj_and_candidates
model_cls
,
_
=
get_class_obj_and_candidates
(
library_name
=
library
,
class_name
=
orig_class_name
,
importable_classes
=
ALL_IMPORTABLE_CLASSES
,
pipelines
=
None
,
is_pipeline_module
=
False
,
)
if
model_cls
is
None
:
if
model_cls
is
None
:
raise
ValueError
(
f
"AutoModel can't find a model linked to
{
orig_class_name
}
."
)
raise
ValueError
(
f
"AutoModel can't find a model linked to
{
orig_class_name
}
."
)
...
...
src/diffusers/utils/dynamic_modules_utils.py
View file @
80de641c
...
@@ -247,6 +247,7 @@ def find_pipeline_class(loaded_module):
...
@@ -247,6 +247,7 @@ def find_pipeline_class(loaded_module):
def
get_cached_module_file
(
def
get_cached_module_file
(
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
module_file
:
str
,
module_file
:
str
,
subfolder
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
Union
[
str
,
os
.
PathLike
]]
=
None
,
cache_dir
:
Optional
[
Union
[
str
,
os
.
PathLike
]]
=
None
,
force_download
:
bool
=
False
,
force_download
:
bool
=
False
,
proxies
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
proxies
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
...
@@ -353,6 +354,7 @@ def get_cached_module_file(
...
@@ -353,6 +354,7 @@ def get_cached_module_file(
resolved_module_file
=
hf_hub_download
(
resolved_module_file
=
hf_hub_download
(
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
module_file
,
module_file
,
subfolder
=
subfolder
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
force_download
=
force_download
,
proxies
=
proxies
,
proxies
=
proxies
,
...
@@ -410,6 +412,7 @@ def get_cached_module_file(
...
@@ -410,6 +412,7 @@ def get_cached_module_file(
get_cached_module_file
(
get_cached_module_file
(
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
f
"
{
module_needed
}
.py"
,
f
"
{
module_needed
}
.py"
,
subfolder
=
subfolder
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
force_download
=
force_download
,
proxies
=
proxies
,
proxies
=
proxies
,
...
@@ -424,6 +427,7 @@ def get_cached_module_file(
...
@@ -424,6 +427,7 @@ def get_cached_module_file(
def
get_class_from_dynamic_module
(
def
get_class_from_dynamic_module
(
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
module_file
:
str
,
module_file
:
str
,
subfolder
:
Optional
[
str
]
=
None
,
class_name
:
Optional
[
str
]
=
None
,
class_name
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
Union
[
str
,
os
.
PathLike
]]
=
None
,
cache_dir
:
Optional
[
Union
[
str
,
os
.
PathLike
]]
=
None
,
force_download
:
bool
=
False
,
force_download
:
bool
=
False
,
...
@@ -497,6 +501,7 @@ def get_class_from_dynamic_module(
...
@@ -497,6 +501,7 @@ def get_class_from_dynamic_module(
final_module
=
get_cached_module_file
(
final_module
=
get_cached_module_file
(
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
module_file
,
module_file
,
subfolder
=
subfolder
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
force_download
=
force_download
,
proxies
=
proxies
,
proxies
=
proxies
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment