Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cc5c061e
Unverified
Commit
cc5c061e
authored
Jun 25, 2022
by
Joao Gante
Committed by
GitHub
Jun 25, 2022
Browse files
CLI: handle multimodal inputs (#17839)
parent
e8eb699e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
64 additions
and
40 deletions
+64
-40
src/transformers/commands/pt_to_tf.py
src/transformers/commands/pt_to_tf.py
+64
-40
No files found.
src/transformers/commands/pt_to_tf.py
View file @
cc5c061e
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
inspect
import
os
import
os
from
argparse
import
ArgumentParser
,
Namespace
from
argparse
import
ArgumentParser
,
Namespace
from
importlib
import
import_module
from
importlib
import
import_module
...
@@ -22,7 +23,17 @@ from packaging import version
...
@@ -22,7 +23,17 @@ from packaging import version
import
huggingface_hub
import
huggingface_hub
from
..
import
AutoConfig
,
AutoFeatureExtractor
,
AutoTokenizer
,
is_tf_available
,
is_torch_available
from
..
import
(
FEATURE_EXTRACTOR_MAPPING
,
PROCESSOR_MAPPING
,
TOKENIZER_MAPPING
,
AutoConfig
,
AutoFeatureExtractor
,
AutoProcessor
,
AutoTokenizer
,
is_tf_available
,
is_torch_available
,
)
from
..utils
import
logging
from
..utils
import
logging
from
.
import
BaseTransformersCLICommand
from
.
import
BaseTransformersCLICommand
...
@@ -161,31 +172,58 @@ class PTtoTFCommand(BaseTransformersCLICommand):
...
@@ -161,31 +172,58 @@ class PTtoTFCommand(BaseTransformersCLICommand):
self
.
_push
=
push
self
.
_push
=
push
self
.
_extra_commit_description
=
extra_commit_description
self
.
_extra_commit_description
=
extra_commit_description
def
get_text_inputs
(
self
):
def
get_inputs
(
self
,
pt_model
,
config
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
_local_dir
)
"""
sample_text
=
[
"Hi there!"
,
"I am a batch with more than one row and different input lengths."
]
Returns the right inputs for the model, based on its signature.
if
tokenizer
.
pad_token
is
None
:
"""
tokenizer
.
pad_token
=
tokenizer
.
eos_token
pt_input
=
tokenizer
(
sample_text
,
return_tensors
=
"pt"
,
padding
=
True
,
truncation
=
True
)
tf_input
=
tokenizer
(
sample_text
,
return_tensors
=
"tf"
,
padding
=
True
,
truncation
=
True
)
return
pt_input
,
tf_input
def
get_audio_inputs
(
self
):
def
_get_audio_input
():
processor
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
_local_dir
)
ds
=
load_dataset
(
"hf-internal-testing/librispeech_asr_dummy"
,
"clean"
,
split
=
"validation"
)
num_samples
=
2
speech_samples
=
ds
.
sort
(
"id"
).
select
(
range
(
2
))[:
2
][
"audio"
]
ds
=
load_dataset
(
"hf-internal-testing/librispeech_asr_dummy"
,
"clean"
,
split
=
"validation"
)
raw_samples
=
[
x
[
"array"
]
for
x
in
speech_samples
]
speech_samples
=
ds
.
sort
(
"id"
).
select
(
range
(
num_samples
))[:
num_samples
][
"audio"
]
return
raw_samples
raw_samples
=
[
x
[
"array"
]
for
x
in
speech_samples
]
pt_input
=
processor
(
raw_samples
,
return_tensors
=
"pt"
,
padding
=
True
)
model_forward_signature
=
set
(
inspect
.
signature
(
pt_model
.
forward
).
parameters
.
keys
())
tf_input
=
processor
(
raw_samples
,
return_tensors
=
"tf"
,
padding
=
True
)
processor_inputs
=
{}
return
pt_input
,
tf_input
if
"input_ids"
in
model_forward_signature
:
processor_inputs
.
update
(
{
"text"
:
[
"Hi there!"
,
"I am a batch with more than one row and different input lengths."
],
"padding"
:
True
,
"truncation"
:
True
,
}
)
if
"pixel_values"
in
model_forward_signature
:
sample_images
=
load_dataset
(
"cifar10"
,
"plain_text"
,
split
=
"test"
)[:
2
][
"img"
]
processor_inputs
.
update
({
"images"
:
sample_images
})
if
"input_features"
in
model_forward_signature
:
processor_inputs
.
update
({
"raw_speech"
:
_get_audio_input
(),
"padding"
:
True
})
if
"input_values"
in
model_forward_signature
:
# Wav2Vec2 audio input
processor_inputs
.
update
({
"raw_speech"
:
_get_audio_input
(),
"padding"
:
True
})
model_config_class
=
type
(
pt_model
.
config
)
if
model_config_class
in
PROCESSOR_MAPPING
:
processor
=
AutoProcessor
.
from_pretrained
(
self
.
_local_dir
)
if
model_config_class
in
TOKENIZER_MAPPING
and
processor
.
tokenizer
.
pad_token
is
None
:
processor
.
tokenizer
.
pad_token
=
processor
.
tokenizer
.
eos_token
elif
model_config_class
in
FEATURE_EXTRACTOR_MAPPING
:
processor
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
_local_dir
)
elif
model_config_class
in
TOKENIZER_MAPPING
:
processor
=
AutoTokenizer
.
from_pretrained
(
self
.
_local_dir
)
if
processor
.
pad_token
is
None
:
processor
.
pad_token
=
processor
.
eos_token
else
:
raise
ValueError
(
f
"Unknown data processing type (model config type:
{
model_config_class
}
)"
)
pt_input
=
processor
(
**
processor_inputs
,
return_tensors
=
"pt"
)
tf_input
=
processor
(
**
processor_inputs
,
return_tensors
=
"tf"
)
# Extra input requirements, in addition to the input modality
if
config
.
is_encoder_decoder
or
(
hasattr
(
pt_model
,
"encoder"
)
and
hasattr
(
pt_model
,
"decoder"
)):
decoder_input_ids
=
np
.
asarray
([[
1
],
[
1
]],
dtype
=
int
)
*
(
pt_model
.
config
.
decoder_start_token_id
or
0
)
pt_input
.
update
({
"decoder_input_ids"
:
torch
.
tensor
(
decoder_input_ids
)})
tf_input
.
update
({
"decoder_input_ids"
:
tf
.
convert_to_tensor
(
decoder_input_ids
)})
def
get_image_inputs
(
self
):
feature_extractor
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
_local_dir
)
num_samples
=
2
ds
=
load_dataset
(
"cifar10"
,
"plain_text"
,
split
=
"test"
)[:
num_samples
][
"img"
]
pt_input
=
feature_extractor
(
images
=
ds
,
return_tensors
=
"pt"
)
tf_input
=
feature_extractor
(
images
=
ds
,
return_tensors
=
"tf"
)
return
pt_input
,
tf_input
return
pt_input
,
tf_input
def
run
(
self
):
def
run
(
self
):
...
@@ -218,24 +256,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
...
@@ -218,24 +256,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
except
AttributeError
:
except
AttributeError
:
raise
AttributeError
(
f
"The TensorFlow equivalent of
{
architectures
[
0
]
}
doesn't exist in transformers."
)
raise
AttributeError
(
f
"The TensorFlow equivalent of
{
architectures
[
0
]
}
doesn't exist in transformers."
)
# Load models and acquire a basic input
for its modality
.
# Load models and acquire a basic input
compatible with the model
.
pt_model
=
pt_class
.
from_pretrained
(
self
.
_local_dir
)
pt_model
=
pt_class
.
from_pretrained
(
self
.
_local_dir
)
main_input_name
=
pt_model
.
main_input_name
if
main_input_name
==
"input_ids"
:
pt_input
,
tf_input
=
self
.
get_text_inputs
()
elif
main_input_name
==
"pixel_values"
:
pt_input
,
tf_input
=
self
.
get_image_inputs
()
elif
main_input_name
==
"input_features"
:
pt_input
,
tf_input
=
self
.
get_audio_inputs
()
else
:
raise
ValueError
(
f
"Can't detect the model modality (`main_input_name` =
{
main_input_name
}
)"
)
tf_from_pt_model
=
tf_class
.
from_pretrained
(
self
.
_local_dir
,
from_pt
=
True
)
tf_from_pt_model
=
tf_class
.
from_pretrained
(
self
.
_local_dir
,
from_pt
=
True
)
pt_input
,
tf_input
=
self
.
get_inputs
(
pt_model
,
config
)
# Extra input requirements, in addition to the input modality
if
config
.
is_encoder_decoder
or
(
hasattr
(
pt_model
,
"encoder"
)
and
hasattr
(
pt_model
,
"decoder"
)):
decoder_input_ids
=
np
.
asarray
([[
1
],
[
1
]],
dtype
=
int
)
*
pt_model
.
config
.
decoder_start_token_id
pt_input
.
update
({
"decoder_input_ids"
:
torch
.
tensor
(
decoder_input_ids
)})
tf_input
.
update
({
"decoder_input_ids"
:
tf
.
convert_to_tensor
(
decoder_input_ids
)})
# Confirms that cross loading PT weights into TF worked.
# Confirms that cross loading PT weights into TF worked.
crossload_differences
=
self
.
find_pt_tf_differences
(
pt_model
,
pt_input
,
tf_from_pt_model
,
tf_input
)
crossload_differences
=
self
.
find_pt_tf_differences
(
pt_model
,
pt_input
,
tf_from_pt_model
,
tf_input
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment