Unverified Commit 2c83b3c3 authored by zhiheng-huang's avatar zhiheng-huang Committed by GitHub
Browse files

Support various BERT relative position embeddings (2nd) (#8276)



* Support BERT relative position embeddings

* Fix typo in README.md

* Address review comment

* Fix failing tests

* [tiny] Fix style_doc.py check by adding an empty line to configuration_bert.py

* make fix copies

* fix configs of electra and albert and fix longformer

* remove copy statement from longformer

* fix albert

* fix electra

* Add bert variants forward tests for various position embeddings

* [tiny] Fix style for test_modeling_bert.py

* improve docstring

* [tiny] improve docstring and remove unnecessary dependency

* [tiny] Remove unused import

* re-add to ALBERT

* make embeddings work for ALBERT

* add test for albert
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 9e71aa2f
...@@ -159,6 +159,81 @@ Larger batch size may improve the performance while costing more memory. ...@@ -159,6 +159,81 @@ Larger batch size may improve the performance while costing more memory.
} }
``` ```
#### Fine-tuning BERT on SQuAD1.0 with relative position embeddings
The following examples show how to fine-tune BERT models with different relative position embeddings. The BERT model
`bert-base-uncased` was pre-trained with default absolute position embeddings. We provide the following pre-trained
models which were pre-trained on the same training data (BooksCorpus and English Wikipedia) as in the BERT model
training, but with different relative position embeddings.
* `zhiheng-huang/bert-base-uncased-embedding-relative-key`, trained from scratch with relative embedding proposed by
Shaw et al., [Self-Attention with Relative Position Representations](https://arxiv.org/abs/1803.02155)
* `zhiheng-huang/bert-base-uncased-embedding-relative-key-query`, trained from scratch with relative embedding method 4
in Huang et al. [Improve Transformer Models with Better Relative Position Embeddings](https://arxiv.org/abs/2009.13658)
* `zhiheng-huang/bert-large-uncased-whole-word-masking-embedding-relative-key-query`, fine-tuned from model
`bert-large-uncased-whole-word-masking` with 3 additional epochs with relative embedding method 4 in Huang et al.
[Improve Transformer Models with Better Relative Position Embeddings](https://arxiv.org/abs/2009.13658)
##### Base models fine-tuning
```bash
export SQUAD_DIR=/path/to/SQUAD
output_dir=relative_squad
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -m torch.distributed.launch --nproc_per_node=8 ./examples/question-answering/run_squad.py \
--model_type bert \
--model_name_or_path zhiheng-huang/bert-base-uncased-embedding-relative-key-query \
--do_train \
--do_eval \
--do_lower_case \
--train_file $SQUAD_DIR/train-v1.1.json \
--predict_file $SQUAD_DIR/dev-v1.1.json \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 512 \
--doc_stride 128 \
--output_dir ${output_dir} \
--per_gpu_eval_batch_size=60 \
--per_gpu_train_batch_size=6
```
Training with the above command leads to the following results. It boosts the BERT default from f1 score of 88.52 to 90.54.
```bash
'exact': 83.6802270577105, 'f1': 90.54772098174814
```
The change of `max_seq_length` from 512 to 384 in the above command leads to the f1 score of 90.34. Replacing the above
model `zhiheng-huang/bert-base-uncased-embedding-relative-key-query` with
`zhiheng-huang/bert-base-uncased-embedding-relative-key` leads to the f1 score of 89.51. The changing of 8 gpus to one
gpu training leads to the f1 score of 90.71.
##### Large models fine-tuning
```bash
export SQUAD_DIR=/path/to/SQUAD
output_dir=relative_squad
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -m torch.distributed.launch --nproc_per_node=8 ./examples/question-answering/run_squad.py \
--model_type bert \
--model_name_or_path zhiheng-huang/bert-large-uncased-whole-word-masking-embedding-relative-key-query \
--do_train \
--do_eval \
--do_lower_case \
--train_file $SQUAD_DIR/train-v1.1.json \
--predict_file $SQUAD_DIR/dev-v1.1.json \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 512 \
--doc_stride 128 \
--output_dir ${output_dir} \
--per_gpu_eval_batch_size=6 \
--per_gpu_train_batch_size=2 \
--gradient_accumulation_steps 3
```
Training with the above command leads to the f1 score of 93.52, which is slightly better than the f1 score of 93.15 for
`bert-large-uncased-whole-word-masking`.
## SQuAD with the Tensorflow Trainer ## SQuAD with the Tensorflow Trainer
```bash ```bash
......
...@@ -78,6 +78,13 @@ class AlbertConfig(PretrainedConfig): ...@@ -78,6 +78,13 @@ class AlbertConfig(PretrainedConfig):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
classifier_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): classifier_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout ratio for attached classifiers. The dropout ratio for attached classifiers.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
:obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
<https://arxiv.org/abs/2009.13658>`__.
Examples:: Examples::
...@@ -119,6 +126,7 @@ class AlbertConfig(PretrainedConfig): ...@@ -119,6 +126,7 @@ class AlbertConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
classifier_dropout_prob=0.1, classifier_dropout_prob=0.1,
position_embedding_type="absolute",
pad_token_id=0, pad_token_id=0,
bos_token_id=2, bos_token_id=2,
eos_token_id=3, eos_token_id=3,
...@@ -142,3 +150,4 @@ class AlbertConfig(PretrainedConfig): ...@@ -142,3 +150,4 @@ class AlbertConfig(PretrainedConfig):
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.classifier_dropout_prob = classifier_dropout_prob self.classifier_dropout_prob = classifier_dropout_prob
self.position_embedding_type = position_embedding_type
...@@ -214,6 +214,7 @@ class AlbertEmbeddings(nn.Module): ...@@ -214,6 +214,7 @@ class AlbertEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = config.position_embedding_type
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
...@@ -232,10 +233,12 @@ class AlbertEmbeddings(nn.Module): ...@@ -232,10 +233,12 @@ class AlbertEmbeddings(nn.Module):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings) embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
return embeddings return embeddings
...@@ -265,6 +268,11 @@ class AlbertAttention(nn.Module): ...@@ -265,6 +268,11 @@ class AlbertAttention(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pruned_heads = set() self.pruned_heads = set()
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores # Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
...@@ -289,10 +297,10 @@ class AlbertAttention(nn.Module): ...@@ -289,10 +297,10 @@ class AlbertAttention(nn.Module):
self.all_head_size = self.attention_head_size * self.num_attention_heads self.all_head_size = self.attention_head_size * self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, input_ids, attention_mask=None, head_mask=None, output_attentions=False): def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
mixed_query_layer = self.query(input_ids) mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(input_ids) mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(input_ids) mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer) query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer) key_layer = self.transpose_for_scores(mixed_key_layer)
...@@ -301,10 +309,27 @@ class AlbertAttention(nn.Module): ...@@ -301,10 +309,27 @@ class AlbertAttention(nn.Module):
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size) attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function) # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask attention_scores = attention_scores + attention_mask
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores) attention_probs = nn.Softmax(dim=-1)(attention_scores)
...@@ -330,7 +355,7 @@ class AlbertAttention(nn.Module): ...@@ -330,7 +355,7 @@ class AlbertAttention(nn.Module):
projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b
projected_context_layer_dropout = self.output_dropout(projected_context_layer) projected_context_layer_dropout = self.output_dropout(projected_context_layer)
layernormed_context_layer = self.LayerNorm(input_ids + projected_context_layer_dropout) layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,) return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
......
...@@ -91,6 +91,13 @@ class BertConfig(PretrainedConfig): ...@@ -91,6 +91,13 @@ class BertConfig(PretrainedConfig):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass. If True, use gradient checkpointing to save memory at the expense of slower backward pass.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
:obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
<https://arxiv.org/abs/2009.13658>`__.
Examples:: Examples::
...@@ -123,6 +130,7 @@ class BertConfig(PretrainedConfig): ...@@ -123,6 +130,7 @@ class BertConfig(PretrainedConfig):
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
pad_token_id=0, pad_token_id=0,
gradient_checkpointing=False, gradient_checkpointing=False,
position_embedding_type="absolute",
**kwargs **kwargs
): ):
super().__init__(pad_token_id=pad_token_id, **kwargs) super().__init__(pad_token_id=pad_token_id, **kwargs)
...@@ -140,3 +148,4 @@ class BertConfig(PretrainedConfig): ...@@ -140,3 +148,4 @@ class BertConfig(PretrainedConfig):
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing self.gradient_checkpointing = gradient_checkpointing
self.position_embedding_type = position_embedding_type
...@@ -178,6 +178,7 @@ class BertEmbeddings(nn.Module): ...@@ -178,6 +178,7 @@ class BertEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = config.position_embedding_type
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None: if input_ids is not None:
...@@ -195,10 +196,12 @@ class BertEmbeddings(nn.Module): ...@@ -195,10 +196,12 @@ class BertEmbeddings(nn.Module):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings) embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
return embeddings return embeddings
...@@ -222,6 +225,10 @@ class BertSelfAttention(nn.Module): ...@@ -222,6 +225,10 @@ class BertSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
...@@ -256,6 +263,23 @@ class BertSelfAttention(nn.Module): ...@@ -256,6 +263,23 @@ class BertSelfAttention(nn.Module):
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size) attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function) # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
......
...@@ -54,6 +54,13 @@ class BertGenerationConfig(PretrainedConfig): ...@@ -54,6 +54,13 @@ class BertGenerationConfig(PretrainedConfig):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass. If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
:obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
<https://arxiv.org/abs/2009.13658>`__.
Examples:: Examples::
...@@ -87,6 +94,7 @@ class BertGenerationConfig(PretrainedConfig): ...@@ -87,6 +94,7 @@ class BertGenerationConfig(PretrainedConfig):
bos_token_id=2, bos_token_id=2,
eos_token_id=1, eos_token_id=1,
gradient_checkpointing=False, gradient_checkpointing=False,
position_embedding_type="absolute",
**kwargs **kwargs
): ):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
...@@ -103,3 +111,4 @@ class BertGenerationConfig(PretrainedConfig): ...@@ -103,3 +111,4 @@ class BertGenerationConfig(PretrainedConfig):
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing self.gradient_checkpointing = gradient_checkpointing
self.position_embedding_type = position_embedding_type
...@@ -97,6 +97,13 @@ class ElectraConfig(PretrainedConfig): ...@@ -97,6 +97,13 @@ class ElectraConfig(PretrainedConfig):
Argument used when doing sequence summary. Used in the sequence classification and multiple choice models. Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
The dropout ratio to be used after the projection and activation. The dropout ratio to be used after the projection and activation.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
:obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
<https://arxiv.org/abs/2009.13658>`__.
Examples:: Examples::
...@@ -133,6 +140,7 @@ class ElectraConfig(PretrainedConfig): ...@@ -133,6 +140,7 @@ class ElectraConfig(PretrainedConfig):
summary_activation="gelu", summary_activation="gelu",
summary_last_dropout=0.1, summary_last_dropout=0.1,
pad_token_id=0, pad_token_id=0,
position_embedding_type="absolute",
**kwargs **kwargs
): ):
super().__init__(pad_token_id=pad_token_id, **kwargs) super().__init__(pad_token_id=pad_token_id, **kwargs)
...@@ -155,3 +163,4 @@ class ElectraConfig(PretrainedConfig): ...@@ -155,3 +163,4 @@ class ElectraConfig(PretrainedConfig):
self.summary_use_proj = summary_use_proj self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation self.summary_activation = summary_activation
self.summary_last_dropout = summary_last_dropout self.summary_last_dropout = summary_last_dropout
self.position_embedding_type = position_embedding_type
...@@ -165,6 +165,7 @@ class ElectraEmbeddings(nn.Module): ...@@ -165,6 +165,7 @@ class ElectraEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = config.position_embedding_type
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
...@@ -183,10 +184,12 @@ class ElectraEmbeddings(nn.Module): ...@@ -183,10 +184,12 @@ class ElectraEmbeddings(nn.Module):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings) embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
return embeddings return embeddings
...@@ -211,6 +214,10 @@ class ElectraSelfAttention(nn.Module): ...@@ -211,6 +214,10 @@ class ElectraSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
...@@ -245,6 +252,23 @@ class ElectraSelfAttention(nn.Module): ...@@ -245,6 +252,23 @@ class ElectraSelfAttention(nn.Module):
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size) attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in ElectraModel forward() function) # Apply the attention mask is (precomputed for all layers in ElectraModel forward() function)
......
...@@ -146,6 +146,10 @@ class LayoutLMSelfAttention(nn.Module): ...@@ -146,6 +146,10 @@ class LayoutLMSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
...@@ -180,6 +184,23 @@ class LayoutLMSelfAttention(nn.Module): ...@@ -180,6 +184,23 @@ class LayoutLMSelfAttention(nn.Module):
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size) attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in LayoutLMModel forward() function) # Apply the attention mask is (precomputed for all layers in LayoutLMModel forward() function)
......
...@@ -446,7 +446,6 @@ class LongformerEmbeddings(nn.Module): ...@@ -446,7 +446,6 @@ class LongformerEmbeddings(nn.Module):
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
""" """
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
...@@ -461,7 +460,6 @@ class LongformerEmbeddings(nn.Module): ...@@ -461,7 +460,6 @@ class LongformerEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
# End copy
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding( self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
...@@ -475,7 +473,6 @@ class LongformerEmbeddings(nn.Module): ...@@ -475,7 +473,6 @@ class LongformerEmbeddings(nn.Module):
else: else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
else: else:
......
...@@ -83,6 +83,7 @@ class RobertaEmbeddings(nn.Module): ...@@ -83,6 +83,7 @@ class RobertaEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = config.position_embedding_type
# End copy # End copy
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -114,10 +115,12 @@ class RobertaEmbeddings(nn.Module): ...@@ -114,10 +115,12 @@ class RobertaEmbeddings(nn.Module):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings) embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
return embeddings return embeddings
...@@ -159,6 +162,10 @@ class RobertaSelfAttention(nn.Module): ...@@ -159,6 +162,10 @@ class RobertaSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
...@@ -193,6 +200,23 @@ class RobertaSelfAttention(nn.Module): ...@@ -193,6 +200,23 @@ class RobertaSelfAttention(nn.Module):
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size) attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
......
...@@ -106,7 +106,7 @@ class AlbertModelTester: ...@@ -106,7 +106,7 @@ class AlbertModelTester:
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_albert_model( def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
model = AlbertModel(config=config) model = AlbertModel(config=config)
...@@ -118,7 +118,7 @@ class AlbertModelTester: ...@@ -118,7 +118,7 @@ class AlbertModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_albert_for_pretraining( def create_and_check_for_pretraining(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
model = AlbertForPreTraining(config=config) model = AlbertForPreTraining(config=config)
...@@ -134,7 +134,7 @@ class AlbertModelTester: ...@@ -134,7 +134,7 @@ class AlbertModelTester:
self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertEqual(result.sop_logits.shape, (self.batch_size, config.num_labels)) self.parent.assertEqual(result.sop_logits.shape, (self.batch_size, config.num_labels))
def create_and_check_albert_for_masked_lm( def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
model = AlbertForMaskedLM(config=config) model = AlbertForMaskedLM(config=config)
...@@ -143,7 +143,7 @@ class AlbertModelTester: ...@@ -143,7 +143,7 @@ class AlbertModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_albert_for_question_answering( def create_and_check_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
model = AlbertForQuestionAnswering(config=config) model = AlbertForQuestionAnswering(config=config)
...@@ -159,7 +159,7 @@ class AlbertModelTester: ...@@ -159,7 +159,7 @@ class AlbertModelTester:
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_albert_for_sequence_classification( def create_and_check_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.num_labels = self.num_labels config.num_labels = self.num_labels
...@@ -169,7 +169,7 @@ class AlbertModelTester: ...@@ -169,7 +169,7 @@ class AlbertModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_albert_for_token_classification( def create_and_check_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.num_labels = self.num_labels config.num_labels = self.num_labels
...@@ -179,7 +179,7 @@ class AlbertModelTester: ...@@ -179,7 +179,7 @@ class AlbertModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_albert_for_multiple_choice( def create_and_check_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.num_choices = self.num_choices config.num_choices = self.num_choices
...@@ -250,29 +250,35 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -250,29 +250,35 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
def test_albert_model(self): def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_albert_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_pretraining(self): def test_for_pretraining(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_albert_for_pretraining(*config_and_inputs) self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
def test_for_masked_lm(self): def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_albert_for_masked_lm(*config_and_inputs) self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_multiple_choice(self): def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_albert_for_multiple_choice(*config_and_inputs) self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
def test_for_question_answering(self): def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_albert_for_question_answering(*config_and_inputs) self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self): def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_albert_for_sequence_classification(*config_and_inputs) self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
......
...@@ -404,6 +404,12 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -404,6 +404,12 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_as_decoder(self): def test_model_as_decoder(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
...@@ -477,3 +483,43 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -477,3 +483,43 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = BertModel.from_pretrained(model_name) model = BertModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@require_torch
class BertModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_no_head_absolute_embedding(self):
model = BertModel.from_pretrained("bert-base-uncased")
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
output = model(input_ids)[0]
expected_shape = torch.Size((1, 11, 768))
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[[[-0.0483, 0.1188, -0.0313], [-0.0606, 0.1435, 0.0199], [-0.0235, 0.1519, 0.0175]]]
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
@slow
def test_inference_no_head_relative_embedding_key(self):
model = BertModel.from_pretrained("zhiheng-huang/bert-base-uncased-embedding-relative-key")
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
output = model(input_ids)[0]
expected_shape = torch.Size((1, 11, 768))
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[[[0.3492, 0.4126, -0.1484], [0.2274, -0.0549, 0.1623], [0.5889, 0.6797, -0.0189]]]
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
@slow
def test_inference_no_head_relative_embedding_key_query(self):
model = BertModel.from_pretrained("zhiheng-huang/bert-base-uncased-embedding-relative-key-query")
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
output = model(input_ids)[0]
expected_shape = torch.Size((1, 11, 768))
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor([[[1.1677, 0.5129, 0.9524], [0.6659, 0.5958, 0.6688], [1.1714, 0.1764, 0.6266]]])
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
...@@ -309,6 +309,12 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -309,6 +309,12 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_model(*config_and_inputs) self.model_tester.create_and_check_electra_model(*config_and_inputs)
def test_electra_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_electra_model(*config_and_inputs)
def test_for_masked_lm(self): def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_masked_lm(*config_and_inputs) self.model_tester.create_and_check_electra_for_masked_lm(*config_and_inputs)
......
...@@ -199,6 +199,12 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -199,6 +199,12 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_masked_lm(self): def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin from .test_generation_utils import GenerationTesterMixin
...@@ -295,6 +295,12 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas ...@@ -295,6 +295,12 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_as_decoder(self): def test_model_as_decoder(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
...@@ -395,8 +401,6 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas ...@@ -395,8 +401,6 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions))) self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
@require_sentencepiece
@require_tokenizers
@require_torch @require_torch
class RobertaModelIntegrationTest(unittest.TestCase): class RobertaModelIntegrationTest(unittest.TestCase):
@slow @slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment