Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c87bbe1f
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e81cb010f8d68ef7317f66ae727e52098f8ae1ab"
Unverified
Commit
c87bbe1f
authored
Feb 20, 2023
by
Sylvain Gugger
Browse files
Fix quality
parent
011cc17a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
5 deletions
+6
-5
src/transformers/models/whisper/modeling_flax_whisper.py
src/transformers/models/whisper/modeling_flax_whisper.py
+3
-3
tests/models/whisper/test_modeling_flax_whisper.py
tests/models/whisper/test_modeling_flax_whisper.py
+2
-2
tests/models/whisper/test_modeling_whisper.py
tests/models/whisper/test_modeling_whisper.py
+1
-0
No files found.
src/transformers/models/whisper/modeling_flax_whisper.py
View file @
c87bbe1f
...
@@ -821,7 +821,7 @@ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
...
@@ -821,7 +821,7 @@ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
seed
:
int
=
0
,
seed
:
int
=
0
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
_do_init
:
bool
=
True
,
_do_init
:
bool
=
True
,
**
kwargs
**
kwargs
,
):
):
module
=
self
.
module_class
(
config
=
config
,
dtype
=
dtype
,
**
kwargs
)
module
=
self
.
module_class
(
config
=
config
,
dtype
=
dtype
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
,
_do_init
=
_do_init
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
,
_do_init
=
_do_init
)
...
@@ -1348,7 +1348,7 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
...
@@ -1348,7 +1348,7 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
task
=
None
,
task
=
None
,
language
=
None
,
language
=
None
,
is_multilingual
=
None
,
is_multilingual
=
None
,
**
kwargs
**
kwargs
,
):
):
if
generation_config
is
None
:
if
generation_config
is
None
:
generation_config
=
self
.
generation_config
generation_config
=
self
.
generation_config
...
@@ -1411,7 +1411,7 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
...
@@ -1411,7 +1411,7 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
attention_mask
:
Optional
[
jnp
.
DeviceArray
]
=
None
,
attention_mask
:
Optional
[
jnp
.
DeviceArray
]
=
None
,
decoder_attention_mask
:
Optional
[
jnp
.
DeviceArray
]
=
None
,
decoder_attention_mask
:
Optional
[
jnp
.
DeviceArray
]
=
None
,
encoder_outputs
=
None
,
encoder_outputs
=
None
,
**
kwargs
**
kwargs
,
):
):
# initializing the cache
# initializing the cache
batch_size
,
seq_length
=
decoder_input_ids
.
shape
batch_size
,
seq_length
=
decoder_input_ids
.
shape
...
...
tests/models/whisper/test_modeling_flax_whisper.py
View file @
c87bbe1f
...
@@ -34,11 +34,11 @@ if is_datasets_available():
...
@@ -34,11 +34,11 @@ if is_datasets_available():
from
datasets
import
load_dataset
from
datasets
import
load_dataset
if
is_flax_available
():
if
is_flax_available
():
import
numpy
as
np
import
jax
import
jax
import
numpy
as
np
from
flax.core.frozen_dict
import
unfreeze
from
flax.core.frozen_dict
import
unfreeze
from
flax.traverse_util
import
flatten_dict
from
flax.traverse_util
import
flatten_dict
from
transformers
import
(
from
transformers
import
(
FLAX_MODEL_MAPPING
,
FLAX_MODEL_MAPPING
,
FlaxWhisperForConditionalGeneration
,
FlaxWhisperForConditionalGeneration
,
...
...
tests/models/whisper/test_modeling_whisper.py
View file @
c87bbe1f
...
@@ -51,6 +51,7 @@ if is_torch_available():
...
@@ -51,6 +51,7 @@ if is_torch_available():
if
is_flax_available
():
if
is_flax_available
():
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
transformers.modeling_flax_pytorch_utils
import
(
from
transformers.modeling_flax_pytorch_utils
import
(
convert_pytorch_state_dict_to_flax
,
convert_pytorch_state_dict_to_flax
,
load_flax_weights_in_pytorch_model
,
load_flax_weights_in_pytorch_model
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment