Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
27039cd3
Commit
27039cd3
authored
Jun 09, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
parents
8841d0d1
1f4d817c
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1430 additions
and
1437 deletions
+1430
-1437
models/vision/glide/convert_weights.py
models/vision/glide/convert_weights.py
+1
-2
models/vision/glide/modeling_glide.py
models/vision/glide/modeling_glide.py
+679
-5
models/vision/glide/run_glide.py
models/vision/glide/run_glide.py
+2
-1
models/vision/latent_diffusion/modeling_latent_diffusion.py
models/vision/latent_diffusion/modeling_latent_diffusion.py
+721
-0
src/diffusers/__init__.py
src/diffusers/__init__.py
+0
-2
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+0
-2
src/diffusers/models/clip_text_transformer.py
src/diffusers/models/clip_text_transformer.py
+0
-685
src/diffusers/models/vqvae.py
src/diffusers/models/vqvae.py
+0
-721
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+27
-19
No files found.
models/vision/glide/convert_weights.py
View file @
27039cd3
...
...
@@ -3,12 +3,11 @@ from torch import nn
from
diffusers
import
(
ClassifierFreeGuidanceScheduler
,
CLIPTextModel
,
GlideDDIMScheduler
,
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
,
)
from
modeling_glide
import
GLIDE
from
modeling_glide
import
GLIDE
,
CLIPTextModel
from
transformers
import
CLIPTextConfig
,
GPT2Tokenizer
...
...
models/vision/glide/modeling_glide.py
View file @
27039cd3
# Copyright 2022 The HuggingFace Team. All rights reserved.
# coding=utf-8
# Copyright 2022 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -10,23 +11,696 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch CLIP model."""
import
math
from
dataclasses
import
dataclass
from
typing
import
Any
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch.utils.checkpoint
from
torch
import
nn
import
tqdm
from
diffusers
import
(
ClassifierFreeGuidanceScheduler
,
CLIPTextModel
,
DiffusionPipeline
,
GlideDDIMScheduler
,
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
,
)
from
transformers
import
GPT2Tokenizer
from
transformers
import
CLIPConfig
,
CLIPModel
,
CLIPTextConfig
,
CLIPVisionConfig
,
GPT2Tokenizer
from
transformers.activations
import
ACT2FN
from
transformers.modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.utils
import
(
ModelOutput
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
logging
,
replace_return_docstrings
,
)
#####################
# START OF THE CLIP MODEL COPY-PASTE (with a modified attention module)
#####################
logger
=
logging
.
get_logger
(
__name__
)
_CHECKPOINT_FOR_DOC
=
"fusing/glide-base"
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST
=
[
"fusing/glide-base"
,
# See all CLIP models at https://huggingface.co/models?filter=clip
]
# Copied from transformers.models.bart.modeling_bart._expand_mask
def
_expand_mask
(
mask
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
tgt_len
:
Optional
[
int
]
=
None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz
,
src_len
=
mask
.
size
()
tgt_len
=
tgt_len
if
tgt_len
is
not
None
else
src_len
expanded_mask
=
mask
[:,
None
,
None
,
:].
expand
(
bsz
,
1
,
tgt_len
,
src_len
).
to
(
dtype
)
inverted_mask
=
1.0
-
expanded_mask
return
inverted_mask
.
masked_fill
(
inverted_mask
.
to
(
torch
.
bool
),
torch
.
finfo
(
dtype
).
min
)
# contrastive loss function, adapted from
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
def
contrastive_loss
(
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
nn
.
functional
.
cross_entropy
(
logits
,
torch
.
arange
(
len
(
logits
),
device
=
logits
.
device
))
def
clip_loss
(
similarity
:
torch
.
Tensor
)
->
torch
.
Tensor
:
caption_loss
=
contrastive_loss
(
similarity
)
image_loss
=
contrastive_loss
(
similarity
.
T
)
return
(
caption_loss
+
image_loss
)
/
2.0
@
dataclass
class
CLIPOutput
(
ModelOutput
):
"""
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
Contrastive loss for image-text similarity.
logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
similarity scores.
logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
similarity scores.
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
text_model_output(`BaseModelOutputWithPooling`):
The output of the [`CLIPTextModel`].
vision_model_output(`BaseModelOutputWithPooling`):
The output of the [`CLIPVisionModel`].
"""
loss
:
Optional
[
torch
.
FloatTensor
]
=
None
logits_per_image
:
torch
.
FloatTensor
=
None
logits_per_text
:
torch
.
FloatTensor
=
None
text_embeds
:
torch
.
FloatTensor
=
None
image_embeds
:
torch
.
FloatTensor
=
None
text_model_output
:
BaseModelOutputWithPooling
=
None
vision_model_output
:
BaseModelOutputWithPooling
=
None
def
to_tuple
(
self
)
->
Tuple
[
Any
]:
return
tuple
(
self
[
k
]
if
k
not
in
[
"text_model_output"
,
"vision_model_output"
]
else
getattr
(
self
,
k
).
to_tuple
()
for
k
in
self
.
keys
()
)
class
CLIPVisionEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPVisionConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
class_embedding
=
nn
.
Parameter
(
torch
.
randn
(
self
.
embed_dim
))
self
.
patch_embedding
=
nn
.
Conv2d
(
in_channels
=
3
,
out_channels
=
self
.
embed_dim
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
bias
=
False
)
self
.
num_patches
=
(
self
.
image_size
//
self
.
patch_size
)
**
2
self
.
num_positions
=
self
.
num_patches
+
1
self
.
position_embedding
=
nn
.
Embedding
(
self
.
num_positions
,
self
.
embed_dim
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
self
.
num_positions
).
expand
((
1
,
-
1
)))
def
forward
(
self
,
pixel_values
:
torch
.
FloatTensor
)
->
torch
.
Tensor
:
batch_size
=
pixel_values
.
shape
[
0
]
patch_embeds
=
self
.
patch_embedding
(
pixel_values
)
# shape = [*, width, grid, grid]
patch_embeds
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
class_embeds
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
-
1
)
embeddings
=
torch
.
cat
([
class_embeds
,
patch_embeds
],
dim
=
1
)
embeddings
=
embeddings
+
self
.
position_embedding
(
self
.
position_ids
)
return
embeddings
class
CLIPTextEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPTextConfig
):
super
().
__init__
()
embed_dim
=
config
.
hidden_size
self
.
token_embedding
=
nn
.
Embedding
(
config
.
vocab_size
,
embed_dim
)
self
.
position_embedding
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
embed_dim
)
self
.
use_padding_embeddings
=
config
.
use_padding_embeddings
if
self
.
use_padding_embeddings
:
self
.
padding_embedding
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
embed_dim
)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
seq_length
=
input_ids
.
shape
[
-
1
]
if
input_ids
is
not
None
else
inputs_embeds
.
shape
[
-
2
]
if
position_ids
is
None
:
position_ids
=
self
.
position_ids
[:,
:
seq_length
]
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
token_embedding
(
input_ids
)
position_embeddings
=
self
.
position_embedding
(
position_ids
)
embeddings
=
inputs_embeds
+
position_embeddings
if
self
.
use_padding_embeddings
and
attention_mask
is
not
None
:
padding_embeddings
=
self
.
padding_embedding
(
position_ids
)
embeddings
=
torch
.
where
(
attention_mask
.
bool
().
unsqueeze
(
-
1
),
embeddings
,
padding_embeddings
)
return
embeddings
class
CLIPAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
embed_dim
//
self
.
num_heads
if
self
.
head_dim
*
self
.
num_heads
!=
self
.
embed_dim
:
raise
ValueError
(
f
"embed_dim must be divisible by num_heads (got `embed_dim`:
{
self
.
embed_dim
}
and `num_heads`:"
f
"
{
self
.
num_heads
}
)."
)
self
.
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
self
.
head_dim
))
self
.
qkv_proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
embed_dim
*
3
)
self
.
out_proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
embed_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
causal_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Input shape: Batch x Time x Channel"""
bsz
,
tgt_len
,
embed_dim
=
hidden_states
.
size
()
qkv_states
=
self
.
qkv_proj
(
hidden_states
)
qkv_states
=
qkv_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads
,
-
1
)
query_states
,
key_states
,
value_states
=
torch
.
split
(
qkv_states
,
self
.
head_dim
,
dim
=-
1
)
attn_weights
=
torch
.
einsum
(
"bthc,bshc->bhts"
,
query_states
*
self
.
scale
,
key_states
*
self
.
scale
)
wdtype
=
attn_weights
.
dtype
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
.
float
(),
dim
=-
1
).
type
(
wdtype
)
attn_output
=
torch
.
einsum
(
"bhts,bshc->bthc"
,
attn_weights
,
value_states
)
attn_output
=
attn_output
.
reshape
(
bsz
,
tgt_len
,
-
1
)
attn_output
=
self
.
out_proj
(
attn_output
)
return
attn_output
,
attn_weights
class
CLIPMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
activation_fn
=
ACT2FN
[
config
.
hidden_act
]
self
.
fc1
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
self
.
fc2
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
return
hidden_states
class
CLIPEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPConfig
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
self
.
self_attn
=
CLIPAttention
(
config
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
self
.
embed_dim
)
self
.
mlp
=
CLIPMLP
(
config
)
self
.
layer_norm2
=
nn
.
LayerNorm
(
self
.
embed_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
causal_attention_mask
:
torch
.
Tensor
,
output_attentions
:
Optional
[
bool
]
=
False
,
)
->
Tuple
[
torch
.
FloatTensor
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual
=
hidden_states
hidden_states
=
self
.
layer_norm1
(
hidden_states
)
hidden_states
,
attn_weights
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
causal_attention_mask
=
causal_attention_mask
,
output_attentions
=
output_attentions
,
)
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
layer_norm2
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
outputs
=
(
hidden_states
,)
if
output_attentions
:
outputs
+=
(
attn_weights
,)
return
outputs
class
CLIPPreTrainedModel
(
PreTrainedModel
):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class
=
CLIPConfig
base_model_prefix
=
"clip"
supports_gradient_checkpointing
=
True
_keys_to_ignore_on_load_missing
=
[
r
"position_ids"
]
def
_init_weights
(
self
,
module
):
"""Initialize the weights"""
factor
=
self
.
config
.
initializer_factor
if
isinstance
(
module
,
CLIPTextEmbeddings
):
module
.
token_embedding
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
0.02
)
module
.
position_embedding
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
0.02
)
if
hasattr
(
module
,
"padding_embedding"
):
module
.
padding_embedding
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
0.02
)
elif
isinstance
(
module
,
CLIPVisionEmbeddings
):
factor
=
self
.
config
.
initializer_factor
nn
.
init
.
normal_
(
module
.
class_embedding
,
mean
=
0.0
,
std
=
module
.
embed_dim
**-
0.5
*
factor
)
nn
.
init
.
normal_
(
module
.
patch_embedding
.
weight
,
std
=
module
.
config
.
initializer_range
*
factor
)
nn
.
init
.
normal_
(
module
.
position_embedding
.
weight
,
std
=
module
.
config
.
initializer_range
*
factor
)
elif
isinstance
(
module
,
CLIPAttention
):
factor
=
self
.
config
.
initializer_factor
in_proj_std
=
(
module
.
embed_dim
**-
0.5
)
*
((
2
*
module
.
config
.
num_hidden_layers
)
**
-
0.5
)
*
factor
out_proj_std
=
(
module
.
embed_dim
**-
0.5
)
*
factor
nn
.
init
.
normal_
(
module
.
qkv_proj
.
weight
,
std
=
in_proj_std
)
nn
.
init
.
normal_
(
module
.
out_proj
.
weight
,
std
=
out_proj_std
)
elif
isinstance
(
module
,
CLIPMLP
):
factor
=
self
.
config
.
initializer_factor
in_proj_std
=
(
(
module
.
config
.
hidden_size
**-
0.5
)
*
((
2
*
module
.
config
.
num_hidden_layers
)
**
-
0.5
)
*
factor
)
fc_std
=
(
2
*
module
.
config
.
hidden_size
)
**
-
0.5
*
factor
nn
.
init
.
normal_
(
module
.
fc1
.
weight
,
std
=
fc_std
)
nn
.
init
.
normal_
(
module
.
fc2
.
weight
,
std
=
in_proj_std
)
elif
isinstance
(
module
,
CLIPModel
):
nn
.
init
.
normal_
(
module
.
text_projection
.
weight
,
std
=
module
.
text_embed_dim
**-
0.5
*
self
.
config
.
initializer_factor
,
)
nn
.
init
.
normal_
(
module
.
visual_projection
.
weight
,
std
=
module
.
vision_embed_dim
**-
0.5
*
self
.
config
.
initializer_factor
,
)
if
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
CLIPEncoder
):
module
.
gradient_checkpointing
=
value
CLIP_START_DOCSTRING
=
r
"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
CLIP_TEXT_INPUTS_DOCSTRING
=
r
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
CLIP_VISION_INPUTS_DOCSTRING
=
r
"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
CLIP_INPUTS_DOCSTRING
=
r
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details.
return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class
CLIPEncoder
(
nn
.
Module
):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`CLIPEncoderLayer`].
Args:
config: CLIPConfig
"""
def
__init__
(
self
,
config
:
CLIPConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
layers
=
nn
.
ModuleList
([
CLIPEncoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
inputs_embeds
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
causal_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutput
]:
r
"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Causal mask for the text model. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
encoder_states
=
()
if
output_hidden_states
else
None
all_attentions
=
()
if
output_attentions
else
None
hidden_states
=
inputs_embeds
for
idx
,
encoder_layer
in
enumerate
(
self
.
layers
):
if
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,)
if
self
.
gradient_checkpointing
and
self
.
training
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
return
module
(
*
inputs
,
output_attentions
)
return
custom_forward
layer_outputs
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
encoder_layer
),
hidden_states
,
attention_mask
,
causal_attention_mask
,
)
else
:
layer_outputs
=
encoder_layer
(
hidden_states
,
attention_mask
,
causal_attention_mask
,
output_attentions
=
output_attentions
,
)
hidden_states
=
layer_outputs
[
0
]
if
output_attentions
:
all_attentions
=
all_attentions
+
(
layer_outputs
[
1
],)
if
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,)
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
encoder_states
,
all_attentions
]
if
v
is
not
None
)
return
BaseModelOutput
(
last_hidden_state
=
hidden_states
,
hidden_states
=
encoder_states
,
attentions
=
all_attentions
)
class
CLIPTextTransformer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPTextConfig
):
super
().
__init__
()
self
.
config
=
config
embed_dim
=
config
.
hidden_size
self
.
embeddings
=
CLIPTextEmbeddings
(
config
)
self
.
encoder
=
CLIPEncoder
(
config
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
embed_dim
)
@
add_start_docstrings_to_model_forward
(
CLIP_TEXT_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
BaseModelOutputWithPooling
,
config_class
=
CLIPTextConfig
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPooling
]:
r
"""
Returns:
"""
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
if
input_ids
is
None
:
raise
ValueError
(
"You have to specify either input_ids"
)
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_shape
[
-
1
])
hidden_states
=
self
.
embeddings
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
attention_mask
=
attention_mask
)
bsz
,
seq_len
=
input_shape
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask
=
self
.
_build_causal_attention_mask
(
bsz
,
seq_len
).
to
(
hidden_states
.
device
)
# expand attention_mask
if
attention_mask
is
not
None
:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask
=
_expand_mask
(
attention_mask
,
hidden_states
.
dtype
)
encoder_outputs
=
self
.
encoder
(
inputs_embeds
=
hidden_states
,
attention_mask
=
None
,
causal_attention_mask
=
None
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
last_hidden_state
=
encoder_outputs
[
0
]
last_hidden_state
=
self
.
final_layer_norm
(
last_hidden_state
)
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
pooled_output
=
last_hidden_state
[
torch
.
arange
(
last_hidden_state
.
shape
[
0
]),
input_ids
.
argmax
(
dim
=-
1
)]
if
not
return_dict
:
return
(
last_hidden_state
,
pooled_output
)
+
encoder_outputs
[
1
:]
return
BaseModelOutputWithPooling
(
last_hidden_state
=
last_hidden_state
,
pooler_output
=
pooled_output
,
hidden_states
=
encoder_outputs
.
hidden_states
,
attentions
=
encoder_outputs
.
attentions
,
)
def
_build_causal_attention_mask
(
self
,
bsz
,
seq_len
):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask
=
torch
.
empty
(
bsz
,
seq_len
,
seq_len
)
mask
.
fill_
(
torch
.
tensor
(
float
(
"-inf"
)))
mask
.
triu_
(
1
)
# zero out the lower diagonal
mask
=
mask
.
unsqueeze
(
1
)
# expand mask
return
mask
class
CLIPTextModel
(
CLIPPreTrainedModel
):
config_class
=
CLIPTextConfig
def
__init__
(
self
,
config
:
CLIPTextConfig
):
super
().
__init__
(
config
)
self
.
text_model
=
CLIPTextTransformer
(
config
)
# Initialize weights and apply final processing
self
.
post_init
()
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
return
self
.
text_model
.
embeddings
.
token_embedding
def
set_input_embeddings
(
self
,
value
):
self
.
text_model
.
embeddings
.
token_embedding
=
value
@
add_start_docstrings_to_model_forward
(
CLIP_TEXT_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
BaseModelOutputWithPooling
,
config_class
=
CLIPTextConfig
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPooling
]:
r
"""
Returns:
Examples:
```python
>>> from transformers import CLIPTokenizer, CLIPTextModel
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
>>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
```"""
return
self
.
text_model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
#####################
# END OF THE CLIP MODEL COPY-PASTE
#####################
def
_extract_into_tensor
(
arr
,
timesteps
,
broadcast_shape
):
...
...
@@ -244,6 +918,6 @@ class GLIDE(DiffusionPipeline):
sampled_prev_image
=
prev_image
+
prev_variance
image
=
sampled_prev_image
image
=
image
[
0
]
.
permute
(
1
,
2
,
0
)
image
=
image
.
permute
(
0
,
2
,
3
,
1
)
return
image
models/vision/glide/run_glide.py
View file @
27039cd3
...
...
@@ -13,9 +13,10 @@ model_id = "fusing/glide-base"
pipeline
=
DiffusionPipeline
.
from_pretrained
(
model_id
)
# run inference (text-conditioned denoising + upscaling)
img
=
pipeline
(
"a c
lip art of a hugging face
"
,
generator
)
img
=
pipeline
(
"a c
rayon drawing of a corgi
"
,
generator
)
# process image to PIL
img
=
img
.
squeeze
(
0
)
img
=
((
img
+
1
)
*
127.5
).
round
().
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
cpu
().
numpy
()
image_pil
=
PIL
.
Image
.
fromarray
(
img
)
...
...
models/vision/latent_diffusion/modeling_latent_diffusion.py
View file @
27039cd3
# pytorch_diffusion + derived encoder decoder
import
math
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.modeling_utils
import
ModelMixin
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert
len
(
timesteps
.
shape
)
==
1
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
return
emb
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
x
):
if
self
.
with_conv
:
pad
=
(
0
,
1
,
0
,
1
)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
self
.
conv
(
x
)
else
:
x
=
torch
.
nn
.
functional
.
avg_pool2d
(
x
,
kernel_size
=
2
,
stride
=
2
)
return
x
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
if
temb
is
not
None
:
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
Normalize
(
in_channels
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
h_
=
x
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
# attend to values
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
class
Model
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
use_timestep
=
True
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
self
.
ch
*
4
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
use_timestep
=
use_timestep
if
self
.
use_timestep
:
# timestep embedding
self
.
temb
=
nn
.
Module
()
self
.
temb
.
dense
=
nn
.
ModuleList
(
[
torch
.
nn
.
Linear
(
self
.
ch
,
self
.
temb_ch
),
torch
.
nn
.
Linear
(
self
.
temb_ch
,
self
.
temb_ch
),
]
)
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
skip_in
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
if
i_block
==
self
.
num_res_blocks
:
skip_in
=
ch
*
in_ch_mult
[
i_level
]
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
+
skip_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
,
t
=
None
):
# assert x.shape[2] == x.shape[3] == self.resolution
if
self
.
use_timestep
:
# timestep embedding
assert
t
is
not
None
temb
=
get_timestep_embedding
(
t
,
self
.
ch
)
temb
=
self
.
temb
.
dense
[
0
](
temb
)
temb
=
nonlinearity
(
temb
)
temb
=
self
.
temb
.
dense
[
1
](
temb
)
else
:
temb
=
None
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
),
temb
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
double_z
=
True
,
**
ignore_kwargs
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
2
*
z_channels
if
double_z
else
z_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
# timestep embedding
temb
=
None
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
give_pre_end
=
False
,
**
ignorekwargs
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
give_pre_end
=
give_pre_end
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
block_in
=
ch
*
ch_mult
[
self
.
num_resolutions
-
1
]
curr_res
=
resolution
//
2
**
(
self
.
num_resolutions
-
1
)
self
.
z_shape
=
(
1
,
z_channels
,
curr_res
,
curr_res
)
print
(
"Working with z of shape {} = {} dimensions."
.
format
(
self
.
z_shape
,
np
.
prod
(
self
.
z_shape
)))
# z to block_in
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
z_channels
,
block_in
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
z
):
# assert z.shape[1:] == self.z_shape[1:]
self
.
last_z_shape
=
z
.
shape
# timestep embedding
temb
=
None
# z to block_in
h
=
self
.
conv_in
(
z
)
# middle
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
h
,
temb
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
if
self
.
give_pre_end
:
return
h
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
VectorQuantizer
(
nn
.
Module
):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def
__init__
(
self
,
n_e
,
e_dim
,
beta
,
remap
=
None
,
unknown_index
=
"random"
,
sane_index_shape
=
False
,
legacy
=
True
):
super
().
__init__
()
self
.
n_e
=
n_e
self
.
e_dim
=
e_dim
self
.
beta
=
beta
self
.
legacy
=
legacy
self
.
embedding
=
nn
.
Embedding
(
self
.
n_e
,
self
.
e_dim
)
self
.
embedding
.
weight
.
data
.
uniform_
(
-
1.0
/
self
.
n_e
,
1.0
/
self
.
n_e
)
self
.
remap
=
remap
if
self
.
remap
is
not
None
:
self
.
register_buffer
(
"used"
,
torch
.
tensor
(
np
.
load
(
self
.
remap
)))
self
.
re_embed
=
self
.
used
.
shape
[
0
]
self
.
unknown_index
=
unknown_index
# "random" or "extra" or integer
if
self
.
unknown_index
==
"extra"
:
self
.
unknown_index
=
self
.
re_embed
self
.
re_embed
=
self
.
re_embed
+
1
print
(
f
"Remapping
{
self
.
n_e
}
indices to
{
self
.
re_embed
}
indices. "
f
"Using
{
self
.
unknown_index
}
for unknown indices."
)
else
:
self
.
re_embed
=
n_e
self
.
sane_index_shape
=
sane_index_shape
def
remap_to_used
(
self
,
inds
):
ishape
=
inds
.
shape
assert
len
(
ishape
)
>
1
inds
=
inds
.
reshape
(
ishape
[
0
],
-
1
)
used
=
self
.
used
.
to
(
inds
)
match
=
(
inds
[:,
:,
None
]
==
used
[
None
,
None
,
...]).
long
()
new
=
match
.
argmax
(
-
1
)
unknown
=
match
.
sum
(
2
)
<
1
if
self
.
unknown_index
==
"random"
:
new
[
unknown
]
=
torch
.
randint
(
0
,
self
.
re_embed
,
size
=
new
[
unknown
].
shape
).
to
(
device
=
new
.
device
)
else
:
new
[
unknown
]
=
self
.
unknown_index
return
new
.
reshape
(
ishape
)
def
unmap_to_all
(
self
,
inds
):
ishape
=
inds
.
shape
assert
len
(
ishape
)
>
1
inds
=
inds
.
reshape
(
ishape
[
0
],
-
1
)
used
=
self
.
used
.
to
(
inds
)
if
self
.
re_embed
>
self
.
used
.
shape
[
0
]:
# extra token
inds
[
inds
>=
self
.
used
.
shape
[
0
]]
=
0
# simply set to zero
back
=
torch
.
gather
(
used
[
None
,
:][
inds
.
shape
[
0
]
*
[
0
],
:],
1
,
inds
)
return
back
.
reshape
(
ishape
)
def
forward
(
self
,
z
,
temp
=
None
,
rescale_logits
=
False
,
return_logits
=
False
):
assert
temp
is
None
or
temp
==
1.0
,
"Only for interface compatible with Gumbel"
assert
rescale_logits
==
False
,
"Only for interface compatible with Gumbel"
assert
return_logits
==
False
,
"Only for interface compatible with Gumbel"
# reshape z -> (batch, height, width, channel) and flatten
z
=
rearrange
(
z
,
"b c h w -> b h w c"
).
contiguous
()
z_flattened
=
z
.
view
(
-
1
,
self
.
e_dim
)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d
=
(
torch
.
sum
(
z_flattened
**
2
,
dim
=
1
,
keepdim
=
True
)
+
torch
.
sum
(
self
.
embedding
.
weight
**
2
,
dim
=
1
)
-
2
*
torch
.
einsum
(
"bd,dn->bn"
,
z_flattened
,
rearrange
(
self
.
embedding
.
weight
,
"n d -> d n"
))
)
min_encoding_indices
=
torch
.
argmin
(
d
,
dim
=
1
)
z_q
=
self
.
embedding
(
min_encoding_indices
).
view
(
z
.
shape
)
perplexity
=
None
min_encodings
=
None
# compute loss for embedding
if
not
self
.
legacy
:
loss
=
self
.
beta
*
torch
.
mean
((
z_q
.
detach
()
-
z
)
**
2
)
+
torch
.
mean
((
z_q
-
z
.
detach
())
**
2
)
else
:
loss
=
torch
.
mean
((
z_q
.
detach
()
-
z
)
**
2
)
+
self
.
beta
*
torch
.
mean
((
z_q
-
z
.
detach
())
**
2
)
# preserve gradients
z_q
=
z
+
(
z_q
-
z
).
detach
()
# reshape back to match original input shape
z_q
=
rearrange
(
z_q
,
"b h w c -> b c h w"
).
contiguous
()
if
self
.
remap
is
not
None
:
min_encoding_indices
=
min_encoding_indices
.
reshape
(
z
.
shape
[
0
],
-
1
)
# add batch axis
min_encoding_indices
=
self
.
remap_to_used
(
min_encoding_indices
)
min_encoding_indices
=
min_encoding_indices
.
reshape
(
-
1
,
1
)
# flatten
if
self
.
sane_index_shape
:
min_encoding_indices
=
min_encoding_indices
.
reshape
(
z_q
.
shape
[
0
],
z_q
.
shape
[
2
],
z_q
.
shape
[
3
])
return
z_q
,
loss
,
(
perplexity
,
min_encodings
,
min_encoding_indices
)
def
get_codebook_entry
(
self
,
indices
,
shape
):
# shape specifying (batch, height, width, channel)
if
self
.
remap
is
not
None
:
indices
=
indices
.
reshape
(
shape
[
0
],
-
1
)
# add batch axis
indices
=
self
.
unmap_to_all
(
indices
)
indices
=
indices
.
reshape
(
-
1
)
# flatten again
# get quantized latent vectors
z_q
=
self
.
embedding
(
indices
)
if
shape
is
not
None
:
z_q
=
z_q
.
view
(
shape
)
# reshape back to match original input shape
z_q
=
z_q
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
return
z_q
class
VQModel
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
ch
,
out_ch
,
num_res_blocks
,
attn_resolutions
,
in_channels
,
resolution
,
z_channels
,
n_embed
,
embed_dim
,
remap
=
None
,
sane_index_shape
=
False
,
# tell vector quantizer to return indices as bhw
ch_mult
=
(
1
,
2
,
4
,
8
),
dropout
=
0.0
,
double_z
=
True
,
resamp_with_conv
=
True
,
give_pre_end
=
False
,
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
n_embed
=
n_embed
,
embed_dim
=
embed_dim
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
double_z
=
double_z
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Encoder
self
.
encoder
=
Encoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
double_z
=
double_z
,
give_pre_end
=
give_pre_end
,
)
self
.
quantize
=
VectorQuantizer
(
n_embed
,
embed_dim
,
beta
=
0.25
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
)
# pass init params to Decoder
self
.
decoder
=
Decoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
def
encode
(
self
,
x
):
h
=
self
.
encoder
(
x
)
h
=
self
.
quant_conv
(
h
)
return
h
def
decode
(
self
,
h
,
force_not_quantize
=
False
):
# also go through quantization layer
if
not
force_not_quantize
:
quant
,
emb_loss
,
info
=
self
.
quantize
(
h
)
else
:
quant
=
h
quant
=
self
.
post_quant_conv
(
quant
)
dec
=
self
.
decoder
(
quant
)
return
dec
src/diffusers/__init__.py
View file @
27039cd3
...
...
@@ -5,11 +5,9 @@
__version__
=
"0.0.1"
from
.modeling_utils
import
ModelMixin
from
.models.clip_text_transformer
import
CLIPTextModel
from
.models.unet
import
UNetModel
from
.models.unet_glide
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.models.unet_ldm
import
UNetLDMModel
from
.models.vqvae
import
VQModel
from
.pipeline_utils
import
DiffusionPipeline
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
...
...
src/diffusers/models/__init__.py
View file @
27039cd3
...
...
@@ -16,8 +16,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.clip_text_transformer
import
CLIPTextModel
from
.unet
import
UNetModel
from
.unet_glide
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.unet_ldm
import
UNetLDMModel
from
.vqvae
import
VQModel
src/diffusers/models/clip_text_transformer.py
deleted
100644 → 0
View file @
8841d0d1
# coding=utf-8
# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch CLIP model."""
import
math
from
dataclasses
import
dataclass
from
typing
import
Any
,
Optional
,
Tuple
,
Union
import
torch
import
torch.utils.checkpoint
from
torch
import
nn
from
transformers
import
CLIPConfig
,
CLIPModel
,
CLIPTextConfig
,
CLIPVisionConfig
from
transformers.activations
import
ACT2FN
from
transformers.modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.utils
import
(
ModelOutput
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
logging
,
replace_return_docstrings
,
)
logger
=
logging
.
get_logger
(
__name__
)
_CHECKPOINT_FOR_DOC
=
"openai/clip-vit-base-patch32"
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST
=
[
"openai/clip-vit-base-patch32"
,
# See all CLIP models at https://huggingface.co/models?filter=clip
]
# Copied from transformers.models.bart.modeling_bart._expand_mask
def
_expand_mask
(
mask
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
tgt_len
:
Optional
[
int
]
=
None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz
,
src_len
=
mask
.
size
()
tgt_len
=
tgt_len
if
tgt_len
is
not
None
else
src_len
expanded_mask
=
mask
[:,
None
,
None
,
:].
expand
(
bsz
,
1
,
tgt_len
,
src_len
).
to
(
dtype
)
inverted_mask
=
1.0
-
expanded_mask
return
inverted_mask
.
masked_fill
(
inverted_mask
.
to
(
torch
.
bool
),
torch
.
finfo
(
dtype
).
min
)
# contrastive loss function, adapted from
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
def
contrastive_loss
(
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
nn
.
functional
.
cross_entropy
(
logits
,
torch
.
arange
(
len
(
logits
),
device
=
logits
.
device
))
def
clip_loss
(
similarity
:
torch
.
Tensor
)
->
torch
.
Tensor
:
caption_loss
=
contrastive_loss
(
similarity
)
image_loss
=
contrastive_loss
(
similarity
.
T
)
return
(
caption_loss
+
image_loss
)
/
2.0
@
dataclass
class
CLIPOutput
(
ModelOutput
):
"""
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
Contrastive loss for image-text similarity.
logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
similarity scores.
logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
similarity scores.
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
text_model_output(`BaseModelOutputWithPooling`):
The output of the [`CLIPTextModel`].
vision_model_output(`BaseModelOutputWithPooling`):
The output of the [`CLIPVisionModel`].
"""
loss
:
Optional
[
torch
.
FloatTensor
]
=
None
logits_per_image
:
torch
.
FloatTensor
=
None
logits_per_text
:
torch
.
FloatTensor
=
None
text_embeds
:
torch
.
FloatTensor
=
None
image_embeds
:
torch
.
FloatTensor
=
None
text_model_output
:
BaseModelOutputWithPooling
=
None
vision_model_output
:
BaseModelOutputWithPooling
=
None
def
to_tuple
(
self
)
->
Tuple
[
Any
]:
return
tuple
(
self
[
k
]
if
k
not
in
[
"text_model_output"
,
"vision_model_output"
]
else
getattr
(
self
,
k
).
to_tuple
()
for
k
in
self
.
keys
()
)
class
CLIPVisionEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPVisionConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
class_embedding
=
nn
.
Parameter
(
torch
.
randn
(
self
.
embed_dim
))
self
.
patch_embedding
=
nn
.
Conv2d
(
in_channels
=
3
,
out_channels
=
self
.
embed_dim
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
bias
=
False
)
self
.
num_patches
=
(
self
.
image_size
//
self
.
patch_size
)
**
2
self
.
num_positions
=
self
.
num_patches
+
1
self
.
position_embedding
=
nn
.
Embedding
(
self
.
num_positions
,
self
.
embed_dim
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
self
.
num_positions
).
expand
((
1
,
-
1
)))
def
forward
(
self
,
pixel_values
:
torch
.
FloatTensor
)
->
torch
.
Tensor
:
batch_size
=
pixel_values
.
shape
[
0
]
patch_embeds
=
self
.
patch_embedding
(
pixel_values
)
# shape = [*, width, grid, grid]
patch_embeds
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
class_embeds
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
-
1
)
embeddings
=
torch
.
cat
([
class_embeds
,
patch_embeds
],
dim
=
1
)
embeddings
=
embeddings
+
self
.
position_embedding
(
self
.
position_ids
)
return
embeddings
class
CLIPTextEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPTextConfig
):
super
().
__init__
()
embed_dim
=
config
.
hidden_size
self
.
token_embedding
=
nn
.
Embedding
(
config
.
vocab_size
,
embed_dim
)
self
.
position_embedding
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
embed_dim
)
self
.
use_padding_embeddings
=
config
.
use_padding_embeddings
if
self
.
use_padding_embeddings
:
self
.
padding_embedding
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
embed_dim
)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
seq_length
=
input_ids
.
shape
[
-
1
]
if
input_ids
is
not
None
else
inputs_embeds
.
shape
[
-
2
]
if
position_ids
is
None
:
position_ids
=
self
.
position_ids
[:,
:
seq_length
]
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
token_embedding
(
input_ids
)
position_embeddings
=
self
.
position_embedding
(
position_ids
)
embeddings
=
inputs_embeds
+
position_embeddings
if
self
.
use_padding_embeddings
and
attention_mask
is
not
None
:
padding_embeddings
=
self
.
padding_embedding
(
position_ids
)
embeddings
=
torch
.
where
(
attention_mask
.
bool
().
unsqueeze
(
-
1
),
embeddings
,
padding_embeddings
)
return
embeddings
class
CLIPAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
embed_dim
//
self
.
num_heads
if
self
.
head_dim
*
self
.
num_heads
!=
self
.
embed_dim
:
raise
ValueError
(
f
"embed_dim must be divisible by num_heads (got `embed_dim`:
{
self
.
embed_dim
}
and `num_heads`:"
f
"
{
self
.
num_heads
}
)."
)
self
.
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
self
.
head_dim
))
self
.
qkv_proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
embed_dim
*
3
)
self
.
out_proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
embed_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
causal_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Input shape: Batch x Time x Channel"""
bsz
,
tgt_len
,
embed_dim
=
hidden_states
.
size
()
qkv_states
=
self
.
qkv_proj
(
hidden_states
)
qkv_states
=
qkv_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads
,
-
1
)
query_states
,
key_states
,
value_states
=
torch
.
split
(
qkv_states
,
self
.
head_dim
,
dim
=-
1
)
attn_weights
=
torch
.
einsum
(
"bthc,bshc->bhts"
,
query_states
*
self
.
scale
,
key_states
*
self
.
scale
)
wdtype
=
attn_weights
.
dtype
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
.
float
(),
dim
=-
1
).
type
(
wdtype
)
attn_output
=
torch
.
einsum
(
"bhts,bshc->bthc"
,
attn_weights
,
value_states
)
attn_output
=
attn_output
.
reshape
(
bsz
,
tgt_len
,
-
1
)
attn_output
=
self
.
out_proj
(
attn_output
)
return
attn_output
,
attn_weights
class
CLIPMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
activation_fn
=
ACT2FN
[
config
.
hidden_act
]
self
.
fc1
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
self
.
fc2
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
return
hidden_states
class
CLIPEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPConfig
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
self
.
self_attn
=
CLIPAttention
(
config
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
self
.
embed_dim
)
self
.
mlp
=
CLIPMLP
(
config
)
self
.
layer_norm2
=
nn
.
LayerNorm
(
self
.
embed_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
causal_attention_mask
:
torch
.
Tensor
,
output_attentions
:
Optional
[
bool
]
=
False
,
)
->
Tuple
[
torch
.
FloatTensor
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual
=
hidden_states
hidden_states
=
self
.
layer_norm1
(
hidden_states
)
hidden_states
,
attn_weights
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
causal_attention_mask
=
causal_attention_mask
,
output_attentions
=
output_attentions
,
)
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
layer_norm2
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
outputs
=
(
hidden_states
,)
if
output_attentions
:
outputs
+=
(
attn_weights
,)
return
outputs
class
CLIPPreTrainedModel
(
PreTrainedModel
):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class
=
CLIPConfig
base_model_prefix
=
"clip"
supports_gradient_checkpointing
=
True
_keys_to_ignore_on_load_missing
=
[
r
"position_ids"
]
def
_init_weights
(
self
,
module
):
"""Initialize the weights"""
factor
=
self
.
config
.
initializer_factor
if
isinstance
(
module
,
CLIPTextEmbeddings
):
module
.
token_embedding
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
0.02
)
module
.
position_embedding
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
0.02
)
if
hasattr
(
module
,
"padding_embedding"
):
module
.
padding_embedding
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
0.02
)
elif
isinstance
(
module
,
CLIPVisionEmbeddings
):
factor
=
self
.
config
.
initializer_factor
nn
.
init
.
normal_
(
module
.
class_embedding
,
mean
=
0.0
,
std
=
module
.
embed_dim
**-
0.5
*
factor
)
nn
.
init
.
normal_
(
module
.
patch_embedding
.
weight
,
std
=
module
.
config
.
initializer_range
*
factor
)
nn
.
init
.
normal_
(
module
.
position_embedding
.
weight
,
std
=
module
.
config
.
initializer_range
*
factor
)
elif
isinstance
(
module
,
CLIPAttention
):
factor
=
self
.
config
.
initializer_factor
in_proj_std
=
(
module
.
embed_dim
**-
0.5
)
*
((
2
*
module
.
config
.
num_hidden_layers
)
**
-
0.5
)
*
factor
out_proj_std
=
(
module
.
embed_dim
**-
0.5
)
*
factor
nn
.
init
.
normal_
(
module
.
qkv_proj
.
weight
,
std
=
in_proj_std
)
nn
.
init
.
normal_
(
module
.
out_proj
.
weight
,
std
=
out_proj_std
)
elif
isinstance
(
module
,
CLIPMLP
):
factor
=
self
.
config
.
initializer_factor
in_proj_std
=
(
(
module
.
config
.
hidden_size
**-
0.5
)
*
((
2
*
module
.
config
.
num_hidden_layers
)
**
-
0.5
)
*
factor
)
fc_std
=
(
2
*
module
.
config
.
hidden_size
)
**
-
0.5
*
factor
nn
.
init
.
normal_
(
module
.
fc1
.
weight
,
std
=
fc_std
)
nn
.
init
.
normal_
(
module
.
fc2
.
weight
,
std
=
in_proj_std
)
elif
isinstance
(
module
,
CLIPModel
):
nn
.
init
.
normal_
(
module
.
text_projection
.
weight
,
std
=
module
.
text_embed_dim
**-
0.5
*
self
.
config
.
initializer_factor
,
)
nn
.
init
.
normal_
(
module
.
visual_projection
.
weight
,
std
=
module
.
vision_embed_dim
**-
0.5
*
self
.
config
.
initializer_factor
,
)
if
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
CLIPEncoder
):
module
.
gradient_checkpointing
=
value
CLIP_START_DOCSTRING
=
r
"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
CLIP_TEXT_INPUTS_DOCSTRING
=
r
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
CLIP_VISION_INPUTS_DOCSTRING
=
r
"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
CLIP_INPUTS_DOCSTRING
=
r
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details.
return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class
CLIPEncoder
(
nn
.
Module
):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`CLIPEncoderLayer`].
Args:
config: CLIPConfig
"""
def
__init__
(
self
,
config
:
CLIPConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
layers
=
nn
.
ModuleList
([
CLIPEncoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
inputs_embeds
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
causal_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutput
]:
r
"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Causal mask for the text model. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
encoder_states
=
()
if
output_hidden_states
else
None
all_attentions
=
()
if
output_attentions
else
None
hidden_states
=
inputs_embeds
for
idx
,
encoder_layer
in
enumerate
(
self
.
layers
):
if
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,)
if
self
.
gradient_checkpointing
and
self
.
training
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
return
module
(
*
inputs
,
output_attentions
)
return
custom_forward
layer_outputs
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
encoder_layer
),
hidden_states
,
attention_mask
,
causal_attention_mask
,
)
else
:
layer_outputs
=
encoder_layer
(
hidden_states
,
attention_mask
,
causal_attention_mask
,
output_attentions
=
output_attentions
,
)
hidden_states
=
layer_outputs
[
0
]
if
output_attentions
:
all_attentions
=
all_attentions
+
(
layer_outputs
[
1
],)
if
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,)
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
encoder_states
,
all_attentions
]
if
v
is
not
None
)
return
BaseModelOutput
(
last_hidden_state
=
hidden_states
,
hidden_states
=
encoder_states
,
attentions
=
all_attentions
)
class
CLIPTextTransformer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPTextConfig
):
super
().
__init__
()
self
.
config
=
config
embed_dim
=
config
.
hidden_size
self
.
embeddings
=
CLIPTextEmbeddings
(
config
)
self
.
encoder
=
CLIPEncoder
(
config
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
embed_dim
)
@
add_start_docstrings_to_model_forward
(
CLIP_TEXT_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
BaseModelOutputWithPooling
,
config_class
=
CLIPTextConfig
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPooling
]:
r
"""
Returns:
"""
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
if
input_ids
is
None
:
raise
ValueError
(
"You have to specify either input_ids"
)
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_shape
[
-
1
])
hidden_states
=
self
.
embeddings
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
attention_mask
=
attention_mask
)
bsz
,
seq_len
=
input_shape
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask
=
self
.
_build_causal_attention_mask
(
bsz
,
seq_len
).
to
(
hidden_states
.
device
)
# expand attention_mask
if
attention_mask
is
not
None
:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask
=
_expand_mask
(
attention_mask
,
hidden_states
.
dtype
)
encoder_outputs
=
self
.
encoder
(
inputs_embeds
=
hidden_states
,
attention_mask
=
None
,
causal_attention_mask
=
None
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
last_hidden_state
=
encoder_outputs
[
0
]
last_hidden_state
=
self
.
final_layer_norm
(
last_hidden_state
)
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
pooled_output
=
last_hidden_state
[
torch
.
arange
(
last_hidden_state
.
shape
[
0
]),
input_ids
.
argmax
(
dim
=-
1
)]
if
not
return_dict
:
return
(
last_hidden_state
,
pooled_output
)
+
encoder_outputs
[
1
:]
return
BaseModelOutputWithPooling
(
last_hidden_state
=
last_hidden_state
,
pooler_output
=
pooled_output
,
hidden_states
=
encoder_outputs
.
hidden_states
,
attentions
=
encoder_outputs
.
attentions
,
)
def
_build_causal_attention_mask
(
self
,
bsz
,
seq_len
):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask
=
torch
.
empty
(
bsz
,
seq_len
,
seq_len
)
mask
.
fill_
(
torch
.
tensor
(
float
(
"-inf"
)))
mask
.
triu_
(
1
)
# zero out the lower diagonal
mask
=
mask
.
unsqueeze
(
1
)
# expand mask
return
mask
class
CLIPTextModel
(
CLIPPreTrainedModel
):
config_class
=
CLIPTextConfig
def
__init__
(
self
,
config
:
CLIPTextConfig
):
super
().
__init__
(
config
)
self
.
text_model
=
CLIPTextTransformer
(
config
)
# Initialize weights and apply final processing
self
.
post_init
()
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
return
self
.
text_model
.
embeddings
.
token_embedding
def
set_input_embeddings
(
self
,
value
):
self
.
text_model
.
embeddings
.
token_embedding
=
value
@
add_start_docstrings_to_model_forward
(
CLIP_TEXT_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
BaseModelOutputWithPooling
,
config_class
=
CLIPTextConfig
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPooling
]:
r
"""
Returns:
Examples:
```python
>>> from transformers import CLIPTokenizer, CLIPTextModel
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
>>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
```"""
return
self
.
text_model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
src/diffusers/models/vqvae.py
deleted
100644 → 0
View file @
8841d0d1
# pytorch_diffusion + derived encoder decoder
import
math
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert
len
(
timesteps
.
shape
)
==
1
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
return
emb
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
x
):
if
self
.
with_conv
:
pad
=
(
0
,
1
,
0
,
1
)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
self
.
conv
(
x
)
else
:
x
=
torch
.
nn
.
functional
.
avg_pool2d
(
x
,
kernel_size
=
2
,
stride
=
2
)
return
x
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
if
temb
is
not
None
:
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
Normalize
(
in_channels
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
h_
=
x
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
# attend to values
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
class
Model
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
use_timestep
=
True
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
self
.
ch
*
4
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
use_timestep
=
use_timestep
if
self
.
use_timestep
:
# timestep embedding
self
.
temb
=
nn
.
Module
()
self
.
temb
.
dense
=
nn
.
ModuleList
(
[
torch
.
nn
.
Linear
(
self
.
ch
,
self
.
temb_ch
),
torch
.
nn
.
Linear
(
self
.
temb_ch
,
self
.
temb_ch
),
]
)
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
skip_in
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
if
i_block
==
self
.
num_res_blocks
:
skip_in
=
ch
*
in_ch_mult
[
i_level
]
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
+
skip_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
,
t
=
None
):
# assert x.shape[2] == x.shape[3] == self.resolution
if
self
.
use_timestep
:
# timestep embedding
assert
t
is
not
None
temb
=
get_timestep_embedding
(
t
,
self
.
ch
)
temb
=
self
.
temb
.
dense
[
0
](
temb
)
temb
=
nonlinearity
(
temb
)
temb
=
self
.
temb
.
dense
[
1
](
temb
)
else
:
temb
=
None
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
),
temb
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
double_z
=
True
,
**
ignore_kwargs
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
2
*
z_channels
if
double_z
else
z_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
# timestep embedding
temb
=
None
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
give_pre_end
=
False
,
**
ignorekwargs
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
give_pre_end
=
give_pre_end
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
block_in
=
ch
*
ch_mult
[
self
.
num_resolutions
-
1
]
curr_res
=
resolution
//
2
**
(
self
.
num_resolutions
-
1
)
self
.
z_shape
=
(
1
,
z_channels
,
curr_res
,
curr_res
)
print
(
"Working with z of shape {} = {} dimensions."
.
format
(
self
.
z_shape
,
np
.
prod
(
self
.
z_shape
)))
# z to block_in
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
z_channels
,
block_in
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
z
):
# assert z.shape[1:] == self.z_shape[1:]
self
.
last_z_shape
=
z
.
shape
# timestep embedding
temb
=
None
# z to block_in
h
=
self
.
conv_in
(
z
)
# middle
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
h
,
temb
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
if
self
.
give_pre_end
:
return
h
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
VectorQuantizer
(
nn
.
Module
):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def
__init__
(
self
,
n_e
,
e_dim
,
beta
,
remap
=
None
,
unknown_index
=
"random"
,
sane_index_shape
=
False
,
legacy
=
True
):
super
().
__init__
()
self
.
n_e
=
n_e
self
.
e_dim
=
e_dim
self
.
beta
=
beta
self
.
legacy
=
legacy
self
.
embedding
=
nn
.
Embedding
(
self
.
n_e
,
self
.
e_dim
)
self
.
embedding
.
weight
.
data
.
uniform_
(
-
1.0
/
self
.
n_e
,
1.0
/
self
.
n_e
)
self
.
remap
=
remap
if
self
.
remap
is
not
None
:
self
.
register_buffer
(
"used"
,
torch
.
tensor
(
np
.
load
(
self
.
remap
)))
self
.
re_embed
=
self
.
used
.
shape
[
0
]
self
.
unknown_index
=
unknown_index
# "random" or "extra" or integer
if
self
.
unknown_index
==
"extra"
:
self
.
unknown_index
=
self
.
re_embed
self
.
re_embed
=
self
.
re_embed
+
1
print
(
f
"Remapping
{
self
.
n_e
}
indices to
{
self
.
re_embed
}
indices. "
f
"Using
{
self
.
unknown_index
}
for unknown indices."
)
else
:
self
.
re_embed
=
n_e
self
.
sane_index_shape
=
sane_index_shape
def
remap_to_used
(
self
,
inds
):
ishape
=
inds
.
shape
assert
len
(
ishape
)
>
1
inds
=
inds
.
reshape
(
ishape
[
0
],
-
1
)
used
=
self
.
used
.
to
(
inds
)
match
=
(
inds
[:,
:,
None
]
==
used
[
None
,
None
,
...]).
long
()
new
=
match
.
argmax
(
-
1
)
unknown
=
match
.
sum
(
2
)
<
1
if
self
.
unknown_index
==
"random"
:
new
[
unknown
]
=
torch
.
randint
(
0
,
self
.
re_embed
,
size
=
new
[
unknown
].
shape
).
to
(
device
=
new
.
device
)
else
:
new
[
unknown
]
=
self
.
unknown_index
return
new
.
reshape
(
ishape
)
def
unmap_to_all
(
self
,
inds
):
ishape
=
inds
.
shape
assert
len
(
ishape
)
>
1
inds
=
inds
.
reshape
(
ishape
[
0
],
-
1
)
used
=
self
.
used
.
to
(
inds
)
if
self
.
re_embed
>
self
.
used
.
shape
[
0
]:
# extra token
inds
[
inds
>=
self
.
used
.
shape
[
0
]]
=
0
# simply set to zero
back
=
torch
.
gather
(
used
[
None
,
:][
inds
.
shape
[
0
]
*
[
0
],
:],
1
,
inds
)
return
back
.
reshape
(
ishape
)
def
forward
(
self
,
z
,
temp
=
None
,
rescale_logits
=
False
,
return_logits
=
False
):
assert
temp
is
None
or
temp
==
1.0
,
"Only for interface compatible with Gumbel"
assert
rescale_logits
==
False
,
"Only for interface compatible with Gumbel"
assert
return_logits
==
False
,
"Only for interface compatible with Gumbel"
# reshape z -> (batch, height, width, channel) and flatten
z
=
rearrange
(
z
,
"b c h w -> b h w c"
).
contiguous
()
z_flattened
=
z
.
view
(
-
1
,
self
.
e_dim
)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d
=
(
torch
.
sum
(
z_flattened
**
2
,
dim
=
1
,
keepdim
=
True
)
+
torch
.
sum
(
self
.
embedding
.
weight
**
2
,
dim
=
1
)
-
2
*
torch
.
einsum
(
"bd,dn->bn"
,
z_flattened
,
rearrange
(
self
.
embedding
.
weight
,
"n d -> d n"
))
)
min_encoding_indices
=
torch
.
argmin
(
d
,
dim
=
1
)
z_q
=
self
.
embedding
(
min_encoding_indices
).
view
(
z
.
shape
)
perplexity
=
None
min_encodings
=
None
# compute loss for embedding
if
not
self
.
legacy
:
loss
=
self
.
beta
*
torch
.
mean
((
z_q
.
detach
()
-
z
)
**
2
)
+
torch
.
mean
((
z_q
-
z
.
detach
())
**
2
)
else
:
loss
=
torch
.
mean
((
z_q
.
detach
()
-
z
)
**
2
)
+
self
.
beta
*
torch
.
mean
((
z_q
-
z
.
detach
())
**
2
)
# preserve gradients
z_q
=
z
+
(
z_q
-
z
).
detach
()
# reshape back to match original input shape
z_q
=
rearrange
(
z_q
,
"b h w c -> b c h w"
).
contiguous
()
if
self
.
remap
is
not
None
:
min_encoding_indices
=
min_encoding_indices
.
reshape
(
z
.
shape
[
0
],
-
1
)
# add batch axis
min_encoding_indices
=
self
.
remap_to_used
(
min_encoding_indices
)
min_encoding_indices
=
min_encoding_indices
.
reshape
(
-
1
,
1
)
# flatten
if
self
.
sane_index_shape
:
min_encoding_indices
=
min_encoding_indices
.
reshape
(
z_q
.
shape
[
0
],
z_q
.
shape
[
2
],
z_q
.
shape
[
3
])
return
z_q
,
loss
,
(
perplexity
,
min_encodings
,
min_encoding_indices
)
def
get_codebook_entry
(
self
,
indices
,
shape
):
# shape specifying (batch, height, width, channel)
if
self
.
remap
is
not
None
:
indices
=
indices
.
reshape
(
shape
[
0
],
-
1
)
# add batch axis
indices
=
self
.
unmap_to_all
(
indices
)
indices
=
indices
.
reshape
(
-
1
)
# flatten again
# get quantized latent vectors
z_q
=
self
.
embedding
(
indices
)
if
shape
is
not
None
:
z_q
=
z_q
.
view
(
shape
)
# reshape back to match original input shape
z_q
=
z_q
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
return
z_q
class
VQModel
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
ch
,
out_ch
,
num_res_blocks
,
attn_resolutions
,
in_channels
,
resolution
,
z_channels
,
n_embed
,
embed_dim
,
remap
=
None
,
sane_index_shape
=
False
,
# tell vector quantizer to return indices as bhw
ch_mult
=
(
1
,
2
,
4
,
8
),
dropout
=
0.0
,
double_z
=
True
,
resamp_with_conv
=
True
,
give_pre_end
=
False
,
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
n_embed
=
n_embed
,
embed_dim
=
embed_dim
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
double_z
=
double_z
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Encoder
self
.
encoder
=
Encoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
double_z
=
double_z
,
give_pre_end
=
give_pre_end
,
)
self
.
quantize
=
VectorQuantizer
(
n_embed
,
embed_dim
,
beta
=
0.25
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
)
# pass init params to Decoder
self
.
decoder
=
Decoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
def
encode
(
self
,
x
):
h
=
self
.
encoder
(
x
)
h
=
self
.
quant_conv
(
h
)
return
h
def
decode
(
self
,
h
,
force_not_quantize
=
False
):
# also go through quantization layer
if
not
force_not_quantize
:
quant
,
emb_loss
,
info
=
self
.
quantize
(
h
)
else
:
quant
=
h
quant
=
self
.
post_quant_conv
(
quant
)
dec
=
self
.
decoder
(
quant
)
return
dec
src/diffusers/pipeline_utils.py
View file @
27039cd3
...
...
@@ -34,13 +34,13 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES
=
{
"diffusers"
:
{
"ModelMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"CLIPTextModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
# TODO (Anton): move to transformers
"GaussianDDPMScheduler"
:
[
"save_config"
,
"from_config"
],
"ClassifierFreeGuidanceScheduler"
:
[
"save_config"
,
"from_config"
],
"GlideDDIMScheduler"
:
[
"save_config"
,
"from_config"
],
},
"transformers"
:
{
"PreTrainedTokenizer"
:
[
"save_pretrained"
,
"from_pretrained"
],
"PreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
},
}
...
...
@@ -83,24 +83,25 @@ class DiffusionPipeline(ConfigMixin):
model_index_dict
.
pop
(
"_diffusers_version"
)
model_index_dict
.
pop
(
"_module"
)
for
name
,
(
library_name
,
class_name
)
in
model_index_dict
.
items
():
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
# TODO: Suraj
if
library_name
==
self
.
__module__
:
library_name
=
self
library
=
importlib
.
import_module
(
library_name
)
class_obj
=
getattr
(
library
,
class_name
)
class_candidates
=
{
c
:
getattr
(
library
,
c
)
for
c
in
importable_classes
.
keys
()}
for
pipeline_component_name
in
model_index_dict
.
keys
():
sub_model
=
getattr
(
self
,
pipeline_component_name
)
model_cls
=
sub_model
.
__class__
save_method_name
=
None
for
class_name
,
class_candidate
in
class_candidates
.
items
():
if
issubclass
(
class_obj
,
class_candidate
):
save_method_name
=
importable_classes
[
class_name
][
0
]
save_method
=
getattr
(
getattr
(
self
,
name
),
save_method_name
)
save_method
(
os
.
path
.
join
(
save_directory
,
name
))
# search for the model's base class in LOADABLE_CLASSES
for
library_name
,
library_classes
in
LOADABLE_CLASSES
.
items
():
library
=
importlib
.
import_module
(
library_name
)
for
base_class
,
save_load_methods
in
library_classes
.
items
():
class_candidate
=
getattr
(
library
,
base_class
)
if
issubclass
(
model_cls
,
class_candidate
):
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
save_method_name
=
save_load_methods
[
0
]
break
if
save_method_name
is
not
None
:
break
save_method
=
getattr
(
sub_model
,
save_method_name
)
save_method
(
os
.
path
.
join
(
save_directory
,
pipeline_component_name
))
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Optional
[
Union
[
str
,
os
.
PathLike
]],
**
kwargs
):
...
...
@@ -112,7 +113,8 @@ class DiffusionPipeline(ConfigMixin):
proxies
=
kwargs
.
pop
(
"proxies"
,
None
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if
not
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
cached_folder
=
snapshot_download
(
...
...
@@ -128,11 +130,12 @@ class DiffusionPipeline(ConfigMixin):
config_dict
=
cls
.
get_config_dict
(
cached_folder
)
module
=
config_dict
[
"_module"
]
# 2. Get class name and module candidates to load custom models
class_name_
=
config_dict
[
"_class_name"
]
module_candidate
=
config_dict
[
"_module"
]
module_candidate_name
=
module_candidate
.
replace
(
".py"
,
""
)
# 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
if
cls
!=
DiffusionPipeline
:
pipeline_class
=
cls
...
...
@@ -146,6 +149,7 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs
=
{}
# 4. Load each module in the pipeline
for
name
,
(
library_name
,
class_name
)
in
init_dict
.
items
():
# if the model is not in diffusers or transformers, we need to load it from the hub
# assumes that it's a subclass of ModelMixin
...
...
@@ -155,6 +159,7 @@ class DiffusionPipeline(ConfigMixin):
importable_classes
=
ALL_IMPORTABLE_CLASSES
class_candidates
=
{
c
:
class_obj
for
c
in
ALL_IMPORTABLE_CLASSES
.
keys
()}
else
:
# else we just import it from the library.
library
=
importlib
.
import_module
(
library_name
)
class_obj
=
getattr
(
library
,
class_name
)
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
...
...
@@ -167,12 +172,15 @@ class DiffusionPipeline(ConfigMixin):
load_method
=
getattr
(
class_obj
,
load_method_name
)
# check if the module is in a subdirectory
if
os
.
path
.
isdir
(
os
.
path
.
join
(
cached_folder
,
name
)):
loaded_sub_model
=
load_method
(
os
.
path
.
join
(
cached_folder
,
name
))
else
:
# else load from the root directory
loaded_sub_model
=
load_method
(
cached_folder
)
init_kwargs
[
name
]
=
loaded_sub_model
# UNet(...), # DiffusionSchedule(...)
# 5. Instantiate the pipeline
model
=
pipeline_class
(
**
init_kwargs
)
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment