Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
fbb103de
Commit
fbb103de
authored
Jun 27, 2022
by
patil-suraj
Browse files
add the bert model in latent diffusion pipeline
parent
45a09beb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
541 additions
and
1 deletion
+541
-1
src/diffusers/pipelines/pipeline_latent_diffusion.py
src/diffusers/pipelines/pipeline_latent_diffusion.py
+541
-1
No files found.
src/diffusers/pipelines/pipeline_latent_diffusion.py
View file @
fbb103de
# pytorch_diffusion + derived encoder decoder
import
math
import
math
from
typing
import
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.utils.checkpoint
import
tqdm
import
tqdm
try
:
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.modeling_outputs
import
BaseModelOutput
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.utils
import
logging
except
ImportError
:
raise
ImportError
(
"Please install the transformers."
)
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
..pipeline_utils
import
DiffusionPipeline
from
..pipeline_utils
import
DiffusionPipeline
################################################################################
# Code for the text transformer model
################################################################################
""" PyTorch LDMBERT model."""
logger
=
logging
.
get_logger
(
__name__
)
LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST
=
[
"ldm-bert"
,
# See all LDMBert models at https://huggingface.co/models?filter=ldmbert
]
LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"ldm-bert"
:
"https://huggingface.co/ldm-bert/resolve/main/config.json"
,
}
""" LDMBERT model configuration"""
class
LDMBertConfig
(
PretrainedConfig
):
model_type
=
"ldmbert"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
attribute_map
=
{
"num_attention_heads"
:
"encoder_attention_heads"
,
"hidden_size"
:
"d_model"
}
def
__init__
(
self
,
vocab_size
=
30522
,
max_position_embeddings
=
77
,
encoder_layers
=
32
,
encoder_ffn_dim
=
5120
,
encoder_attention_heads
=
8
,
head_dim
=
64
,
encoder_layerdrop
=
0.0
,
activation_function
=
"gelu"
,
d_model
=
1280
,
dropout
=
0.1
,
attention_dropout
=
0.0
,
activation_dropout
=
0.0
,
init_std
=
0.02
,
classifier_dropout
=
0.0
,
scale_embedding
=
False
,
use_cache
=
True
,
pad_token_id
=
0
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
d_model
=
d_model
self
.
encoder_ffn_dim
=
encoder_ffn_dim
self
.
encoder_layers
=
encoder_layers
self
.
encoder_attention_heads
=
encoder_attention_heads
self
.
head_dim
=
head_dim
self
.
dropout
=
dropout
self
.
attention_dropout
=
attention_dropout
self
.
activation_dropout
=
activation_dropout
self
.
activation_function
=
activation_function
self
.
init_std
=
init_std
self
.
encoder_layerdrop
=
encoder_layerdrop
self
.
classifier_dropout
=
classifier_dropout
self
.
use_cache
=
use_cache
self
.
num_hidden_layers
=
encoder_layers
self
.
scale_embedding
=
scale_embedding
# scale factor will be sqrt(d_model) if True
super
().
__init__
(
pad_token_id
=
pad_token_id
,
**
kwargs
)
def
_expand_mask
(
mask
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
tgt_len
:
Optional
[
int
]
=
None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz
,
src_len
=
mask
.
size
()
tgt_len
=
tgt_len
if
tgt_len
is
not
None
else
src_len
expanded_mask
=
mask
[:,
None
,
None
,
:].
expand
(
bsz
,
1
,
tgt_len
,
src_len
).
to
(
dtype
)
inverted_mask
=
1.0
-
expanded_mask
return
inverted_mask
.
masked_fill
(
inverted_mask
.
to
(
torch
.
bool
),
torch
.
finfo
(
dtype
).
min
)
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert
class
LDMBertAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
embed_dim
:
int
,
num_heads
:
int
,
head_dim
:
int
,
dropout
:
float
=
0.0
,
is_decoder
:
bool
=
False
,
bias
:
bool
=
False
,
):
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
dropout
=
dropout
self
.
head_dim
=
head_dim
self
.
inner_dim
=
head_dim
*
num_heads
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
is_decoder
=
is_decoder
self
.
k_proj
=
nn
.
Linear
(
embed_dim
,
self
.
inner_dim
,
bias
=
bias
)
self
.
v_proj
=
nn
.
Linear
(
embed_dim
,
self
.
inner_dim
,
bias
=
bias
)
self
.
q_proj
=
nn
.
Linear
(
embed_dim
,
self
.
inner_dim
,
bias
=
bias
)
self
.
out_proj
=
nn
.
Linear
(
self
.
inner_dim
,
embed_dim
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
key_value_states
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
layer_head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention
=
key_value_states
is
not
None
bsz
,
tgt_len
,
_
=
hidden_states
.
size
()
# get query proj
query_states
=
self
.
q_proj
(
hidden_states
)
*
self
.
scaling
# get key, value proj
if
is_cross_attention
and
past_key_value
is
not
None
:
# reuse k,v, cross_attentions
key_states
=
past_key_value
[
0
]
value_states
=
past_key_value
[
1
]
elif
is_cross_attention
:
# cross_attentions
key_states
=
self
.
_shape
(
self
.
k_proj
(
key_value_states
),
-
1
,
bsz
)
value_states
=
self
.
_shape
(
self
.
v_proj
(
key_value_states
),
-
1
,
bsz
)
elif
past_key_value
is
not
None
:
# reuse k, v, self_attention
key_states
=
self
.
_shape
(
self
.
k_proj
(
hidden_states
),
-
1
,
bsz
)
value_states
=
self
.
_shape
(
self
.
v_proj
(
hidden_states
),
-
1
,
bsz
)
key_states
=
torch
.
cat
([
past_key_value
[
0
],
key_states
],
dim
=
2
)
value_states
=
torch
.
cat
([
past_key_value
[
1
],
value_states
],
dim
=
2
)
else
:
# self_attention
key_states
=
self
.
_shape
(
self
.
k_proj
(
hidden_states
),
-
1
,
bsz
)
value_states
=
self
.
_shape
(
self
.
v_proj
(
hidden_states
),
-
1
,
bsz
)
if
self
.
is_decoder
:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value
=
(
key_states
,
value_states
)
proj_shape
=
(
bsz
*
self
.
num_heads
,
-
1
,
self
.
head_dim
)
query_states
=
self
.
_shape
(
query_states
,
tgt_len
,
bsz
).
view
(
*
proj_shape
)
key_states
=
key_states
.
view
(
*
proj_shape
)
value_states
=
value_states
.
view
(
*
proj_shape
)
src_len
=
key_states
.
size
(
1
)
attn_weights
=
torch
.
bmm
(
query_states
,
key_states
.
transpose
(
1
,
2
))
if
attn_weights
.
size
()
!=
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
):
raise
ValueError
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
attn_weights
.
size
()
}
"
)
if
attention_mask
is
not
None
:
if
attention_mask
.
size
()
!=
(
bsz
,
1
,
tgt_len
,
src_len
):
raise
ValueError
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is
{
attention_mask
.
size
()
}
"
)
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
+
attention_mask
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
.
size
()
!=
(
self
.
num_heads
,):
raise
ValueError
(
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
,)
}
, but is"
f
"
{
layer_head_mask
.
size
()
}
"
)
attn_weights
=
layer_head_mask
.
view
(
1
,
-
1
,
1
,
1
)
*
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
if
output_attentions
:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
attn_weights_reshaped
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
else
:
attn_weights_reshaped
=
None
attn_probs
=
nn
.
functional
.
dropout
(
attn_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
attn_output
=
torch
.
bmm
(
attn_probs
,
value_states
)
if
attn_output
.
size
()
!=
(
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
):
raise
ValueError
(
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
f
"
{
attn_output
.
size
()
}
"
)
attn_output
=
attn_output
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
attn_output
=
attn_output
.
transpose
(
1
,
2
)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output
=
attn_output
.
reshape
(
bsz
,
tgt_len
,
self
.
inner_dim
)
attn_output
=
self
.
out_proj
(
attn_output
)
return
attn_output
,
attn_weights_reshaped
,
past_key_value
class
LDMBertEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LDMBertConfig
):
super
().
__init__
()
self
.
embed_dim
=
config
.
d_model
self
.
self_attn
=
LDMBertAttention
(
embed_dim
=
self
.
embed_dim
,
num_heads
=
config
.
encoder_attention_heads
,
head_dim
=
config
.
head_dim
,
dropout
=
config
.
attention_dropout
,
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
self
.
dropout
=
config
.
dropout
self
.
activation_fn
=
ACT2FN
[
config
.
activation_function
]
self
.
activation_dropout
=
config
.
activation_dropout
self
.
fc1
=
nn
.
Linear
(
self
.
embed_dim
,
config
.
encoder_ffn_dim
)
self
.
fc2
=
nn
.
Linear
(
config
.
encoder_ffn_dim
,
self
.
embed_dim
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
attention_mask
:
torch
.
FloatTensor
,
layer_head_mask
:
torch
.
FloatTensor
,
output_attentions
:
Optional
[
bool
]
=
False
,
)
->
Tuple
[
torch
.
FloatTensor
,
Optional
[
torch
.
FloatTensor
]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual
=
hidden_states
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
,
attn_weights
,
_
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
,
output_attentions
=
output_attentions
,
)
hidden_states
=
nn
.
functional
.
dropout
(
hidden_states
,
p
=
self
.
dropout
,
training
=
self
.
training
)
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
nn
.
functional
.
dropout
(
hidden_states
,
p
=
self
.
activation_dropout
,
training
=
self
.
training
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
nn
.
functional
.
dropout
(
hidden_states
,
p
=
self
.
dropout
,
training
=
self
.
training
)
hidden_states
=
residual
+
hidden_states
if
hidden_states
.
dtype
==
torch
.
float16
and
(
torch
.
isinf
(
hidden_states
).
any
()
or
torch
.
isnan
(
hidden_states
).
any
()
):
clamp_value
=
torch
.
finfo
(
hidden_states
.
dtype
).
max
-
1000
hidden_states
=
torch
.
clamp
(
hidden_states
,
min
=-
clamp_value
,
max
=
clamp_value
)
outputs
=
(
hidden_states
,)
if
output_attentions
:
outputs
+=
(
attn_weights
,)
return
outputs
# Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert
class
LDMBertPreTrainedModel
(
PreTrainedModel
):
config_class
=
LDMBertConfig
base_model_prefix
=
"model"
supports_gradient_checkpointing
=
True
_keys_to_ignore_on_load_unexpected
=
[
r
"encoder\.version"
,
r
"decoder\.version"
]
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
init_std
if
isinstance
(
module
,
nn
.
Linear
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
nn
.
Embedding
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
padding_idx
is
not
None
:
module
.
weight
.
data
[
module
.
padding_idx
].
zero_
()
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
(
LDMBertEncoder
,)):
module
.
gradient_checkpointing
=
value
@
property
def
dummy_inputs
(
self
):
pad_token
=
self
.
config
.
pad_token_id
input_ids
=
torch
.
tensor
([[
0
,
6
,
10
,
4
,
2
],
[
0
,
8
,
12
,
2
,
pad_token
]],
device
=
self
.
device
)
dummy_inputs
=
{
"attention_mask"
:
input_ids
.
ne
(
pad_token
),
"input_ids"
:
input_ids
,
}
return
dummy_inputs
class
LDMBertEncoder
(
LDMBertPreTrainedModel
):
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
[`LDMBertEncoderLayer`].
Args:
config: LDMBertConfig
embed_tokens (nn.Embedding): output embedding
"""
def
__init__
(
self
,
config
:
LDMBertConfig
):
super
().
__init__
(
config
)
self
.
dropout
=
config
.
dropout
embed_dim
=
config
.
d_model
self
.
padding_idx
=
config
.
pad_token_id
self
.
max_source_positions
=
config
.
max_position_embeddings
self
.
embed_tokens
=
nn
.
Embedding
(
config
.
vocab_size
,
embed_dim
)
self
.
embed_positions
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
embed_dim
)
self
.
layers
=
nn
.
ModuleList
([
LDMBertEncoderLayer
(
config
)
for
_
in
range
(
config
.
encoder_layers
)])
self
.
layer_norm
=
nn
.
LayerNorm
(
embed_dim
)
self
.
gradient_checkpointing
=
False
# Initialize weights and apply final processing
self
.
post_init
()
def
get_input_embeddings
(
self
):
return
self
.
embed_tokens
def
set_input_embeddings
(
self
,
value
):
self
.
embed_tokens
=
value
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutput
]:
r
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
# retrieve input_ids and inputs_embeds
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
input_ids
is
not
None
:
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_shape
[
-
1
])
elif
inputs_embeds
is
not
None
:
input_shape
=
inputs_embeds
.
size
()[:
-
1
]
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
seq_len
=
input_shape
[
1
]
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
seq_len
,
dtype
=
torch
.
long
,
device
=
inputs_embeds
.
device
).
expand
((
1
,
-
1
))
embed_pos
=
self
.
embed_positions
(
position_ids
)
hidden_states
=
inputs_embeds
+
embed_pos
hidden_states
=
nn
.
functional
.
dropout
(
hidden_states
,
p
=
self
.
dropout
,
training
=
self
.
training
)
# expand attention_mask
if
attention_mask
is
not
None
:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask
=
_expand_mask
(
attention_mask
,
inputs_embeds
.
dtype
)
encoder_states
=
()
if
output_hidden_states
else
None
all_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
if
head_mask
is
not
None
:
if
head_mask
.
size
()[
0
]
!=
(
len
(
self
.
layers
)):
raise
ValueError
(
f
"The head_mask should be specified for
{
len
(
self
.
layers
)
}
layers, but it is for"
f
"
{
head_mask
.
size
()[
0
]
}
."
)
for
idx
,
encoder_layer
in
enumerate
(
self
.
layers
):
if
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,)
if
self
.
gradient_checkpointing
and
self
.
training
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
return
module
(
*
inputs
,
output_attentions
)
return
custom_forward
layer_outputs
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
encoder_layer
),
hidden_states
,
attention_mask
,
(
head_mask
[
idx
]
if
head_mask
is
not
None
else
None
),
)
else
:
layer_outputs
=
encoder_layer
(
hidden_states
,
attention_mask
,
layer_head_mask
=
(
head_mask
[
idx
]
if
head_mask
is
not
None
else
None
),
output_attentions
=
output_attentions
,
)
hidden_states
=
layer_outputs
[
0
]
if
output_attentions
:
all_attentions
=
all_attentions
+
(
layer_outputs
[
1
],)
hidden_states
=
self
.
layer_norm
(
hidden_states
)
if
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,)
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
encoder_states
,
all_attentions
]
if
v
is
not
None
)
return
BaseModelOutput
(
last_hidden_state
=
hidden_states
,
hidden_states
=
encoder_states
,
attentions
=
all_attentions
)
class
LDMBertModel
(
LDMBertPreTrainedModel
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
model
=
LDMBertEncoder
(
config
)
self
.
to_logits
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
)
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
return_dict
=
None
,
):
outputs
=
self
.
model
(
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
sequence_output
=
outputs
[
0
]
return
sequence_output
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
This matches the implementation in Denoising Diffusion Probabilistic Models:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment