Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
16469fed
Unverified
Commit
16469fed
authored
Apr 16, 2020
by
Sam Shleifer
Committed by
GitHub
Apr 16, 2020
Browse files
[PretrainedTokenizer] Factor out tensor conversion method (#3777)
parent
80a16945
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
27 deletions
+30
-27
src/transformers/tokenization_utils.py
src/transformers/tokenization_utils.py
+30
-27
No files found.
src/transformers/tokenization_utils.py
View file @
16469fed
...
...
@@ -517,7 +517,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
# Padding side is right by default and over
-
riden in subclasses. If specified in the kwargs, it is changed.
# Padding side is right by default and overrid
d
en in subclasses. If specified in the kwargs, it is changed.
self
.
padding_side
=
kwargs
.
pop
(
"padding_side"
,
self
.
padding_side
)
self
.
model_input_names
=
kwargs
.
pop
(
"model_input_names"
,
self
.
model_input_names
)
...
...
@@ -1447,6 +1447,10 @@ class PreTrainedTokenizer(SpecialTokensMixin):
if
return_tensors
is
not
None
:
self
.
convert_to_tensors_
(
batch_outputs
,
return_tensors
)
return
BatchEncoding
(
batch_outputs
)
def
convert_to_tensors_
(
self
,
batch_outputs
:
dict
,
return_tensors
:
str
)
->
None
:
# Do the tensor conversion in batch
for
key
,
value
in
batch_outputs
.
items
():
if
return_tensors
==
"tf"
and
is_tf_available
():
...
...
@@ -1467,6 +1471,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
raise
ValueError
(
self
.
NO_PAD_TOKEN_FOR_BATCH_MSG
)
else
:
raise
elif
return_tensors
is
not
None
:
logger
.
warning
(
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available."
.
format
(
...
...
@@ -1474,8 +1479,6 @@ class PreTrainedTokenizer(SpecialTokensMixin):
)
)
return
BatchEncoding
(
batch_outputs
)
def
prepare_for_model
(
self
,
ids
:
List
[
int
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment