Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
447de34d
Commit
447de34d
authored
Sep 23, 2019
by
thomwolf
Browse files
tests for distilbert and roberta
parent
68a3e022
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
273 additions
and
31 deletions
+273
-31
pytorch_transformers/configuration_distilbert.py
pytorch_transformers/configuration_distilbert.py
+1
-1
pytorch_transformers/modeling_tf_distilbert.py
pytorch_transformers/modeling_tf_distilbert.py
+50
-30
pytorch_transformers/tests/modeling_tf_distilbert_test.py
pytorch_transformers/tests/modeling_tf_distilbert_test.py
+222
-0
No files found.
pytorch_transformers/configuration_distilbert.py
View file @
447de34d
...
@@ -37,7 +37,7 @@ class DistilBertConfig(PretrainedConfig):
...
@@ -37,7 +37,7 @@ class DistilBertConfig(PretrainedConfig):
def
__init__
(
self
,
def
__init__
(
self
,
vocab_size_or_config_json_file
=
30522
,
vocab_size_or_config_json_file
=
30522
,
max_position_embeddings
=
512
,
max_position_embeddings
=
512
,
sinusoidal_pos_embds
=
Tru
e
,
sinusoidal_pos_embds
=
Fals
e
,
n_layers
=
6
,
n_layers
=
6
,
n_heads
=
12
,
n_heads
=
12
,
dim
=
768
,
dim
=
768
,
...
...
pytorch_transformers/modeling_tf_distilbert.py
View file @
447de34d
...
@@ -79,9 +79,9 @@ class TFEmbeddings(tf.keras.layers.Layer):
...
@@ -79,9 +79,9 @@ class TFEmbeddings(tf.keras.layers.Layer):
super
(
TFEmbeddings
,
self
).
__init__
(
**
kwargs
)
super
(
TFEmbeddings
,
self
).
__init__
(
**
kwargs
)
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
dim
=
config
.
dim
self
.
dim
=
config
.
dim
self
.
word_embeddings
=
TFSharedEmbeddings
(,
name
=
'word_embeddings'
)
# padding_idx=0)
self
.
word_embeddings
=
TFSharedEmbeddings
(
config
.
vocab_size
,
config
.
dim
,
name
=
'word_embeddings'
)
# padding_idx=0)
self
.
position_embeddings
=
tf
.
keras
.
layers
.
Embedding
(
config
.
max_position_embeddings
,
config
.
dim
,
name
=
'position_embeddings'
)
self
.
position_embeddings
=
tf
.
keras
.
layers
.
Embedding
(
config
.
max_position_embeddings
,
config
.
dim
,
name
=
'position_embeddings'
)
if
config
.
sinusoidal_
embedding
s
:
if
config
.
sinusoidal_
pos_embd
s
:
raise
NotImplementedError
raise
NotImplementedError
self
.
LayerNorm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
1e-12
,
name
=
"LayerNorm"
)
self
.
LayerNorm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
1e-12
,
name
=
"LayerNorm"
)
...
@@ -94,9 +94,9 @@ class TFEmbeddings(tf.keras.layers.Layer):
...
@@ -94,9 +94,9 @@ class TFEmbeddings(tf.keras.layers.Layer):
# arbitrarily, and works well.
# arbitrarily, and works well.
self
.
word_embeddings
=
self
.
add_weight
(
self
.
word_embeddings
=
self
.
add_weight
(
"weight"
,
"weight"
,
shape
=
[
self
.
vocab_size
,
self
.
hidden_size
],
shape
=
[
self
.
vocab_size
,
self
.
dim
],
initializer
=
tf
.
random_normal_initializer
(
initializer
=
tf
.
random_normal_initializer
(
mean
=
0.
,
stddev
=
self
.
hidden_size
**-
0.5
))
mean
=
0.
,
stddev
=
self
.
dim
**-
0.5
))
super
(
TFEmbeddings
,
self
).
build
(
input_shape
)
super
(
TFEmbeddings
,
self
).
build
(
input_shape
)
def
call
(
self
,
inputs
,
mode
=
"embedding"
,
training
=
False
):
def
call
(
self
,
inputs
,
mode
=
"embedding"
,
training
=
False
):
...
@@ -133,13 +133,17 @@ class TFEmbeddings(tf.keras.layers.Layer):
...
@@ -133,13 +133,17 @@ class TFEmbeddings(tf.keras.layers.Layer):
embeddings: torch.tensor(bs, max_seq_length, dim)
embeddings: torch.tensor(bs, max_seq_length, dim)
The embedded tokens (plus position embeddings, no token_type embeddings)
The embedded tokens (plus position embeddings, no token_type embeddings)
"""
"""
input_ids
,
position_ids
=
inputs
if
not
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
position_ids
=
None
else
:
input_ids
,
position_ids
=
inputs
seq_length
=
tf
.
shape
(
input_ids
)[
1
]
seq_length
=
tf
.
shape
(
input_ids
)[
1
]
if
position_ids
is
None
:
if
position_ids
is
None
:
position_ids
=
tf
.
range
(
seq_length
,
dtype
=
tf
.
int32
)[
tf
.
newaxis
,
:]
position_ids
=
tf
.
range
(
seq_length
,
dtype
=
tf
.
int32
)[
tf
.
newaxis
,
:]
word
s
_embeddings
=
tf
.
gather
(
self
.
word_embeddings
,
input_ids
)
word_embeddings
=
tf
.
gather
(
self
.
word_embeddings
,
input_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
# (bs, max_seq_length, dim)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
# (bs, max_seq_length, dim)
embeddings
=
word_embeddings
+
position_embeddings
# (bs, max_seq_length, dim)
embeddings
=
word_embeddings
+
position_embeddings
# (bs, max_seq_length, dim)
...
@@ -157,7 +161,7 @@ class TFEmbeddings(tf.keras.layers.Layer):
...
@@ -157,7 +161,7 @@ class TFEmbeddings(tf.keras.layers.Layer):
batch_size
=
tf
.
shape
(
inputs
)[
0
]
batch_size
=
tf
.
shape
(
inputs
)[
0
]
length
=
tf
.
shape
(
inputs
)[
1
]
length
=
tf
.
shape
(
inputs
)[
1
]
x
=
tf
.
reshape
(
inputs
,
[
-
1
,
self
.
hidden_size
])
x
=
tf
.
reshape
(
inputs
,
[
-
1
,
self
.
dim
])
logits
=
tf
.
matmul
(
x
,
self
.
word_embeddings
,
transpose_b
=
True
)
logits
=
tf
.
matmul
(
x
,
self
.
word_embeddings
,
transpose_b
=
True
)
return
tf
.
reshape
(
logits
,
[
batch_size
,
length
,
self
.
vocab_size
])
return
tf
.
reshape
(
logits
,
[
batch_size
,
length
,
self
.
vocab_size
])
...
@@ -169,7 +173,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
...
@@ -169,7 +173,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
self
.
n_heads
=
config
.
n_heads
self
.
n_heads
=
config
.
n_heads
self
.
dim
=
config
.
dim
self
.
dim
=
config
.
dim
self
.
dropout
=
nn
.
Dropout
(
p
=
config
.
attention_dropout
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
attention_dropout
)
self
.
output_attentions
=
config
.
output_attentions
self
.
output_attentions
=
config
.
output_attentions
assert
self
.
dim
%
self
.
n_heads
==
0
assert
self
.
dim
%
self
.
n_heads
==
0
...
@@ -210,7 +214,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
...
@@ -210,7 +214,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
assert
2
<=
len
(
tf
.
shape
(
mask
))
<=
3
assert
2
<=
len
(
tf
.
shape
(
mask
))
<=
3
causal
=
(
len
(
tf
.
shape
(
mask
))
==
3
)
causal
=
(
len
(
tf
.
shape
(
mask
))
==
3
)
mask_resh
p
=
[
bs
,
1
,
1
,
k_length
]
mask_resh
ape
=
[
bs
,
1
,
1
,
k_length
]
def
shape
(
x
):
def
shape
(
x
):
""" separate heads """
""" separate heads """
...
@@ -327,7 +331,7 @@ class TFTransformer(tf.keras.layers.Layer):
...
@@ -327,7 +331,7 @@ class TFTransformer(tf.keras.layers.Layer):
self
.
layer
=
[
TFTransformerBlock
(
config
,
name
=
'layer_._{}'
.
format
(
i
))
self
.
layer
=
[
TFTransformerBlock
(
config
,
name
=
'layer_._{}'
.
format
(
i
))
for
i
in
range
(
config
.
n_layers
)]
for
i
in
range
(
config
.
n_layers
)]
def
forward
(
self
,
inputs
,
training
=
False
):
def
call
(
self
,
inputs
,
training
=
False
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -403,6 +407,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
...
@@ -403,6 +407,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
"""
"""
def
__init__
(
self
,
config
,
**
kwargs
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
(
TFDistilBertMainLayer
,
self
).
__init__
(
**
kwargs
)
super
(
TFDistilBertMainLayer
,
self
).
__init__
(
**
kwargs
)
self
.
num_hidden_layers
=
config
.
num_hidden_layers
self
.
embeddings
=
TFEmbeddings
(
config
,
name
=
"embeddings"
)
# Embeddings
self
.
embeddings
=
TFEmbeddings
(
config
,
name
=
"embeddings"
)
# Embeddings
self
.
transformer
=
TFTransformer
(
config
,
name
=
"transformer"
)
# Encoder
self
.
transformer
=
TFTransformer
(
config
,
name
=
"transformer"
)
# Encoder
...
@@ -430,6 +435,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
...
@@ -430,6 +435,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
if
attention_mask
is
None
:
if
attention_mask
is
None
:
attention_mask
=
tf
.
ones
(
shape_list
(
input_ids
))
# (bs, seq_length)
attention_mask
=
tf
.
ones
(
shape_list
(
input_ids
))
# (bs, seq_length)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
tf
.
float32
)
# Prepare head mask if needed
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# 1.0 in head_mask indicate we keep the head
...
@@ -439,15 +445,12 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
...
@@ -439,15 +445,12 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
if
head_mask
is
not
None
:
if
head_mask
is
not
None
:
raise
NotImplementedError
raise
NotImplementedError
else
:
else
:
head_mask
=
[
None
]
*
self
.
config
.
num_hidden_layers
head_mask
=
[
None
]
*
self
.
num_hidden_layers
embedding_output
=
self
.
embeddings
(
input_ids
)
# (bs, seq_length, dim)
embedding_output
=
self
.
embeddings
(
input_ids
)
# (bs, seq_length, dim)
tfmr_output
=
self
.
transformer
([
embedding_output
,
attention_mask
,
head_mask
],
training
=
training
)
tfmr_output
=
self
.
transformer
([
embedding_output
,
attention_mask
,
head_mask
],
training
=
training
)
hidden_state
=
tfmr_output
[
0
]
return
tfmr_output
# last-layer hidden-state, (all hidden_states), (all attentions)
output
=
(
hidden_state
,
)
+
tfmr_output
[
1
:]
return
output
# last-layer hidden-state, (all hidden_states), (all attentions)
### INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL ###
### INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL ###
...
@@ -503,7 +506,7 @@ DISTILBERT_INPUTS_DOCSTRING = r"""
...
@@ -503,7 +506,7 @@ DISTILBERT_INPUTS_DOCSTRING = r"""
@
add_start_docstrings
(
"The bare DistilBERT encoder/transformer outputing raw hidden-states without any specific head on top."
,
@
add_start_docstrings
(
"The bare DistilBERT encoder/transformer outputing raw hidden-states without any specific head on top."
,
DISTILBERT_START_DOCSTRING
,
DISTILBERT_INPUTS_DOCSTRING
)
DISTILBERT_START_DOCSTRING
,
DISTILBERT_INPUTS_DOCSTRING
)
class
DistilBertModel
(
DistilBertPreTrainedModel
):
class
TF
DistilBertModel
(
TF
DistilBertPreTrainedModel
):
r
"""
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
...
@@ -526,7 +529,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
...
@@ -526,7 +529,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
"""
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
DistilBertModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
super
(
TF
DistilBertModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
distilbert
=
TFDistilBertMainLayer
(
config
,
name
=
"distilbert"
)
# Embeddings
self
.
distilbert
=
TFDistilBertMainLayer
(
config
,
name
=
"distilbert"
)
# Embeddings
def
call
(
self
,
inputs
,
training
=
False
):
def
call
(
self
,
inputs
,
training
=
False
):
...
@@ -534,6 +537,28 @@ class DistilBertModel(DistilBertPreTrainedModel):
...
@@ -534,6 +537,28 @@ class DistilBertModel(DistilBertPreTrainedModel):
return
outputs
return
outputs
class
TFDistilBertLMHead
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
input_embeddings
,
**
kwargs
):
super
(
TFDistilBertLMHead
,
self
).
__init__
(
**
kwargs
)
self
.
vocab_size
=
config
.
vocab_size
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self
.
input_embeddings
=
input_embeddings
def
build
(
self
,
input_shape
):
self
.
bias
=
self
.
add_weight
(
shape
=
(
self
.
vocab_size
,),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'bias'
)
super
(
TFDistilBertLMHead
,
self
).
build
(
input_shape
)
def
call
(
self
,
hidden_states
):
hidden_states
=
self
.
input_embeddings
(
hidden_states
,
mode
=
"linear"
)
hidden_states
=
hidden_states
+
self
.
bias
return
hidden_states
@
add_start_docstrings
(
"""DistilBert Model with a `masked language modeling` head on top. """
,
@
add_start_docstrings
(
"""DistilBert Model with a `masked language modeling` head on top. """
,
DISTILBERT_START_DOCSTRING
,
DISTILBERT_INPUTS_DOCSTRING
)
DISTILBERT_START_DOCSTRING
,
DISTILBERT_INPUTS_DOCSTRING
)
class
TFDistilBertForMaskedLM
(
TFDistilBertPreTrainedModel
):
class
TFDistilBertForMaskedLM
(
TFDistilBertPreTrainedModel
):
...
@@ -570,27 +595,22 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
...
@@ -570,27 +595,22 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
super
(
TFDistilBertForMaskedLM
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
super
(
TFDistilBertForMaskedLM
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
output_attentions
=
config
.
output_attentions
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
vocab_size
=
config
.
vocab_size
self
.
distilbert
=
TFDistilBertMainLayer
(
config
,
name
=
"distilbert"
)
self
.
distilbert
=
TFDistilBertMainLayer
(
config
,
name
=
"distilbert"
)
self
.
vocab_transform
=
tf
.
keras
.
layers
.
Dense
(
config
.
dim
,
name
=
"vocab_transform"
)
self
.
vocab_transform
=
tf
.
keras
.
layers
.
Dense
(
config
.
dim
,
name
=
"vocab_transform"
)
self
.
act
=
tf
.
keras
.
layers
.
Activation
(
gelu
)
self
.
act
=
tf
.
keras
.
layers
.
Activation
(
gelu
)
self
.
vocab_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
1e-12
,
name
=
"vocab_layer_norm"
)
self
.
vocab_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
1e-12
,
name
=
"vocab_layer_norm"
)
self
.
vocab_projector_weight
=
self
.
distilbert
.
embeddings
self
.
vocab_projector
=
TFDistilBertLMHead
(
config
,
self
.
distilbert
.
embeddings
,
name
=
"vocab_projector"
)
def
build
(
self
,
input_shape
):
self
.
vocab_projector_bias
=
self
.
add_weight
(
shape
=
(
self
.
vocab_size
,),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'vocab_projector_._bias'
)
super
(
TFDistilBertForMaskedLM
,
self
).
build
(
input_shape
)
def
call
(
self
,
inputs
,
training
=
False
):
def
call
(
self
,
inputs
,
training
=
False
):
dlbrt_output
=
self
.
distilbert
(
inputs
,
training
=
training
)
dlbrt_output
=
self
.
distilbert
(
inputs
,
training
=
training
)
hidden_states
=
dlbrt_output
[
0
]
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_transform
(
hidden_states
)
# (bs, seq_length, dim)
hidden_states
=
dlbrt_output
[
0
]
# (bs, seq_length, dim)
prediction_logits
=
self
.
act
(
prediction_logits
)
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_transform
(
hidden_states
)
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_layer_norm
(
prediction_logits
)
# (bs, seq_length, dim)
prediction_logits
=
self
.
act
(
prediction_logits
)
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_projector_weight
(
prediction_logits
,
mode
=
'linear'
)
+
self
.
vocab_projector_bias
prediction_logits
=
self
.
vocab_layer_norm
(
prediction_logits
)
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_projector
(
prediction_logits
)
outputs
=
(
prediction_logits
,
)
+
dlbrt_output
[
1
:]
outputs
=
(
prediction_logits
,
)
+
dlbrt_output
[
1
:]
...
...
pytorch_transformers/tests/modeling_tf_distilbert_test.py
0 → 100644
View file @
447de34d
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
pytest
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
pytorch_transformers
import
DistilBertConfig
,
is_tf_available
if
is_tf_available
():
import
tensorflow
as
tf
from
pytorch_transformers.modeling_tf_distilbert
import
(
TFDistilBertModel
,
TFDistilBertForMaskedLM
,
TFDistilBertForQuestionAnswering
,
TFDistilBertForSequenceClassification
)
else
:
pytestmark
=
pytest
.
mark
.
skip
(
"Require TensorFlow"
)
class
TFDistilBertModelTest
(
TFCommonTestCases
.
TFCommonModelTester
):
all_model_classes
=
(
TFDistilBertModel
,
TFDistilBertForMaskedLM
,
TFDistilBertForQuestionAnswering
,
TFDistilBertForSequenceClassification
)
if
is_tf_available
()
else
None
test_pruning
=
True
test_torchscript
=
True
test_resize_embeddings
=
True
test_head_masking
=
True
class
TFDistilBertModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
is_training
=
True
,
use_input_mask
=
True
,
use_token_type_ids
=
False
,
use_labels
=
True
,
vocab_size
=
99
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_attention_heads
=
4
,
intermediate_size
=
37
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
num_labels
=
3
,
num_choices
=
4
,
scope
=
None
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
is_training
=
is_training
self
.
use_input_mask
=
use_input_mask
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
intermediate_size
=
intermediate_size
self
.
hidden_act
=
hidden_act
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
scope
=
scope
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
vocab_size
=
2
)
sequence_labels
=
None
token_labels
=
None
choice_labels
=
None
if
self
.
use_labels
:
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
DistilBertConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
dim
=
self
.
hidden_size
,
n_layers
=
self
.
num_hidden_layers
,
n_heads
=
self
.
num_attention_heads
,
hidden_dim
=
self
.
intermediate_size
,
hidden_act
=
self
.
hidden_act
,
dropout
=
self
.
hidden_dropout_prob
,
attention_dropout
=
self
.
attention_probs_dropout_prob
,
max_position_embeddings
=
self
.
max_position_embeddings
,
initializer_range
=
self
.
initializer_range
)
return
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
create_and_check_distilbert_model
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
TFDistilBertModel
(
config
=
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
}
outputs
=
model
(
inputs
)
sequence_output
=
outputs
[
0
]
inputs
=
[
input_ids
,
input_mask
]
(
sequence_output
,)
=
model
(
inputs
)
result
=
{
"sequence_output"
:
sequence_output
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"sequence_output"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
def
create_and_check_distilbert_for_masked_lm
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
TFDistilBertForMaskedLM
(
config
=
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
}
(
prediction_scores
,)
=
model
(
inputs
)
result
=
{
"prediction_scores"
:
prediction_scores
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_and_check_distilbert_for_question_answering
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
TFDistilBertForQuestionAnswering
(
config
=
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
}
start_logits
,
end_logits
=
model
(
inputs
)
result
=
{
"start_logits"
:
start_logits
.
numpy
(),
"end_logits"
:
end_logits
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
])
def
create_and_check_distilbert_for_sequence_classification
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
config
.
num_labels
=
self
.
num_labels
model
=
TFDistilBertForSequenceClassification
(
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
}
(
logits
,)
=
model
(
inputs
)
result
=
{
"logits"
:
logits
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
shape
),
[
self
.
batch_size
,
self
.
num_labels
])
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
}
return
config
,
inputs_dict
def
setUp
(
self
):
self
.
model_tester
=
TFDistilBertModelTest
.
TFDistilBertModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
DistilBertConfig
,
dim
=
37
)
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
def
test_distilbert_model
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_distilbert_model
(
*
config_and_inputs
)
def
test_for_masked_lm
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_distilbert_for_masked_lm
(
*
config_and_inputs
)
def
test_for_question_answering
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_distilbert_for_question_answering
(
*
config_and_inputs
)
def
test_for_sequence_classification
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_distilbert_for_sequence_classification
(
*
config_and_inputs
)
# @pytest.mark.slow
# def test_model_from_pretrained(self):
# cache_dir = "/tmp/pytorch_transformers_test/"
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
# shutil.rmtree(cache_dir)
# self.assertIsNotNone(model)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment