Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
505f2d74
Unverified
Commit
505f2d74
authored
Aug 20, 2020
by
Patrick von Platen
Committed by
GitHub
Aug 20, 2020
Browse files
[Tests] fix attention masks in Tests (#6621)
* fix distilbert * fix typo
parent
c9454507
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
35 additions
and
31 deletions
+35
-31
tests/test_modeling_albert.py
tests/test_modeling_albert.py
+2
-2
tests/test_modeling_bert.py
tests/test_modeling_bert.py
+2
-2
tests/test_modeling_common.py
tests/test_modeling_common.py
+7
-3
tests/test_modeling_ctrl.py
tests/test_modeling_ctrl.py
+2
-2
tests/test_modeling_distilbert.py
tests/test_modeling_distilbert.py
+2
-2
tests/test_modeling_dpr.py
tests/test_modeling_dpr.py
+2
-2
tests/test_modeling_electra.py
tests/test_modeling_electra.py
+2
-2
tests/test_modeling_flaubert.py
tests/test_modeling_flaubert.py
+2
-2
tests/test_modeling_gpt2.py
tests/test_modeling_gpt2.py
+2
-2
tests/test_modeling_longformer.py
tests/test_modeling_longformer.py
+2
-2
tests/test_modeling_mobilebert.py
tests/test_modeling_mobilebert.py
+2
-2
tests/test_modeling_reformer.py
tests/test_modeling_reformer.py
+2
-2
tests/test_modeling_roberta.py
tests/test_modeling_roberta.py
+2
-2
tests/test_modeling_xlm.py
tests/test_modeling_xlm.py
+2
-2
tests/test_modeling_xlnet.py
tests/test_modeling_xlnet.py
+2
-2
No files found.
tests/test_modeling_albert.py
View file @
505f2d74
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
if
is_torch_available
():
...
@@ -71,7 +71,7 @@ class AlbertModelTester:
...
@@ -71,7 +71,7 @@ class AlbertModelTester:
input_mask
=
None
input_mask
=
None
if
self
.
use_input_mask
:
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
token_type_ids
=
None
token_type_ids
=
None
if
self
.
use_token_type_ids
:
if
self
.
use_token_type_ids
:
...
...
tests/test_modeling_bert.py
View file @
505f2d74
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
if
is_torch_available
():
...
@@ -93,7 +93,7 @@ class BertModelTester:
...
@@ -93,7 +93,7 @@ class BertModelTester:
input_mask
=
None
input_mask
=
None
if
self
.
use_input_mask
:
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
token_type_ids
=
None
token_type_ids
=
None
if
self
.
use_token_type_ids
:
if
self
.
use_token_type_ids
:
...
...
tests/test_modeling_common.py
View file @
505f2d74
...
@@ -704,9 +704,6 @@ class ModelTesterMixin:
...
@@ -704,9 +704,6 @@ class ModelTesterMixin:
recursive_check
(
tuple_iterable_value
,
dict_iterable_value
)
recursive_check
(
tuple_iterable_value
,
dict_iterable_value
)
elif
tuple_object
is
None
:
elif
tuple_object
is
None
:
return
return
elif
torch
.
isinf
(
tuple_object
).
any
()
and
torch
.
isinf
(
dict_object
).
any
():
# TODO: (Lysandre) - maybe take a look if that's ok here
return
else
:
else
:
self
.
assertTrue
(
self
.
assertTrue
(
torch
.
allclose
(
tuple_object
,
dict_object
,
atol
=
1e-5
),
torch
.
allclose
(
tuple_object
,
dict_object
,
atol
=
1e-5
),
...
@@ -937,6 +934,13 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
...
@@ -937,6 +934,13 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
,
device
=
torch_device
).
view
(
shape
).
contiguous
()
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
,
device
=
torch_device
).
view
(
shape
).
contiguous
()
def
random_attention_mask
(
shape
,
rng
=
None
,
name
=
None
):
attn_mask
=
ids_tensor
(
shape
,
vocab_size
=
2
,
rng
=
None
,
name
=
None
)
# make sure that at least one token is attended to for each batch
attn_mask
[:,
-
1
]
=
1
return
attn_mask
def
floats_tensor
(
shape
,
scale
=
1.0
,
rng
=
None
,
name
=
None
):
def
floats_tensor
(
shape
,
scale
=
1.0
,
rng
=
None
,
name
=
None
):
"""Creates a random float32 tensor"""
"""Creates a random float32 tensor"""
if
rng
is
None
:
if
rng
is
None
:
...
...
tests/test_modeling_ctrl.py
View file @
505f2d74
...
@@ -19,7 +19,7 @@ from transformers import is_torch_available
...
@@ -19,7 +19,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
if
is_torch_available
():
...
@@ -60,7 +60,7 @@ class CTRLModelTester:
...
@@ -60,7 +60,7 @@ class CTRLModelTester:
input_mask
=
None
input_mask
=
None
if
self
.
use_input_mask
:
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
token_type_ids
=
None
token_type_ids
=
None
if
self
.
use_token_type_ids
:
if
self
.
use_token_type_ids
:
...
...
tests/test_modeling_distilbert.py
View file @
505f2d74
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
if
is_torch_available
():
...
@@ -89,7 +89,7 @@ if is_torch_available():
...
@@ -89,7 +89,7 @@ if is_torch_available():
input_mask
=
None
input_mask
=
None
if
self
.
use_input_mask
:
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
sequence_labels
=
None
sequence_labels
=
None
token_labels
=
None
token_labels
=
None
...
...
tests/test_modeling_dpr.py
View file @
505f2d74
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
if
is_torch_available
():
...
@@ -88,7 +88,7 @@ class DPRModelTester:
...
@@ -88,7 +88,7 @@ class DPRModelTester:
input_mask
=
None
input_mask
=
None
if
self
.
use_input_mask
:
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
token_type_ids
=
None
token_type_ids
=
None
if
self
.
use_token_type_ids
:
if
self
.
use_token_type_ids
:
...
...
tests/test_modeling_electra.py
View file @
505f2d74
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
if
is_torch_available
():
...
@@ -69,7 +69,7 @@ class ElectraModelTester:
...
@@ -69,7 +69,7 @@ class ElectraModelTester:
input_mask
=
None
input_mask
=
None
if
self
.
use_input_mask
:
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
token_type_ids
=
None
token_type_ids
=
None
if
self
.
use_token_type_ids
:
if
self
.
use_token_type_ids
:
...
...
tests/test_modeling_flaubert.py
View file @
505f2d74
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
if
is_torch_available
():
...
@@ -72,7 +72,7 @@ class FlaubertModelTester(object):
...
@@ -72,7 +72,7 @@ class FlaubertModelTester(object):
def
prepare_config_and_inputs
(
self
):
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
2
).
float
(
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
input_lengths
=
None
input_lengths
=
None
if
self
.
use_input_lengths
:
if
self
.
use_input_lengths
:
...
...
tests/test_modeling_gpt2.py
View file @
505f2d74
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
if
is_torch_available
():
...
@@ -92,7 +92,7 @@ class GPT2ModelTester:
...
@@ -92,7 +92,7 @@ class GPT2ModelTester:
input_mask
=
None
input_mask
=
None
if
self
.
use_input_mask
:
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
token_type_ids
=
None
token_type_ids
=
None
if
self
.
use_token_type_ids
:
if
self
.
use_token_type_ids
:
...
...
tests/test_modeling_longformer.py
View file @
505f2d74
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
if
is_torch_available
():
...
@@ -82,7 +82,7 @@ class LongformerModelTester:
...
@@ -82,7 +82,7 @@ class LongformerModelTester:
input_mask
=
None
input_mask
=
None
if
self
.
use_input_mask
:
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
token_type_ids
=
None
token_type_ids
=
None
if
self
.
use_token_type_ids
:
if
self
.
use_token_type_ids
:
...
...
tests/test_modeling_mobilebert.py
View file @
505f2d74
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
if
is_torch_available
():
...
@@ -94,7 +94,7 @@ class MobileBertModelTester:
...
@@ -94,7 +94,7 @@ class MobileBertModelTester:
input_mask
=
None
input_mask
=
None
if
self
.
use_input_mask
:
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
token_type_ids
=
None
token_type_ids
=
None
if
self
.
use_token_type_ids
:
if
self
.
use_token_type_ids
:
...
...
tests/test_modeling_reformer.py
View file @
505f2d74
...
@@ -19,7 +19,7 @@ from transformers import is_torch_available
...
@@ -19,7 +19,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_multigpu
,
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_multigpu
,
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
if
is_torch_available
():
...
@@ -133,7 +133,7 @@ class ReformerModelTester:
...
@@ -133,7 +133,7 @@ class ReformerModelTester:
input_mask
=
None
input_mask
=
None
if
self
.
use_input_mask
:
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
choice_labels
=
None
choice_labels
=
None
if
self
.
use_labels
:
if
self
.
use_labels
:
...
...
tests/test_modeling_roberta.py
View file @
505f2d74
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
if
is_torch_available
():
...
@@ -71,7 +71,7 @@ class RobertaModelTester:
...
@@ -71,7 +71,7 @@ class RobertaModelTester:
input_mask
=
None
input_mask
=
None
if
self
.
use_input_mask
:
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
vocab_size
=
2
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
token_type_ids
=
None
token_type_ids
=
None
if
self
.
use_token_type_ids
:
if
self
.
use_token_type_ids
:
...
...
tests/test_modeling_xlm.py
View file @
505f2d74
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
...
@@ -20,7 +20,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
if
is_torch_available
():
...
@@ -73,7 +73,7 @@ class XLMModelTester:
...
@@ -73,7 +73,7 @@ class XLMModelTester:
def
prepare_config_and_inputs
(
self
):
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
2
).
float
(
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
input_lengths
=
None
input_lengths
=
None
if
self
.
use_input_lengths
:
if
self
.
use_input_lengths
:
...
...
tests/test_modeling_xlnet.py
View file @
505f2d74
...
@@ -21,7 +21,7 @@ from transformers import is_torch_available
...
@@ -21,7 +21,7 @@ from transformers import is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
if
is_torch_available
():
if
is_torch_available
():
...
@@ -100,7 +100,7 @@ class XLNetModelTester:
...
@@ -100,7 +100,7 @@ class XLNetModelTester:
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
segment_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
segment_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
]
,
2
).
float
(
)
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
input_ids_q
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
input_ids_q
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
perm_mask
=
torch
.
zeros
(
perm_mask
=
torch
.
zeros
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment