Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
505f2d74
Unverified
Commit
505f2d74
authored
Aug 20, 2020
by
Patrick von Platen
Committed by
GitHub
Aug 20, 2020
Browse files
[Tests] fix attention masks in Tests (#6621)
* fix distilbert * fix typo
parent
c9454507
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
35 additions
and
31 deletions
+35
-31
tests/test_modeling_albert.py
tests/test_modeling_albert.py
+2
-2
tests/test_modeling_bert.py
tests/test_modeling_bert.py
+2
-2
tests/test_modeling_common.py
tests/test_modeling_common.py
+7
-3
tests/test_modeling_ctrl.py
tests/test_modeling_ctrl.py
+2
-2
tests/test_modeling_distilbert.py
tests/test_modeling_distilbert.py
+2
-2
tests/test_modeling_dpr.py
tests/test_modeling_dpr.py
+2
-2
tests/test_modeling_electra.py
tests/test_modeling_electra.py
+2
-2
tests/test_modeling_flaubert.py
tests/test_modeling_flaubert.py
+2
-2
tests/test_modeling_gpt2.py
tests/test_modeling_gpt2.py
+2
-2
tests/test_modeling_longformer.py
tests/test_modeling_longformer.py
+2
-2
tests/test_modeling_mobilebert.py
tests/test_modeling_mobilebert.py
+2
-2
tests/test_modeling_reformer.py
tests/test_modeling_reformer.py
+2
-2
tests/test_modeling_roberta.py
tests/test_modeling_roberta.py
+2
-2
tests/test_modeling_xlm.py
tests/test_modeling_xlm.py
+2
-2
tests/test_modeling_xlnet.py
tests/test_modeling_xlnet.py
+2
-2
No files found.
tests/test_modeling_albert.py
View file @
505f2d74
...
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
...
...
@@ -71,7 +71,7 @@ class AlbertModelTester:
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
token_type_ids
=
None
if
self
.
use_token_type_ids
:
...
...
tests/test_modeling_bert.py
View file @
505f2d74
...
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
...
...
@@ -93,7 +93,7 @@ class BertModelTester:
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
token_type_ids
=
None
if
self
.
use_token_type_ids
:
...
...
tests/test_modeling_common.py
View file @
505f2d74
...
...
@@ -704,9 +704,6 @@ class ModelTesterMixin:
recursive_check
(
tuple_iterable_value
,
dict_iterable_value
)
elif
tuple_object
is
None
:
return
elif
torch
.
isinf
(
tuple_object
).
any
()
and
torch
.
isinf
(
dict_object
).
any
():
# TODO: (Lysandre) - maybe take a look if that's ok here
return
else
:
self
.
assertTrue
(
torch
.
allclose
(
tuple_object
,
dict_object
,
atol
=
1e-5
),
...
...
@@ -937,6 +934,13 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
,
device
=
torch_device
).
view
(
shape
).
contiguous
()
def
random_attention_mask
(
shape
,
rng
=
None
,
name
=
None
):
attn_mask
=
ids_tensor
(
shape
,
vocab_size
=
2
,
rng
=
None
,
name
=
None
)
# make sure that at least one token is attended to for each batch
attn_mask
[:,
-
1
]
=
1
return
attn_mask
def
floats_tensor
(
shape
,
scale
=
1.0
,
rng
=
None
,
name
=
None
):
"""Creates a random float32 tensor"""
if
rng
is
None
:
...
...
tests/test_modeling_ctrl.py
View file @
505f2d74
...
...
@@ -19,7 +19,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
...
...
@@ -60,7 +60,7 @@ class CTRLModelTester:
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
token_type_ids
=
None
if
self
.
use_token_type_ids
:
...
...
tests/test_modeling_distilbert.py
View file @
505f2d74
...
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
...
...
@@ -89,7 +89,7 @@ if is_torch_available():
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
sequence_labels
=
None
token_labels
=
None
...
...
tests/test_modeling_dpr.py
View file @
505f2d74
...
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
...
...
@@ -88,7 +88,7 @@ class DPRModelTester:
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
token_type_ids
=
None
if
self
.
use_token_type_ids
:
...
...
tests/test_modeling_electra.py
View file @
505f2d74
...
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
...
...
@@ -69,7 +69,7 @@ class ElectraModelTester:
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
token_type_ids
=
None
if
self
.
use_token_type_ids
:
...
...
tests/test_modeling_flaubert.py
View file @
505f2d74
...
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
...
...
@@ -72,7 +72,7 @@ class FlaubertModelTester(object):
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
2
).
float
(
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
input_lengths
=
None
if
self
.
use_input_lengths
:
...
...
tests/test_modeling_gpt2.py
View file @
505f2d74
...
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
...
...
@@ -92,7 +92,7 @@ class GPT2ModelTester:
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
token_type_ids
=
None
if
self
.
use_token_type_ids
:
...
...
tests/test_modeling_longformer.py
View file @
505f2d74
...
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
...
...
@@ -82,7 +82,7 @@ class LongformerModelTester:
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
token_type_ids
=
None
if
self
.
use_token_type_ids
:
...
...
tests/test_modeling_mobilebert.py
View file @
505f2d74
...
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
...
...
@@ -94,7 +94,7 @@ class MobileBertModelTester:
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
token_type_ids
=
None
if
self
.
use_token_type_ids
:
...
...
tests/test_modeling_reformer.py
View file @
505f2d74
...
...
@@ -19,7 +19,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_multigpu
,
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
...
...
@@ -133,7 +133,7 @@ class ReformerModelTester:
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
choice_labels
=
None
if
self
.
use_labels
:
...
...
tests/test_modeling_roberta.py
View file @
505f2d74
...
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
...
...
@@ -71,7 +71,7 @@ class RobertaModelTester:
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
token_type_ids
=
None
if
self
.
use_token_type_ids
:
...
...
tests/test_modeling_xlm.py
View file @
505f2d74
...
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
...
...
@@ -73,7 +73,7 @@ class XLMModelTester:
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
2
).
float
(
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
input_lengths
=
None
if
self
.
use_input_lengths
:
...
...
tests/test_modeling_xlnet.py
View file @
505f2d74
...
...
@@ -21,7 +21,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
...
...
@@ -100,7 +100,7 @@ class XLNetModelTester:
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
segment_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
2
).
float
(
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
input_ids_q
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
perm_mask
=
torch
.
zeros
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment