Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cc5c061e
Unverified
Commit
cc5c061e
authored
Jun 25, 2022
by
Joao Gante
Committed by
GitHub
Jun 25, 2022
Browse files
CLI: handle multimodal inputs (#17839)
parent
e8eb699e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
64 additions
and
40 deletions
+64
-40
src/transformers/commands/pt_to_tf.py
src/transformers/commands/pt_to_tf.py
+64
-40
No files found.
src/transformers/commands/pt_to_tf.py
View file @
cc5c061e
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
inspect
import
os
import
os
from
argparse
import
ArgumentParser
,
Namespace
from
argparse
import
ArgumentParser
,
Namespace
from
importlib
import
import_module
from
importlib
import
import_module
...
@@ -22,7 +23,17 @@ from packaging import version
...
@@ -22,7 +23,17 @@ from packaging import version
import
huggingface_hub
import
huggingface_hub
from
..
import
AutoConfig
,
AutoFeatureExtractor
,
AutoTokenizer
,
is_tf_available
,
is_torch_available
from
..
import
(
FEATURE_EXTRACTOR_MAPPING
,
PROCESSOR_MAPPING
,
TOKENIZER_MAPPING
,
AutoConfig
,
AutoFeatureExtractor
,
AutoProcessor
,
AutoTokenizer
,
is_tf_available
,
is_torch_available
,
)
from
..utils
import
logging
from
..utils
import
logging
from
.
import
BaseTransformersCLICommand
from
.
import
BaseTransformersCLICommand
...
@@ -161,31 +172,58 @@ class PTtoTFCommand(BaseTransformersCLICommand):
...
@@ -161,31 +172,58 @@ class PTtoTFCommand(BaseTransformersCLICommand):
self
.
_push
=
push
self
.
_push
=
push
self
.
_extra_commit_description
=
extra_commit_description
self
.
_extra_commit_description
=
extra_commit_description
def
get_text_inputs
(
self
):
def
get_inputs
(
self
,
pt_model
,
config
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
_local_dir
)
"""
sample_text
=
[
"Hi there!"
,
"I am a batch with more than one row and different input lengths."
]
Returns the right inputs for the model, based on its signature.
if
tokenizer
.
pad_token
is
None
:
"""
tokenizer
.
pad_token
=
tokenizer
.
eos_token
pt_input
=
tokenizer
(
sample_text
,
return_tensors
=
"pt"
,
padding
=
True
,
truncation
=
True
)
tf_input
=
tokenizer
(
sample_text
,
return_tensors
=
"tf"
,
padding
=
True
,
truncation
=
True
)
return
pt_input
,
tf_input
def
get_audio_inputs
(
self
):
def
_get_audio_input
():
processor
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
_local_dir
)
ds
=
load_dataset
(
"hf-internal-testing/librispeech_asr_dummy"
,
"clean"
,
split
=
"validation"
)
num_samples
=
2
speech_samples
=
ds
.
sort
(
"id"
).
select
(
range
(
2
))[:
2
][
"audio"
]
ds
=
load_dataset
(
"hf-internal-testing/librispeech_asr_dummy"
,
"clean"
,
split
=
"validation"
)
raw_samples
=
[
x
[
"array"
]
for
x
in
speech_samples
]
speech_samples
=
ds
.
sort
(
"id"
).
select
(
range
(
num_samples
))[:
num_samples
][
"audio"
]
return
raw_samples
raw_samples
=
[
x
[
"array"
]
for
x
in
speech_samples
]
pt_input
=
processor
(
raw_samples
,
return_tensors
=
"pt"
,
padding
=
True
)
model_forward_signature
=
set
(
inspect
.
signature
(
pt_model
.
forward
).
parameters
.
keys
())
tf_input
=
processor
(
raw_samples
,
return_tensors
=
"tf"
,
padding
=
True
)
processor_inputs
=
{}
return
pt_input
,
tf_input
if
"input_ids"
in
model_forward_signature
:
processor_inputs
.
update
(
{
"text"
:
[
"Hi there!"
,
"I am a batch with more than one row and different input lengths."
],
"padding"
:
True
,
"truncation"
:
True
,
}
)
if
"pixel_values"
in
model_forward_signature
:
sample_images
=
load_dataset
(
"cifar10"
,
"plain_text"
,
split
=
"test"
)[:
2
][
"img"
]
processor_inputs
.
update
({
"images"
:
sample_images
})
if
"input_features"
in
model_forward_signature
:
processor_inputs
.
update
({
"raw_speech"
:
_get_audio_input
(),
"padding"
:
True
})
if
"input_values"
in
model_forward_signature
:
# Wav2Vec2 audio input
processor_inputs
.
update
({
"raw_speech"
:
_get_audio_input
(),
"padding"
:
True
})
model_config_class
=
type
(
pt_model
.
config
)
if
model_config_class
in
PROCESSOR_MAPPING
:
processor
=
AutoProcessor
.
from_pretrained
(
self
.
_local_dir
)
if
model_config_class
in
TOKENIZER_MAPPING
and
processor
.
tokenizer
.
pad_token
is
None
:
processor
.
tokenizer
.
pad_token
=
processor
.
tokenizer
.
eos_token
elif
model_config_class
in
FEATURE_EXTRACTOR_MAPPING
:
processor
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
_local_dir
)
elif
model_config_class
in
TOKENIZER_MAPPING
:
processor
=
AutoTokenizer
.
from_pretrained
(
self
.
_local_dir
)
if
processor
.
pad_token
is
None
:
processor
.
pad_token
=
processor
.
eos_token
else
:
raise
ValueError
(
f
"Unknown data processing type (model config type:
{
model_config_class
}
)"
)
pt_input
=
processor
(
**
processor_inputs
,
return_tensors
=
"pt"
)
tf_input
=
processor
(
**
processor_inputs
,
return_tensors
=
"tf"
)
# Extra input requirements, in addition to the input modality
if
config
.
is_encoder_decoder
or
(
hasattr
(
pt_model
,
"encoder"
)
and
hasattr
(
pt_model
,
"decoder"
)):
decoder_input_ids
=
np
.
asarray
([[
1
],
[
1
]],
dtype
=
int
)
*
(
pt_model
.
config
.
decoder_start_token_id
or
0
)
pt_input
.
update
({
"decoder_input_ids"
:
torch
.
tensor
(
decoder_input_ids
)})
tf_input
.
update
({
"decoder_input_ids"
:
tf
.
convert_to_tensor
(
decoder_input_ids
)})
def
get_image_inputs
(
self
):
feature_extractor
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
_local_dir
)
num_samples
=
2
ds
=
load_dataset
(
"cifar10"
,
"plain_text"
,
split
=
"test"
)[:
num_samples
][
"img"
]
pt_input
=
feature_extractor
(
images
=
ds
,
return_tensors
=
"pt"
)
tf_input
=
feature_extractor
(
images
=
ds
,
return_tensors
=
"tf"
)
return
pt_input
,
tf_input
return
pt_input
,
tf_input
def
run
(
self
):
def
run
(
self
):
...
@@ -218,24 +256,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
...
@@ -218,24 +256,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
except
AttributeError
:
except
AttributeError
:
raise
AttributeError
(
f
"The TensorFlow equivalent of
{
architectures
[
0
]
}
doesn't exist in transformers."
)
raise
AttributeError
(
f
"The TensorFlow equivalent of
{
architectures
[
0
]
}
doesn't exist in transformers."
)
# Load models and acquire a basic input
for its modality
.
# Load models and acquire a basic input
compatible with the model
.
pt_model
=
pt_class
.
from_pretrained
(
self
.
_local_dir
)
pt_model
=
pt_class
.
from_pretrained
(
self
.
_local_dir
)
main_input_name
=
pt_model
.
main_input_name
if
main_input_name
==
"input_ids"
:
pt_input
,
tf_input
=
self
.
get_text_inputs
()
elif
main_input_name
==
"pixel_values"
:
pt_input
,
tf_input
=
self
.
get_image_inputs
()
elif
main_input_name
==
"input_features"
:
pt_input
,
tf_input
=
self
.
get_audio_inputs
()
else
:
raise
ValueError
(
f
"Can't detect the model modality (`main_input_name` =
{
main_input_name
}
)"
)
tf_from_pt_model
=
tf_class
.
from_pretrained
(
self
.
_local_dir
,
from_pt
=
True
)
tf_from_pt_model
=
tf_class
.
from_pretrained
(
self
.
_local_dir
,
from_pt
=
True
)
pt_input
,
tf_input
=
self
.
get_inputs
(
pt_model
,
config
)
# Extra input requirements, in addition to the input modality
if
config
.
is_encoder_decoder
or
(
hasattr
(
pt_model
,
"encoder"
)
and
hasattr
(
pt_model
,
"decoder"
)):
decoder_input_ids
=
np
.
asarray
([[
1
],
[
1
]],
dtype
=
int
)
*
pt_model
.
config
.
decoder_start_token_id
pt_input
.
update
({
"decoder_input_ids"
:
torch
.
tensor
(
decoder_input_ids
)})
tf_input
.
update
({
"decoder_input_ids"
:
tf
.
convert_to_tensor
(
decoder_input_ids
)})
# Confirms that cross loading PT weights into TF worked.
# Confirms that cross loading PT weights into TF worked.
crossload_differences
=
self
.
find_pt_tf_differences
(
pt_model
,
pt_input
,
tf_from_pt_model
,
tf_input
)
crossload_differences
=
self
.
find_pt_tf_differences
(
pt_model
,
pt_input
,
tf_from_pt_model
,
tf_input
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment