Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c9bce181
Commit
c9bce181
authored
Aug 28, 2019
by
thomwolf
Browse files
fixing model to add torchscript, embedding resizing, head pruning and masking + tests
parent
62df4ba5
Changes
3
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
253 additions
and
138 deletions
+253
-138
pytorch_transformers/modeling_bert.py
pytorch_transformers/modeling_bert.py
+1
-1
pytorch_transformers/modeling_dilbert.py
pytorch_transformers/modeling_dilbert.py
+244
-127
pytorch_transformers/tests/modeling_dilbert_test.py
pytorch_transformers/tests/modeling_dilbert_test.py
+8
-10
No files found.
pytorch_transformers/modeling_bert.py
View file @
c9bce181
...
...
@@ -449,7 +449,7 @@ class BertEncoder(nn.Module):
outputs
=
outputs
+
(
all_hidden_states
,)
if
self
.
output_attentions
:
outputs
=
outputs
+
(
all_attentions
,)
return
outputs
#
outputs, (
hidden states), (attentions)
return
outputs
#
last-layer hidden state, (all
hidden states), (
all
attentions)
class
BertPooler
(
nn
.
Module
):
...
...
pytorch_transformers/modeling_dilbert.py
View file @
c9bce181
This diff is collapsed.
Click to expand it.
pytorch_transformers/tests/modeling_dilbert_test.py
View file @
c9bce181
...
...
@@ -31,10 +31,10 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
all_model_classes
=
(
DilBertModel
,
DilBertForMaskedLM
,
DilBertForQuestionAnswering
,
DilBertForSequenceClassification
)
test_pruning
=
Fals
e
test_torchscript
=
Fals
e
test_resize_embeddings
=
Fals
e
test_head_masking
=
Fals
e
test_pruning
=
Tru
e
test_torchscript
=
Tru
e
test_resize_embeddings
=
Tru
e
test_head_masking
=
Tru
e
class
DilBertModelTester
(
object
):
...
...
@@ -122,22 +122,20 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
def
create_and_check_dilbert_model
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
DilBertModel
(
config
=
config
)
model
.
eval
()
sequence_output
,
pooled_output
=
model
(
input_ids
,
input_mask
)
sequence_output
,
pooled_output
=
model
(
input_ids
)
(
sequence_output
,
)
=
model
(
input_ids
,
input_mask
)
(
sequence_output
,
)
=
model
(
input_ids
)
result
=
{
"sequence_output"
:
sequence_output
,
"pooled_output"
:
pooled_output
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"sequence_output"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"pooled_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
def
create_and_check_dilbert_for_masked_lm
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
DilBertForMaskedLM
(
config
=
config
)
model
.
eval
()
loss
,
prediction_scores
=
model
(
input_ids
,
input_mask
,
token_labels
)
loss
,
prediction_scores
=
model
(
input_ids
,
attention_mask
=
input_mask
,
masked_lm_labels
=
token_labels
)
result
=
{
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment