Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6c134444
Unverified
Commit
6c134444
authored
Jun 20, 2023
by
Sanchit Gandhi
Committed by
GitHub
Jun 20, 2023
Browse files
[Whisper] Make tests faster (#24105)
parent
f924df3c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
157 additions
and
4 deletions
+157
-4
tests/models/whisper/test_modeling_whisper.py
tests/models/whisper/test_modeling_whisper.py
+157
-4
No files found.
tests/models/whisper/test_modeling_whisper.py
View file @
6c134444
...
...
@@ -95,7 +95,7 @@ class WhisperModelTester:
self
,
parent
,
batch_size
=
2
,
seq_length
=
150
0
,
seq_length
=
6
0
,
is_training
=
True
,
use_labels
=
False
,
vocab_size
=
200
,
...
...
@@ -107,7 +107,7 @@ class WhisperModelTester:
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
20
,
max_source_positions
=
75
0
,
max_source_positions
=
3
0
,
max_target_positions
=
40
,
bos_token_id
=
98
,
eos_token_id
=
98
,
...
...
@@ -1538,7 +1538,7 @@ class WhisperEncoderModelTester:
self
,
parent
,
batch_size
=
2
,
seq_length
=
300
0
,
seq_length
=
6
0
,
is_training
=
True
,
use_labels
=
True
,
hidden_size
=
16
,
...
...
@@ -1549,7 +1549,7 @@ class WhisperEncoderModelTester:
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
20
,
max_source_positions
=
150
0
,
max_source_positions
=
3
0
,
num_mel_bins
=
80
,
num_conv_layers
=
1
,
suppress_tokens
=
None
,
...
...
@@ -1731,3 +1731,156 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
# WhisperEncoder cannot resize token embeddings since it has no tokens embeddings
def
test_resize_tokens_embeddings
(
self
):
pass
@
is_pt_flax_cross_test
def
test_equivalence_pt_to_flax
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
init_shape
=
(
1
,)
+
inputs_dict
[
"input_features"
].
shape
[
1
:]
for
model_class
in
self
.
all_model_classes
:
with
self
.
subTest
(
model_class
.
__name__
):
fx_model_class_name
=
"Flax"
+
model_class
.
__name__
if
not
hasattr
(
transformers
,
fx_model_class_name
):
# no flax model exists for this class
return
# Output all for aggressive testing
config
.
output_hidden_states
=
True
config
.
output_attentions
=
self
.
has_attentions
fx_model_class
=
getattr
(
transformers
,
fx_model_class_name
)
# load PyTorch class
pt_model
=
model_class
(
config
).
eval
()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model
.
config
.
use_cache
=
False
# load Flax class
fx_model
=
fx_model_class
(
config
,
input_shape
=
init_shape
,
dtype
=
jnp
.
float32
)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys
=
inspect
.
signature
(
fx_model
.
__call__
).
parameters
.
keys
()
# prepare inputs
pt_inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
# remove function args that don't exist in Flax
pt_inputs
=
{
k
:
v
for
k
,
v
in
pt_inputs
.
items
()
if
k
in
fx_input_keys
}
# send pytorch inputs to the correct device
pt_inputs
=
{
k
:
v
.
to
(
device
=
torch_device
)
if
isinstance
(
v
,
torch
.
Tensor
)
else
v
for
k
,
v
in
pt_inputs
.
items
()
}
# convert inputs to Flax
fx_inputs
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
pt_inputs
.
items
()
if
torch
.
is_tensor
(
v
)}
fx_state
=
convert_pytorch_state_dict_to_flax
(
pt_model
.
state_dict
(),
fx_model
)
fx_model
.
params
=
fx_state
# send pytorch model to the correct device
pt_model
.
to
(
torch_device
)
with
torch
.
no_grad
():
pt_outputs
=
pt_model
(
**
pt_inputs
)
fx_outputs
=
fx_model
(
**
fx_inputs
)
fx_keys
=
tuple
([
k
for
k
,
v
in
fx_outputs
.
items
()
if
v
is
not
None
])
pt_keys
=
tuple
([
k
for
k
,
v
in
pt_outputs
.
items
()
if
v
is
not
None
])
self
.
assertEqual
(
fx_keys
,
pt_keys
)
self
.
check_pt_flax_outputs
(
fx_outputs
,
pt_outputs
,
model_class
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pt_model
.
save_pretrained
(
tmpdirname
)
fx_model_loaded
=
fx_model_class
.
from_pretrained
(
tmpdirname
,
input_shape
=
init_shape
,
from_pt
=
True
)
fx_outputs_loaded
=
fx_model_loaded
(
**
fx_inputs
)
fx_keys
=
tuple
([
k
for
k
,
v
in
fx_outputs_loaded
.
items
()
if
v
is
not
None
])
pt_keys
=
tuple
([
k
for
k
,
v
in
pt_outputs
.
items
()
if
v
is
not
None
])
self
.
assertEqual
(
fx_keys
,
pt_keys
)
self
.
check_pt_flax_outputs
(
fx_outputs_loaded
,
pt_outputs
,
model_class
)
@
is_pt_flax_cross_test
def
test_equivalence_flax_to_pt
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
init_shape
=
(
1
,)
+
inputs_dict
[
"input_features"
].
shape
[
1
:]
for
model_class
in
self
.
all_model_classes
:
with
self
.
subTest
(
model_class
.
__name__
):
fx_model_class_name
=
"Flax"
+
model_class
.
__name__
if
not
hasattr
(
transformers
,
fx_model_class_name
):
# no flax model exists for this class
return
# Output all for aggressive testing
config
.
output_hidden_states
=
True
config
.
output_attentions
=
self
.
has_attentions
fx_model_class
=
getattr
(
transformers
,
fx_model_class_name
)
# load PyTorch class
pt_model
=
model_class
(
config
).
eval
()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model
.
config
.
use_cache
=
False
# load Flax class
fx_model
=
fx_model_class
(
config
,
input_shape
=
init_shape
,
dtype
=
jnp
.
float32
)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys
=
inspect
.
signature
(
fx_model
.
__call__
).
parameters
.
keys
()
# prepare inputs
pt_inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
# remove function args that don't exist in Flax
pt_inputs
=
{
k
:
v
for
k
,
v
in
pt_inputs
.
items
()
if
k
in
fx_input_keys
}
# send pytorch inputs to the correct device
pt_inputs
=
{
k
:
v
.
to
(
device
=
torch_device
)
if
isinstance
(
v
,
torch
.
Tensor
)
else
v
for
k
,
v
in
pt_inputs
.
items
()
}
# convert inputs to Flax
fx_inputs
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
pt_inputs
.
items
()
if
torch
.
is_tensor
(
v
)}
pt_model
=
load_flax_weights_in_pytorch_model
(
pt_model
,
fx_model
.
params
)
# make sure weights are tied in PyTorch
pt_model
.
tie_weights
()
# send pytorch model to the correct device
pt_model
.
to
(
torch_device
)
with
torch
.
no_grad
():
pt_outputs
=
pt_model
(
**
pt_inputs
)
fx_outputs
=
fx_model
(
**
fx_inputs
)
fx_keys
=
tuple
([
k
for
k
,
v
in
fx_outputs
.
items
()
if
v
is
not
None
])
pt_keys
=
tuple
([
k
for
k
,
v
in
pt_outputs
.
items
()
if
v
is
not
None
])
self
.
assertEqual
(
fx_keys
,
pt_keys
)
self
.
check_pt_flax_outputs
(
fx_outputs
,
pt_outputs
,
model_class
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
fx_model
.
save_pretrained
(
tmpdirname
)
pt_model_loaded
=
model_class
.
from_pretrained
(
tmpdirname
,
from_flax
=
True
)
# send pytorch model to the correct device
pt_model_loaded
.
to
(
torch_device
)
pt_model_loaded
.
eval
()
with
torch
.
no_grad
():
pt_outputs_loaded
=
pt_model_loaded
(
**
pt_inputs
)
fx_keys
=
tuple
([
k
for
k
,
v
in
fx_outputs
.
items
()
if
v
is
not
None
])
pt_keys
=
tuple
([
k
for
k
,
v
in
pt_outputs_loaded
.
items
()
if
v
is
not
None
])
self
.
assertEqual
(
fx_keys
,
pt_keys
)
self
.
check_pt_flax_outputs
(
fx_outputs
,
pt_outputs_loaded
,
model_class
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment