Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6c134444
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "269b05493917af2f7e86bafc735576a1a22caf4f"
Unverified
Commit
6c134444
authored
Jun 20, 2023
by
Sanchit Gandhi
Committed by
GitHub
Jun 20, 2023
Browse files
[Whisper] Make tests faster (#24105)
parent
f924df3c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
157 additions
and
4 deletions
+157
-4
tests/models/whisper/test_modeling_whisper.py
tests/models/whisper/test_modeling_whisper.py
+157
-4
No files found.
tests/models/whisper/test_modeling_whisper.py
View file @
6c134444
...
@@ -95,7 +95,7 @@ class WhisperModelTester:
...
@@ -95,7 +95,7 @@ class WhisperModelTester:
self
,
self
,
parent
,
parent
,
batch_size
=
2
,
batch_size
=
2
,
seq_length
=
150
0
,
seq_length
=
6
0
,
is_training
=
True
,
is_training
=
True
,
use_labels
=
False
,
use_labels
=
False
,
vocab_size
=
200
,
vocab_size
=
200
,
...
@@ -107,7 +107,7 @@ class WhisperModelTester:
...
@@ -107,7 +107,7 @@ class WhisperModelTester:
hidden_dropout_prob
=
0.1
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
20
,
max_position_embeddings
=
20
,
max_source_positions
=
75
0
,
max_source_positions
=
3
0
,
max_target_positions
=
40
,
max_target_positions
=
40
,
bos_token_id
=
98
,
bos_token_id
=
98
,
eos_token_id
=
98
,
eos_token_id
=
98
,
...
@@ -1538,7 +1538,7 @@ class WhisperEncoderModelTester:
...
@@ -1538,7 +1538,7 @@ class WhisperEncoderModelTester:
self
,
self
,
parent
,
parent
,
batch_size
=
2
,
batch_size
=
2
,
seq_length
=
300
0
,
seq_length
=
6
0
,
is_training
=
True
,
is_training
=
True
,
use_labels
=
True
,
use_labels
=
True
,
hidden_size
=
16
,
hidden_size
=
16
,
...
@@ -1549,7 +1549,7 @@ class WhisperEncoderModelTester:
...
@@ -1549,7 +1549,7 @@ class WhisperEncoderModelTester:
hidden_dropout_prob
=
0.1
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
20
,
max_position_embeddings
=
20
,
max_source_positions
=
150
0
,
max_source_positions
=
3
0
,
num_mel_bins
=
80
,
num_mel_bins
=
80
,
num_conv_layers
=
1
,
num_conv_layers
=
1
,
suppress_tokens
=
None
,
suppress_tokens
=
None
,
...
@@ -1731,3 +1731,156 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
...
@@ -1731,3 +1731,156 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
# WhisperEncoder cannot resize token embeddings since it has no tokens embeddings
# WhisperEncoder cannot resize token embeddings since it has no tokens embeddings
def
test_resize_tokens_embeddings
(
self
):
def
test_resize_tokens_embeddings
(
self
):
pass
pass
@
is_pt_flax_cross_test
def
test_equivalence_pt_to_flax
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
init_shape
=
(
1
,)
+
inputs_dict
[
"input_features"
].
shape
[
1
:]
for
model_class
in
self
.
all_model_classes
:
with
self
.
subTest
(
model_class
.
__name__
):
fx_model_class_name
=
"Flax"
+
model_class
.
__name__
if
not
hasattr
(
transformers
,
fx_model_class_name
):
# no flax model exists for this class
return
# Output all for aggressive testing
config
.
output_hidden_states
=
True
config
.
output_attentions
=
self
.
has_attentions
fx_model_class
=
getattr
(
transformers
,
fx_model_class_name
)
# load PyTorch class
pt_model
=
model_class
(
config
).
eval
()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model
.
config
.
use_cache
=
False
# load Flax class
fx_model
=
fx_model_class
(
config
,
input_shape
=
init_shape
,
dtype
=
jnp
.
float32
)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys
=
inspect
.
signature
(
fx_model
.
__call__
).
parameters
.
keys
()
# prepare inputs
pt_inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
# remove function args that don't exist in Flax
pt_inputs
=
{
k
:
v
for
k
,
v
in
pt_inputs
.
items
()
if
k
in
fx_input_keys
}
# send pytorch inputs to the correct device
pt_inputs
=
{
k
:
v
.
to
(
device
=
torch_device
)
if
isinstance
(
v
,
torch
.
Tensor
)
else
v
for
k
,
v
in
pt_inputs
.
items
()
}
# convert inputs to Flax
fx_inputs
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
pt_inputs
.
items
()
if
torch
.
is_tensor
(
v
)}
fx_state
=
convert_pytorch_state_dict_to_flax
(
pt_model
.
state_dict
(),
fx_model
)
fx_model
.
params
=
fx_state
# send pytorch model to the correct device
pt_model
.
to
(
torch_device
)
with
torch
.
no_grad
():
pt_outputs
=
pt_model
(
**
pt_inputs
)
fx_outputs
=
fx_model
(
**
fx_inputs
)
fx_keys
=
tuple
([
k
for
k
,
v
in
fx_outputs
.
items
()
if
v
is
not
None
])
pt_keys
=
tuple
([
k
for
k
,
v
in
pt_outputs
.
items
()
if
v
is
not
None
])
self
.
assertEqual
(
fx_keys
,
pt_keys
)
self
.
check_pt_flax_outputs
(
fx_outputs
,
pt_outputs
,
model_class
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pt_model
.
save_pretrained
(
tmpdirname
)
fx_model_loaded
=
fx_model_class
.
from_pretrained
(
tmpdirname
,
input_shape
=
init_shape
,
from_pt
=
True
)
fx_outputs_loaded
=
fx_model_loaded
(
**
fx_inputs
)
fx_keys
=
tuple
([
k
for
k
,
v
in
fx_outputs_loaded
.
items
()
if
v
is
not
None
])
pt_keys
=
tuple
([
k
for
k
,
v
in
pt_outputs
.
items
()
if
v
is
not
None
])
self
.
assertEqual
(
fx_keys
,
pt_keys
)
self
.
check_pt_flax_outputs
(
fx_outputs_loaded
,
pt_outputs
,
model_class
)
@
is_pt_flax_cross_test
def
test_equivalence_flax_to_pt
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
init_shape
=
(
1
,)
+
inputs_dict
[
"input_features"
].
shape
[
1
:]
for
model_class
in
self
.
all_model_classes
:
with
self
.
subTest
(
model_class
.
__name__
):
fx_model_class_name
=
"Flax"
+
model_class
.
__name__
if
not
hasattr
(
transformers
,
fx_model_class_name
):
# no flax model exists for this class
return
# Output all for aggressive testing
config
.
output_hidden_states
=
True
config
.
output_attentions
=
self
.
has_attentions
fx_model_class
=
getattr
(
transformers
,
fx_model_class_name
)
# load PyTorch class
pt_model
=
model_class
(
config
).
eval
()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model
.
config
.
use_cache
=
False
# load Flax class
fx_model
=
fx_model_class
(
config
,
input_shape
=
init_shape
,
dtype
=
jnp
.
float32
)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys
=
inspect
.
signature
(
fx_model
.
__call__
).
parameters
.
keys
()
# prepare inputs
pt_inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
# remove function args that don't exist in Flax
pt_inputs
=
{
k
:
v
for
k
,
v
in
pt_inputs
.
items
()
if
k
in
fx_input_keys
}
# send pytorch inputs to the correct device
pt_inputs
=
{
k
:
v
.
to
(
device
=
torch_device
)
if
isinstance
(
v
,
torch
.
Tensor
)
else
v
for
k
,
v
in
pt_inputs
.
items
()
}
# convert inputs to Flax
fx_inputs
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
pt_inputs
.
items
()
if
torch
.
is_tensor
(
v
)}
pt_model
=
load_flax_weights_in_pytorch_model
(
pt_model
,
fx_model
.
params
)
# make sure weights are tied in PyTorch
pt_model
.
tie_weights
()
# send pytorch model to the correct device
pt_model
.
to
(
torch_device
)
with
torch
.
no_grad
():
pt_outputs
=
pt_model
(
**
pt_inputs
)
fx_outputs
=
fx_model
(
**
fx_inputs
)
fx_keys
=
tuple
([
k
for
k
,
v
in
fx_outputs
.
items
()
if
v
is
not
None
])
pt_keys
=
tuple
([
k
for
k
,
v
in
pt_outputs
.
items
()
if
v
is
not
None
])
self
.
assertEqual
(
fx_keys
,
pt_keys
)
self
.
check_pt_flax_outputs
(
fx_outputs
,
pt_outputs
,
model_class
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
fx_model
.
save_pretrained
(
tmpdirname
)
pt_model_loaded
=
model_class
.
from_pretrained
(
tmpdirname
,
from_flax
=
True
)
# send pytorch model to the correct device
pt_model_loaded
.
to
(
torch_device
)
pt_model_loaded
.
eval
()
with
torch
.
no_grad
():
pt_outputs_loaded
=
pt_model_loaded
(
**
pt_inputs
)
fx_keys
=
tuple
([
k
for
k
,
v
in
fx_outputs
.
items
()
if
v
is
not
None
])
pt_keys
=
tuple
([
k
for
k
,
v
in
pt_outputs_loaded
.
items
()
if
v
is
not
None
])
self
.
assertEqual
(
fx_keys
,
pt_keys
)
self
.
check_pt_flax_outputs
(
fx_outputs
,
pt_outputs_loaded
,
model_class
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment