Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
07b6d0e7
Commit
07b6d0e7
authored
Jun 07, 2022
by
Patrick von Platen
Browse files
rename modeling code
parent
80b86587
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
41 additions
and
41 deletions
+41
-41
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-1
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+4
-4
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+13
-13
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+3
-3
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+4
-4
src/diffusers/schedulers/gaussian_ddpm.py
src/diffusers/schedulers/gaussian_ddpm.py
+2
-2
utils/check_config_docstrings.py
utils/check_config_docstrings.py
+7
-7
utils/check_repo.py
utils/check_repo.py
+5
-5
utils/check_table.py
utils/check_table.py
+2
-2
No files found.
src/diffusers/__init__.py
View file @
07b6d0e7
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
__version__
=
"0.0.1"
__version__
=
"0.0.1"
from
.modeling_utils
import
PreTrained
Model
from
.modeling_utils
import
Model
Mixin
from
.models.unet
import
UNetModel
from
.models.unet
import
UNetModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
from
.schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
src/diffusers/configuration_utils.py
View file @
07b6d0e7
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
""" Configuration base class and utilities."""
""" Config
Mixin
uration base class and utilities."""
import
copy
import
copy
...
@@ -44,7 +44,7 @@ logger = logging.get_logger(__name__)
...
@@ -44,7 +44,7 @@ logger = logging.get_logger(__name__)
_re_configuration_file
=
re
.
compile
(
r
"config\.(.*)\.json"
)
_re_configuration_file
=
re
.
compile
(
r
"config\.(.*)\.json"
)
class
Config
:
class
Config
Mixin
:
r
"""
r
"""
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
methods for loading/downloading/saving configurations.
methods for loading/downloading/saving configurations.
...
@@ -71,7 +71,7 @@ class Config:
...
@@ -71,7 +71,7 @@ class Config:
def
save_config
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
push_to_hub
:
bool
=
False
,
**
kwargs
):
def
save_config
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
push_to_hub
:
bool
=
False
,
**
kwargs
):
"""
"""
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
[`~Config.from_config`] class method.
[`~Config
Mixin
.from_config`] class method.
Args:
Args:
save_directory (`str` or `os.PathLike`):
save_directory (`str` or `os.PathLike`):
...
@@ -88,7 +88,7 @@ class Config:
...
@@ -88,7 +88,7 @@ class Config:
output_config_file
=
os
.
path
.
join
(
save_directory
,
self
.
config_name
)
output_config_file
=
os
.
path
.
join
(
save_directory
,
self
.
config_name
)
self
.
to_json_file
(
output_config_file
)
self
.
to_json_file
(
output_config_file
)
logger
.
info
(
f
"Configuration saved in
{
output_config_file
}
"
)
logger
.
info
(
f
"Config
Mixin
uration saved in
{
output_config_file
}
"
)
@
classmethod
@
classmethod
def
get_config_dict
(
def
get_config_dict
(
...
...
src/diffusers/modeling_utils.py
View file @
07b6d0e7
...
@@ -122,11 +122,11 @@ def _load_state_dict_into_model(model_to_load, state_dict):
...
@@ -122,11 +122,11 @@ def _load_state_dict_into_model(model_to_load, state_dict):
return
error_msgs
return
error_msgs
class
PreTrained
Model
(
torch
.
nn
.
Module
):
class
Model
Mixin
(
torch
.
nn
.
Module
):
r
"""
r
"""
Base class for all models.
Base class for all models.
[`
PreTrained
Model`] takes care of storing the configuration of the models and handles methods for loading,
[`Model
Mixin
`] takes care of storing the configuration of the models and handles methods for loading,
downloading and saving models as well as a few methods common to all models to:
downloading and saving models as well as a few methods common to all models to:
- resize the input embeddings,
- resize the input embeddings,
...
@@ -134,13 +134,13 @@ class PreTrainedModel(torch.nn.Module):
...
@@ -134,13 +134,13 @@ class PreTrainedModel(torch.nn.Module):
Class attributes (overridden by derived classes):
Class attributes (overridden by derived classes):
- **config_class** ([`Config`]) -- A subclass of [`Config`] to use as configuration class
- **config_class** ([`Config
Mixin
`]) -- A subclass of [`Config
Mixin
`] to use as configuration class
for this model architecture.
for this model architecture.
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
taking as arguments:
taking as arguments:
- **model** ([`
PreTrained
Model`]) -- An instance of the model on which to load the TensorFlow checkpoint.
- **model** ([`Model
Mixin
`]) -- An instance of the model on which to load the TensorFlow checkpoint.
- **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model.
- **config** ([`PreTrainedConfig
Mixin
`]) -- An instance of the configuration associated to the model.
- **path** (`str`) -- A path to the TensorFlow checkpoint.
- **path** (`str`) -- A path to the TensorFlow checkpoint.
- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
...
@@ -163,7 +163,7 @@ class PreTrainedModel(torch.nn.Module):
...
@@ -163,7 +163,7 @@ class PreTrainedModel(torch.nn.Module):
):
):
"""
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`[`~
PreTrained
Model.from_pretrained`]` class method.
`[`~Model
Mixin
.from_pretrained`]` class method.
Arguments:
Arguments:
save_directory (`str` or `os.PathLike`):
save_directory (`str` or `os.PathLike`):
...
@@ -231,20 +231,20 @@ class PreTrainedModel(torch.nn.Module):
...
@@ -231,20 +231,20 @@ class PreTrainedModel(torch.nn.Module):
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
user or organization name, like `dbmdz/bert-base-german-cased`.
user or organization name, like `dbmdz/bert-base-german-cased`.
- A path to a *directory* containing model weights saved using
- A path to a *directory* containing model weights saved using
[`~
PreTrained
Model.save_pretrained`], e.g., `./my_model_directory/`.
[`~Model
Mixin
.save_pretrained`], e.g., `./my_model_directory/`.
config (`Union[Config, str, os.PathLike]`, *optional*):
config (`Union[Config
Mixin
, str, os.PathLike]`, *optional*):
Can be either:
Can be either:
- an instance of a class derived from [`Config`],
- an instance of a class derived from [`Config
Mixin
`],
- a string or path valid as input to [`~Config.from_pretrained`].
- a string or path valid as input to [`~Config
Mixin
.from_pretrained`].
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
Config
Mixin
uration for the model to use instead of an automatically loaded configuration. Config
Mixin
uration can
be automatically loaded when:
be automatically loaded when:
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
model).
model).
- The model was saved using [`~
PreTrained
Model.save_pretrained`] and is reloaded by supplying the
- The model was saved using [`~Model
Mixin
.save_pretrained`] and is reloaded by supplying the
save directory.
save directory.
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
configuration JSON file named *config.json* is found in the directory.
configuration JSON file named *config.json* is found in the directory.
...
@@ -295,7 +295,7 @@ class PreTrainedModel(torch.nn.Module):
...
@@ -295,7 +295,7 @@ class PreTrainedModel(torch.nn.Module):
underlying model's `__init__` method (we assume all relevant updates to the configuration have
underlying model's `__init__` method (we assume all relevant updates to the configuration have
already been done)
already been done)
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
initialization function ([`~Config.from_pretrained`]). Each key of `kwargs` that
initialization function ([`~Config
Mixin
.from_pretrained`]). Each key of `kwargs` that
corresponds to a configuration attribute will be used to override said attribute with the
corresponds to a configuration attribute will be used to override said attribute with the
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
will be passed to the underlying model's `__init__` function.
will be passed to the underlying model's `__init__` function.
...
...
src/diffusers/models/unet.py
View file @
07b6d0e7
...
@@ -29,8 +29,8 @@ from torchvision import transforms, utils
...
@@ -29,8 +29,8 @@ from torchvision import transforms, utils
from
PIL
import
Image
from
PIL
import
Image
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
..configuration_utils
import
Config
from
..configuration_utils
import
Config
Mixin
from
..modeling_utils
import
PreTrained
Model
from
..modeling_utils
import
Model
Mixin
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
...
@@ -175,7 +175,7 @@ class AttnBlock(nn.Module):
...
@@ -175,7 +175,7 @@ class AttnBlock(nn.Module):
return
x
+
h_
return
x
+
h_
class
UNetModel
(
PreTrained
Model
,
Config
):
class
UNetModel
(
Model
Mixin
,
Config
Mixin
):
def
__init__
(
def
__init__
(
self
,
self
,
ch
=
128
,
ch
=
128
,
...
...
src/diffusers/pipeline_utils.py
View file @
07b6d0e7
...
@@ -22,7 +22,7 @@ from huggingface_hub import snapshot_download
...
@@ -22,7 +22,7 @@ from huggingface_hub import snapshot_download
# CHANGE to diffusers.utils
# CHANGE to diffusers.utils
from
transformers.utils
import
logging
from
transformers.utils
import
logging
from
.configuration_utils
import
Config
from
.configuration_utils
import
Config
Mixin
INDEX_FILE
=
"diffusion_model.pt"
INDEX_FILE
=
"diffusion_model.pt"
...
@@ -33,16 +33,16 @@ logger = logging.get_logger(__name__)
...
@@ -33,16 +33,16 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES
=
{
LOADABLE_CLASSES
=
{
"diffusers"
:
{
"diffusers"
:
{
"
PreTrained
Model"
:
[
"save_pretrained"
,
"from_pretrained"
],
"Model
Mixin
"
:
[
"save_pretrained"
,
"from_pretrained"
],
"GaussianDDPMScheduler"
:
[
"save_config"
,
"from_config"
],
"GaussianDDPMScheduler"
:
[
"save_config"
,
"from_config"
],
},
},
"transformers"
:
{
"transformers"
:
{
"
PreTrained
Model"
:
[
"save_pretrained"
,
"from_pretrained"
],
"Model
Mixin
"
:
[
"save_pretrained"
,
"from_pretrained"
],
},
},
}
}
class
DiffusionPipeline
(
Config
):
class
DiffusionPipeline
(
Config
Mixin
):
config_name
=
"model_index.json"
config_name
=
"model_index.json"
...
...
src/diffusers/schedulers/gaussian_ddpm.py
View file @
07b6d0e7
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
..configuration_utils
import
Config
from
..configuration_utils
import
Config
Mixin
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
...
@@ -24,7 +24,7 @@ def linear_beta_schedule(timesteps, beta_start, beta_end):
...
@@ -24,7 +24,7 @@ def linear_beta_schedule(timesteps, beta_start, beta_end):
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
class
GaussianDDPMScheduler
(
nn
.
Module
,
Config
):
class
GaussianDDPMScheduler
(
nn
.
Module
,
Config
Mixin
):
config_name
=
SAMPLING_CONFIG_NAME
config_name
=
SAMPLING_CONFIG_NAME
...
...
utils/check_config_docstrings.py
View file @
07b6d0e7
...
@@ -40,13 +40,13 @@ _re_checkpoint = re.compile("\[(.+?)\]\((https://huggingface\.co/.+?)\)")
...
@@ -40,13 +40,13 @@ _re_checkpoint = re.compile("\[(.+?)\]\((https://huggingface\.co/.+?)\)")
CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK
=
{
CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK
=
{
"CLIPConfig"
,
"CLIPConfig
Mixin
"
,
"DecisionTransformerConfig"
,
"DecisionTransformerConfig
Mixin
"
,
"EncoderDecoderConfig"
,
"EncoderDecoderConfig
Mixin
"
,
"RagConfig"
,
"RagConfig
Mixin
"
,
"SpeechEncoderDecoderConfig"
,
"SpeechEncoderDecoderConfig
Mixin
"
,
"VisionEncoderDecoderConfig"
,
"VisionEncoderDecoderConfig
Mixin
"
,
"VisionTextDualEncoderConfig"
,
"VisionTextDualEncoderConfig
Mixin
"
,
}
}
...
...
utils/check_repo.py
View file @
07b6d0e7
...
@@ -87,7 +87,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
...
@@ -87,7 +87,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"ReformerForMaskedLM"
,
# Needs to be setup as decoder.
"ReformerForMaskedLM"
,
# Needs to be setup as decoder.
"Speech2Text2DecoderWrapper"
,
# Building part of bigger (tested) model.
"Speech2Text2DecoderWrapper"
,
# Building part of bigger (tested) model.
"TFDPREncoder"
,
# Building part of bigger (tested) model.
"TFDPREncoder"
,
# Building part of bigger (tested) model.
"TFElectraMainLayer"
,
# Building part of bigger (tested) model (should it be a TF
PreTrained
Model ?)
"TFElectraMainLayer"
,
# Building part of bigger (tested) model (should it be a TFModel
Mixin
?)
"TFRobertaForMultipleChoice"
,
# TODO: fix
"TFRobertaForMultipleChoice"
,
# TODO: fix
"TrOCRDecoderWrapper"
,
# Building part of bigger (tested) model.
"TrOCRDecoderWrapper"
,
# Building part of bigger (tested) model.
"SeparableConv1D"
,
# Building part of bigger (tested) model.
"SeparableConv1D"
,
# Building part of bigger (tested) model.
...
@@ -271,7 +271,7 @@ def get_model_modules():
...
@@ -271,7 +271,7 @@ def get_model_modules():
def
get_models
(
module
,
include_pretrained
=
False
):
def
get_models
(
module
,
include_pretrained
=
False
):
"""Get the objects in module that are models."""
"""Get the objects in module that are models."""
models
=
[]
models
=
[]
model_classes
=
(
transformers
.
PreTrained
Model
,
transformers
.
TF
PreTrained
Model
,
transformers
.
Flax
PreTrained
Model
)
model_classes
=
(
transformers
.
Model
Mixin
,
transformers
.
TFModel
Mixin
,
transformers
.
FlaxModel
Mixin
)
for
attr_name
in
dir
(
module
):
for
attr_name
in
dir
(
module
):
if
not
include_pretrained
and
(
"Pretrained"
in
attr_name
or
"PreTrained"
in
attr_name
):
if
not
include_pretrained
and
(
"Pretrained"
in
attr_name
or
"PreTrained"
in
attr_name
):
continue
continue
...
@@ -372,7 +372,7 @@ def find_tested_models(test_file):
...
@@ -372,7 +372,7 @@ def find_tested_models(test_file):
def
check_models_are_tested
(
module
,
test_file
):
def
check_models_are_tested
(
module
,
test_file
):
"""Check models defined in module are tested in test_file."""
"""Check models defined in module are tested in test_file."""
# Xxx
PreTrained
Model are not tested
# XxxModel
Mixin
are not tested
defined_models
=
get_models
(
module
)
defined_models
=
get_models
(
module
)
tested_models
=
find_tested_models
(
test_file
)
tested_models
=
find_tested_models
(
test_file
)
if
tested_models
is
None
:
if
tested_models
is
None
:
...
@@ -625,9 +625,9 @@ def ignore_undocumented(name):
...
@@ -625,9 +625,9 @@ def ignore_undocumented(name):
# Constants uppercase are not documented.
# Constants uppercase are not documented.
if
name
.
isupper
():
if
name
.
isupper
():
return
True
return
True
#
PreTrained
Models / Encoders / Decoders / Layers / Embeddings / Attention are not documented.
# Model
Mixin
s / Encoders / Decoders / Layers / Embeddings / Attention are not documented.
if
(
if
(
name
.
endswith
(
"
PreTrained
Model"
)
name
.
endswith
(
"Model
Mixin
"
)
or
name
.
endswith
(
"Decoder"
)
or
name
.
endswith
(
"Decoder"
)
or
name
.
endswith
(
"Encoder"
)
or
name
.
endswith
(
"Encoder"
)
or
name
.
endswith
(
"Layer"
)
or
name
.
endswith
(
"Layer"
)
...
...
utils/check_table.py
View file @
07b6d0e7
...
@@ -94,7 +94,7 @@ def get_model_table_from_auto_modules():
...
@@ -94,7 +94,7 @@ def get_model_table_from_auto_modules():
for
code
,
name
in
transformers_module
.
MODEL_NAMES_MAPPING
.
items
()
for
code
,
name
in
transformers_module
.
MODEL_NAMES_MAPPING
.
items
()
if
code
in
config_maping_names
if
code
in
config_maping_names
}
}
model_name_to_prefix
=
{
name
:
config
.
replace
(
"Config"
,
""
)
for
name
,
config
in
model_name_to_config
.
items
()}
model_name_to_prefix
=
{
name
:
config
.
replace
(
"Config
Mixin
"
,
""
)
for
name
,
config
in
model_name_to_config
.
items
()}
# Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
# Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
slow_tokenizers
=
collections
.
defaultdict
(
bool
)
slow_tokenizers
=
collections
.
defaultdict
(
bool
)
...
@@ -190,7 +190,7 @@ def has_onnx(model_type):
...
@@ -190,7 +190,7 @@ def has_onnx(model_type):
for
part
in
config_module
.
split
(
"."
)[
1
:]:
for
part
in
config_module
.
split
(
"."
)[
1
:]:
module
=
getattr
(
module
,
part
)
module
=
getattr
(
module
,
part
)
config_name
=
config
.
__name__
config_name
=
config
.
__name__
onnx_config_name
=
config_name
.
replace
(
"Config"
,
"OnnxConfig"
)
onnx_config_name
=
config_name
.
replace
(
"Config
Mixin
"
,
"OnnxConfig
Mixin
"
)
return
hasattr
(
module
,
onnx_config_name
)
return
hasattr
(
module
,
onnx_config_name
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment