Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8d4bb020
Unverified
Commit
8d4bb020
authored
Dec 10, 2020
by
Sylvain Gugger
Committed by
GitHub
Dec 10, 2020
Browse files
Refactor FLAX tests (#9034)
parent
1310e1a7
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
298 additions
and
114 deletions
+298
-114
src/transformers/tokenization_utils_base.py
src/transformers/tokenization_utils_base.py
+1
-0
tests/test_modeling_flax_bert.py
tests/test_modeling_flax_bert.py
+85
-57
tests/test_modeling_flax_common.py
tests/test_modeling_flax_common.py
+127
-0
tests/test_modeling_flax_roberta.py
tests/test_modeling_flax_roberta.py
+85
-57
No files found.
src/transformers/tokenization_utils_base.py
View file @
8d4bb020
...
@@ -50,6 +50,7 @@ if is_tf_available():
...
@@ -50,6 +50,7 @@ if is_tf_available():
if
is_torch_available
():
if
is_torch_available
():
import
torch
import
torch
if
is_flax_available
():
if
is_flax_available
():
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
...
...
tests/test_modeling_flax_bert.py
View file @
8d4bb020
...
@@ -14,70 +14,98 @@
...
@@ -14,70 +14,98 @@
import
unittest
import
unittest
from
numpy
import
ndarray
from
transformers
import
BertConfig
,
is_flax_available
from
transformers.testing_utils
import
require_flax
from
transformers
import
BertTokenizerFast
,
TensorType
,
is_flax_available
,
is_torch_available
from
.test_modeling_flax_common
import
FlaxModelTesterMixin
,
ids_tensor
,
random_attention_mask
from
transformers.testing_utils
import
require_flax
,
require_torch
if
is_flax_available
():
if
is_flax_available
():
import
os
os
.
environ
[
"XLA_PYTHON_CLIENT_MEM_FRACTION"
]
=
"0.12"
# assumed parallelism: 8
import
jax
from
transformers.models.bert.modeling_flax_bert
import
FlaxBertModel
from
transformers.models.bert.modeling_flax_bert
import
FlaxBertModel
if
is_torch_available
():
import
torch
from
transformers.models.bert.modeling_bert
import
BertModel
class
FlaxBertModelTester
(
unittest
.
TestCase
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
is_training
=
True
,
use_attention_mask
=
True
,
use_token_type_ids
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_attention_heads
=
4
,
intermediate_size
=
37
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
is_training
=
is_training
self
.
use_attention_mask
=
use_attention_mask
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
intermediate_size
=
intermediate_size
self
.
hidden_act
=
hidden_act
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
attention_mask
=
None
if
self
.
use_attention_mask
:
attention_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
token_type_ids
=
None
if
self
.
use_token_type_ids
:
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
config
=
BertConfig
(
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
num_attention_heads
=
self
.
num_attention_heads
,
intermediate_size
=
self
.
intermediate_size
,
hidden_act
=
self
.
hidden_act
,
hidden_dropout_prob
=
self
.
hidden_dropout_prob
,
attention_probs_dropout_prob
=
self
.
attention_probs_dropout_prob
,
max_position_embeddings
=
self
.
max_position_embeddings
,
type_vocab_size
=
self
.
type_vocab_size
,
is_decoder
=
False
,
initializer_range
=
self
.
initializer_range
,
)
return
config
,
input_ids
,
token_type_ids
,
attention_mask
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
config
,
input_ids
,
token_type_ids
,
attention_mask
=
config_and_inputs
inputs_dict
=
{
"input_ids"
:
input_ids
,
"token_type_ids"
:
token_type_ids
,
"attention_mask"
:
attention_mask
}
return
config
,
inputs_dict
@
require_flax
@
require_flax
@
require_torch
class
FlaxBertModelTest
(
FlaxModelTesterMixin
,
unittest
.
TestCase
):
class
FlaxBertModelTest
(
unittest
.
TestCase
):
def
assert_almost_equals
(
self
,
a
:
ndarray
,
b
:
ndarray
,
tol
:
float
):
all_model_classes
=
(
FlaxBertModel
,)
if
is_flax_available
()
else
()
diff
=
(
a
-
b
).
sum
()
self
.
assertLessEqual
(
diff
,
tol
,
f
"Difference between torch and flax is
{
diff
}
(>=
{
tol
}
)"
)
def
setUp
(
self
):
self
.
model_tester
=
FlaxBertModelTester
(
self
)
def
test_from_pytorch
(
self
):
with
torch
.
no_grad
():
with
self
.
subTest
(
"bert-base-cased"
):
tokenizer
=
BertTokenizerFast
.
from_pretrained
(
"bert-base-cased"
)
fx_model
=
FlaxBertModel
.
from_pretrained
(
"bert-base-cased"
)
pt_model
=
BertModel
.
from_pretrained
(
"bert-base-cased"
)
# Check for simple input
pt_inputs
=
tokenizer
.
encode_plus
(
"This is a simple input"
,
return_tensors
=
TensorType
.
PYTORCH
)
fx_inputs
=
tokenizer
.
encode_plus
(
"This is a simple input"
,
return_tensors
=
TensorType
.
JAX
)
pt_outputs
=
pt_model
(
**
pt_inputs
).
to_tuple
()
fx_outputs
=
fx_model
(
**
fx_inputs
)
self
.
assertEqual
(
len
(
fx_outputs
),
len
(
pt_outputs
),
"Output lengths differ between Flax and PyTorch"
)
for
fx_output
,
pt_output
in
zip
(
fx_outputs
,
pt_outputs
):
self
.
assert_almost_equals
(
fx_output
,
pt_output
.
numpy
(),
5e-3
)
def
test_multiple_sequences
(
self
):
tokenizer
=
BertTokenizerFast
.
from_pretrained
(
"bert-base-cased"
)
model
=
FlaxBertModel
.
from_pretrained
(
"bert-base-cased"
)
sequences
=
[
"this is an example sentence"
,
"this is another"
,
"and a third one"
]
encodings
=
tokenizer
(
sequences
,
return_tensors
=
TensorType
.
JAX
,
padding
=
True
,
truncation
=
True
)
@
jax
.
jit
def
model_jitted
(
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
):
return
model
(
input_ids
,
attention_mask
,
token_type_ids
)
with
self
.
subTest
(
"JIT Disabled"
):
with
jax
.
disable_jit
():
tokens
,
pooled
=
model_jitted
(
**
encodings
)
self
.
assertEqual
(
tokens
.
shape
,
(
3
,
7
,
768
))
self
.
assertEqual
(
pooled
.
shape
,
(
3
,
768
))
with
self
.
subTest
(
"JIT Enabled"
):
jitted_tokens
,
jitted_pooled
=
model_jitted
(
**
encodings
)
self
.
assertEqual
(
jitted_tokens
.
shape
,
(
3
,
7
,
768
))
self
.
assertEqual
(
jitted_pooled
.
shape
,
(
3
,
768
))
tests/test_modeling_flax_common.py
0 → 100644
View file @
8d4bb020
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
random
import
numpy
as
np
import
transformers
from
transformers
import
is_flax_available
,
is_torch_available
from
transformers.testing_utils
import
require_flax
,
require_torch
if
is_flax_available
():
import
os
import
jax
import
jax.numpy
as
jnp
from
flax.traverse_util
import
unflatten_dict
os
.
environ
[
"XLA_PYTHON_CLIENT_MEM_FRACTION"
]
=
"0.12"
# assumed parallelism: 8
if
is_torch_available
():
import
torch
def
ids_tensor
(
shape
,
vocab_size
,
rng
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
if
rng
is
None
:
rng
=
random
.
Random
()
total_dims
=
1
for
dim
in
shape
:
total_dims
*=
dim
values
=
[]
for
_
in
range
(
total_dims
):
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
output
=
np
.
array
(
values
,
dtype
=
jnp
.
int32
).
reshape
(
shape
)
return
output
def
random_attention_mask
(
shape
,
rng
=
None
):
attn_mask
=
ids_tensor
(
shape
,
vocab_size
=
2
,
rng
=
rng
)
# make sure that at least one token is attended to for each batch
attn_mask
[:,
-
1
]
=
1
return
attn_mask
def
convert_pt_model_to_flax
(
pt_model
,
config
,
flax_model_cls
):
state
=
pt_model
.
state_dict
()
state
=
{
k
:
v
.
numpy
()
for
k
,
v
in
state
.
items
()}
state
=
flax_model_cls
.
convert_from_pytorch
(
state
,
config
)
state
=
unflatten_dict
({
tuple
(
k
.
split
(
"."
)):
v
for
k
,
v
in
state
.
items
()})
return
flax_model_cls
(
config
,
state
,
dtype
=
jnp
.
float32
)
@
require_flax
class
FlaxModelTesterMixin
:
model_tester
=
None
all_model_classes
=
()
def
assert_almost_equals
(
self
,
a
:
np
.
ndarray
,
b
:
np
.
ndarray
,
tol
:
float
):
diff
=
np
.
abs
((
a
-
b
)).
sum
()
self
.
assertLessEqual
(
diff
,
tol
,
f
"Difference between torch and flax is
{
diff
}
(>=
{
tol
}
)."
)
@
require_torch
def
test_equivalence_flax_pytorch
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
with
self
.
subTest
(
model_class
.
__name__
):
pt_model_class_name
=
model_class
.
__name__
[
4
:]
# Skip the "Flax" at the beginning
pt_model_class
=
getattr
(
transformers
,
pt_model_class_name
)
pt_model
=
pt_model_class
(
config
).
eval
()
fx_model
=
convert_pt_model_to_flax
(
pt_model
,
config
,
model_class
)
pt_inputs
=
{
k
:
torch
.
tensor
(
v
.
tolist
())
for
k
,
v
in
inputs_dict
.
items
()}
with
torch
.
no_grad
():
pt_outputs
=
pt_model
(
**
pt_inputs
).
to_tuple
()
fx_outputs
=
fx_model
(
**
inputs_dict
)
self
.
assertEqual
(
len
(
fx_outputs
),
len
(
pt_outputs
),
"Output lengths differ between Flax and PyTorch"
)
for
fx_output
,
pt_output
in
zip
(
fx_outputs
,
pt_outputs
):
self
.
assert_almost_equals
(
fx_output
,
pt_output
.
numpy
(),
5e-3
)
@
require_torch
def
test_jit_compilation
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
with
self
.
subTest
(
model_class
.
__name__
):
# TODO later: have some way to initialize easily a Flax model from config, for now I go through PT
pt_model_class_name
=
model_class
.
__name__
[
4
:]
# Skip the "Flax" at the beginning
pt_model_class
=
getattr
(
transformers
,
pt_model_class_name
)
pt_model
=
pt_model_class
(
config
).
eval
()
model
=
convert_pt_model_to_flax
(
pt_model
,
config
,
model_class
)
@
jax
.
jit
def
model_jitted
(
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
):
return
model
(
input_ids
,
attention_mask
,
token_type_ids
)
with
self
.
subTest
(
"JIT Disabled"
):
with
jax
.
disable_jit
():
outputs
=
model_jitted
(
**
inputs_dict
)
with
self
.
subTest
(
"JIT Enabled"
):
jitted_outputs
=
model_jitted
(
**
inputs_dict
)
self
.
assertEqual
(
len
(
outputs
),
len
(
jitted_outputs
))
for
jitted_output
,
output
in
zip
(
jitted_outputs
,
outputs
):
self
.
assertEqual
(
jitted_output
.
shape
,
output
.
shape
)
tests/test_modeling_flax_roberta.py
View file @
8d4bb020
...
@@ -14,70 +14,98 @@
...
@@ -14,70 +14,98 @@
import
unittest
import
unittest
from
numpy
import
ndarray
from
transformers
import
RobertaConfig
,
is_flax_available
from
transformers.testing_utils
import
require_flax
from
transformers
import
RobertaTokenizerFast
,
TensorType
,
is_flax_available
,
is_torch_available
from
.test_modeling_flax_common
import
FlaxModelTesterMixin
,
ids_tensor
,
random_attention_mask
from
transformers.testing_utils
import
require_flax
,
require_torch
if
is_flax_available
():
if
is_flax_available
():
import
os
os
.
environ
[
"XLA_PYTHON_CLIENT_MEM_FRACTION"
]
=
"0.12"
# assumed parallelism: 8
import
jax
from
transformers.models.roberta.modeling_flax_roberta
import
FlaxRobertaModel
from
transformers.models.roberta.modeling_flax_roberta
import
FlaxRobertaModel
if
is_torch_available
():
import
torch
from
transformers.models.roberta.modeling_roberta
import
RobertaModel
class
FlaxRobertaModelTester
(
unittest
.
TestCase
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
is_training
=
True
,
use_attention_mask
=
True
,
use_token_type_ids
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_attention_heads
=
4
,
intermediate_size
=
37
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
is_training
=
is_training
self
.
use_attention_mask
=
use_attention_mask
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
intermediate_size
=
intermediate_size
self
.
hidden_act
=
hidden_act
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
attention_mask
=
None
if
self
.
use_attention_mask
:
attention_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
token_type_ids
=
None
if
self
.
use_token_type_ids
:
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
config
=
RobertaConfig
(
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
num_attention_heads
=
self
.
num_attention_heads
,
intermediate_size
=
self
.
intermediate_size
,
hidden_act
=
self
.
hidden_act
,
hidden_dropout_prob
=
self
.
hidden_dropout_prob
,
attention_probs_dropout_prob
=
self
.
attention_probs_dropout_prob
,
max_position_embeddings
=
self
.
max_position_embeddings
,
type_vocab_size
=
self
.
type_vocab_size
,
is_decoder
=
False
,
initializer_range
=
self
.
initializer_range
,
)
return
config
,
input_ids
,
token_type_ids
,
attention_mask
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
config
,
input_ids
,
token_type_ids
,
attention_mask
=
config_and_inputs
inputs_dict
=
{
"input_ids"
:
input_ids
,
"token_type_ids"
:
token_type_ids
,
"attention_mask"
:
attention_mask
}
return
config
,
inputs_dict
@
require_flax
@
require_flax
@
require_torch
class
FlaxRobertaModelTest
(
FlaxModelTesterMixin
,
unittest
.
TestCase
):
class
FlaxRobertaModelTest
(
unittest
.
TestCase
):
def
assert_almost_equals
(
self
,
a
:
ndarray
,
b
:
ndarray
,
tol
:
float
):
all_model_classes
=
(
FlaxRobertaModel
,)
if
is_flax_available
()
else
()
diff
=
(
a
-
b
).
sum
()
self
.
assertLessEqual
(
diff
,
tol
,
f
"Difference between torch and flax is
{
diff
}
(>=
{
tol
}
)"
)
def
setUp
(
self
):
self
.
model_tester
=
FlaxRobertaModelTester
(
self
)
def
test_from_pytorch
(
self
):
with
torch
.
no_grad
():
with
self
.
subTest
(
"roberta-base"
):
tokenizer
=
RobertaTokenizerFast
.
from_pretrained
(
"roberta-base"
)
fx_model
=
FlaxRobertaModel
.
from_pretrained
(
"roberta-base"
)
pt_model
=
RobertaModel
.
from_pretrained
(
"roberta-base"
)
# Check for simple input
pt_inputs
=
tokenizer
.
encode_plus
(
"This is a simple input"
,
return_tensors
=
TensorType
.
PYTORCH
)
fx_inputs
=
tokenizer
.
encode_plus
(
"This is a simple input"
,
return_tensors
=
TensorType
.
JAX
)
pt_outputs
=
pt_model
(
**
pt_inputs
)
fx_outputs
=
fx_model
(
**
fx_inputs
)
self
.
assertEqual
(
len
(
fx_outputs
),
len
(
pt_outputs
),
"Output lengths differ between Flax and PyTorch"
)
for
fx_output
,
pt_output
in
zip
(
fx_outputs
,
pt_outputs
.
to_tuple
()):
self
.
assert_almost_equals
(
fx_output
,
pt_output
.
numpy
(),
5e-3
)
def
test_multiple_sequences
(
self
):
tokenizer
=
RobertaTokenizerFast
.
from_pretrained
(
"roberta-base"
)
model
=
FlaxRobertaModel
.
from_pretrained
(
"roberta-base"
)
sequences
=
[
"this is an example sentence"
,
"this is another"
,
"and a third one"
]
encodings
=
tokenizer
(
sequences
,
return_tensors
=
TensorType
.
JAX
,
padding
=
True
,
truncation
=
True
)
@
jax
.
jit
def
model_jitted
(
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
):
return
model
(
input_ids
,
attention_mask
,
token_type_ids
)
with
self
.
subTest
(
"JIT Disabled"
):
with
jax
.
disable_jit
():
tokens
,
pooled
=
model_jitted
(
**
encodings
)
self
.
assertEqual
(
tokens
.
shape
,
(
3
,
7
,
768
))
self
.
assertEqual
(
pooled
.
shape
,
(
3
,
768
))
with
self
.
subTest
(
"JIT Enabled"
):
jitted_tokens
,
jitted_pooled
=
model_jitted
(
**
encodings
)
self
.
assertEqual
(
jitted_tokens
.
shape
,
(
3
,
7
,
768
))
self
.
assertEqual
(
jitted_pooled
.
shape
,
(
3
,
768
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment