Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9151e649
Unverified
Commit
9151e649
authored
Oct 21, 2022
by
Sylvain Gugger
Committed by
GitHub
Oct 21, 2022
Browse files
Make public versions of private tensor utils (#19775)
* Make public versions of private utils * I need sleep
parent
3aaabaa2
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
80 additions
and
49 deletions
+80
-49
src/transformers/feature_extraction_sequence_utils.py
src/transformers/feature_extraction_sequence_utils.py
+3
-4
src/transformers/feature_extraction_utils.py
src/transformers/feature_extraction_utils.py
+6
-4
src/transformers/image_utils.py
src/transformers/image_utils.py
+9
-14
src/transformers/models/luke/tokenization_luke.py
src/transformers/models/luke/tokenization_luke.py
+3
-5
src/transformers/models/mluke/tokenization_mluke.py
src/transformers/models/mluke/tokenization_mluke.py
+3
-5
src/transformers/tokenization_utils_base.py
src/transformers/tokenization_utils_base.py
+10
-11
src/transformers/utils/__init__.py
src/transformers/utils/__init__.py
+5
-0
src/transformers/utils/generic.py
src/transformers/utils/generic.py
+41
-6
No files found.
src/transformers/feature_extraction_sequence_utils.py
View file @
9151e649
...
...
@@ -20,8 +20,7 @@ from typing import Dict, List, Optional, Union
import
numpy
as
np
from
.feature_extraction_utils
import
BatchFeature
,
FeatureExtractionMixin
from
.utils
import
PaddingStrategy
,
TensorType
,
is_tf_available
,
is_torch_available
,
logging
,
to_numpy
from
.utils.generic
import
_is_tensorflow
,
_is_torch
from
.utils
import
PaddingStrategy
,
TensorType
,
is_tf_tensor
,
is_torch_tensor
,
logging
,
to_numpy
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -160,9 +159,9 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
first_element
=
required_input
[
index
][
0
]
if
return_tensors
is
None
:
if
is_tf_
available
()
and
_is_
tensor
flow
(
first_element
):
if
is_tf_tensor
(
first_element
):
return_tensors
=
"tf"
elif
is_torch_
available
()
and
_is_torch
(
first_element
):
elif
is_torch_
tensor
(
first_element
):
return_tensors
=
"pt"
elif
isinstance
(
first_element
,
(
int
,
float
,
list
,
tuple
,
np
.
ndarray
)):
return_tensors
=
"np"
...
...
src/transformers/feature_extraction_utils.py
View file @
9151e649
...
...
@@ -33,14 +33,16 @@ from .utils import (
copy_func
,
download_url
,
is_flax_available
,
is_jax_tensor
,
is_numpy_array
,
is_offline_mode
,
is_remote_url
,
is_tf_available
,
is_torch_available
,
is_torch_device
,
logging
,
torch_required
,
)
from
.utils.generic
import
_is_jax
,
_is_numpy
,
_is_torch_device
if
TYPE_CHECKING
:
...
...
@@ -150,10 +152,10 @@ class BatchFeature(UserDict):
import
jax.numpy
as
jnp
# noqa: F811
as_tensor
=
jnp
.
array
is_tensor
=
_
is_jax
is_tensor
=
is_jax
_tensor
else
:
as_tensor
=
np
.
asarray
is_tensor
=
_
is_numpy
is_tensor
=
is_numpy
_array
# Do the tensor conversion in batch
for
key
,
value
in
self
.
items
():
...
...
@@ -188,7 +190,7 @@ class BatchFeature(UserDict):
# This check catches things like APEX blindly calling "to" on all inputs to a module
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
if
isinstance
(
device
,
str
)
or
_
is_torch_device
(
device
)
or
isinstance
(
device
,
int
):
if
isinstance
(
device
,
str
)
or
is_torch_device
(
device
)
or
isinstance
(
device
,
int
):
self
.
data
=
{
k
:
v
.
to
(
device
=
device
)
for
k
,
v
in
self
.
data
.
items
()}
else
:
logger
.
warning
(
f
"Attempting to cast a BatchFeature to type
{
str
(
device
)
}
. This is not supported."
)
...
...
src/transformers/image_utils.py
View file @
9151e649
...
...
@@ -21,14 +21,21 @@ from packaging import version
import
requests
from
.utils
import
is_flax_available
,
is_tf_available
,
is_torch_available
,
is_vision_available
from
.utils
import
(
ExplicitEnum
,
is_jax_tensor
,
is_tf_tensor
,
is_torch_available
,
is_torch_tensor
,
is_vision_available
,
to_numpy
,
)
from
.utils.constants
import
(
# noqa: F401
IMAGENET_DEFAULT_MEAN
,
IMAGENET_DEFAULT_STD
,
IMAGENET_STANDARD_MEAN
,
IMAGENET_STANDARD_STD
,
)
from
.utils.generic
import
ExplicitEnum
,
_is_jax
,
_is_tensorflow
,
_is_torch
,
to_numpy
if
is_vision_available
():
...
...
@@ -55,18 +62,6 @@ class ChannelDimension(ExplicitEnum):
LAST
=
"channels_last"
def
is_torch_tensor
(
obj
):
return
_is_torch
(
obj
)
if
is_torch_available
()
else
False
def
is_tf_tensor
(
obj
):
return
_is_tensorflow
(
obj
)
if
is_tf_available
()
else
False
def
is_jax_tensor
(
obj
):
return
_is_jax
(
obj
)
if
is_flax_available
()
else
False
def
is_valid_image
(
img
):
return
(
isinstance
(
img
,
(
PIL
.
Image
.
Image
,
np
.
ndarray
))
...
...
src/transformers/models/luke/tokenization_luke.py
View file @
9151e649
...
...
@@ -33,11 +33,9 @@ from ...tokenization_utils_base import (
TextInput
,
TextInputPair
,
TruncationStrategy
,
_is_tensorflow
,
_is_torch
,
to_py_obj
,
)
from
...utils
import
add_end_docstrings
,
is_tf_
available
,
is_torch_
available
,
logging
from
...utils
import
add_end_docstrings
,
is_tf_
tensor
,
is_torch_
tensor
,
logging
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -1174,9 +1172,9 @@ class LukeTokenizer(RobertaTokenizer):
first_element
=
required_input
[
index
][
0
]
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
if
not
isinstance
(
first_element
,
(
int
,
list
,
tuple
)):
if
is_tf_
available
()
and
_is_
tensor
flow
(
first_element
):
if
is_tf_tensor
(
first_element
):
return_tensors
=
"tf"
if
return_tensors
is
None
else
return_tensors
elif
is_torch_
available
()
and
_is_torch
(
first_element
):
elif
is_torch_
tensor
(
first_element
):
return_tensors
=
"pt"
if
return_tensors
is
None
else
return_tensors
elif
isinstance
(
first_element
,
np
.
ndarray
):
return_tensors
=
"np"
if
return_tensors
is
None
else
return_tensors
...
...
src/transformers/models/mluke/tokenization_mluke.py
View file @
9151e649
...
...
@@ -37,11 +37,9 @@ from ...tokenization_utils_base import (
TextInput
,
TextInputPair
,
TruncationStrategy
,
_is_tensorflow
,
_is_torch
,
to_py_obj
,
)
from
...utils
import
add_end_docstrings
,
is_tf_
available
,
is_torch_
available
,
logging
from
...utils
import
add_end_docstrings
,
is_tf_
tensor
,
is_torch_
tensor
,
logging
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -1287,9 +1285,9 @@ class MLukeTokenizer(PreTrainedTokenizer):
first_element
=
required_input
[
index
][
0
]
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
if
not
isinstance
(
first_element
,
(
int
,
list
,
tuple
)):
if
is_tf_
available
()
and
_is_
tensor
flow
(
first_element
):
if
is_tf_tensor
(
first_element
):
return_tensors
=
"tf"
if
return_tensors
is
None
else
return_tensors
elif
is_torch_
available
()
and
_is_torch
(
first_element
):
elif
is_torch_
tensor
(
first_element
):
return_tensors
=
"pt"
if
return_tensors
is
None
else
return_tensors
elif
isinstance
(
first_element
,
np
.
ndarray
):
return_tensors
=
"np"
if
return_tensors
is
None
else
return_tensors
...
...
src/transformers/tokenization_utils_base.py
View file @
9151e649
...
...
@@ -45,16 +45,20 @@ from .utils import (
download_url
,
extract_commit_hash
,
is_flax_available
,
is_jax_tensor
,
is_numpy_array
,
is_offline_mode
,
is_remote_url
,
is_tf_available
,
is_tf_tensor
,
is_tokenizers_available
,
is_torch_available
,
is_torch_device
,
is_torch_tensor
,
logging
,
to_py_obj
,
torch_required
,
)
from
.utils.generic
import
_is_jax
,
_is_numpy
,
_is_tensorflow
,
_is_torch
,
_is_torch_device
if
TYPE_CHECKING
:
...
...
@@ -696,15 +700,10 @@ class BatchEncoding(UserDict):
import
jax.numpy
as
jnp
# noqa: F811
as_tensor
=
jnp
.
array
is_tensor
=
_
is_jax
is_tensor
=
is_jax
_tensor
else
:
as_tensor
=
np
.
asarray
is_tensor
=
_is_numpy
# (mfuntowicz: This code is unreachable)
# else:
# raise ImportError(
# f"Unable to convert output to tensors format {tensor_type}"
# )
is_tensor
=
is_numpy_array
# Do the tensor conversion in batch
for
key
,
value
in
self
.
items
():
...
...
@@ -753,7 +752,7 @@ class BatchEncoding(UserDict):
# This check catches things like APEX blindly calling "to" on all inputs to a module
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
if
isinstance
(
device
,
str
)
or
_
is_torch_device
(
device
)
or
isinstance
(
device
,
int
):
if
isinstance
(
device
,
str
)
or
is_torch_device
(
device
)
or
isinstance
(
device
,
int
):
self
.
data
=
{
k
:
v
.
to
(
device
=
device
)
for
k
,
v
in
self
.
data
.
items
()}
else
:
logger
.
warning
(
f
"Attempting to cast a BatchEncoding to type
{
str
(
device
)
}
. This is not supported."
)
...
...
@@ -2925,9 +2924,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
break
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
if
not
isinstance
(
first_element
,
(
int
,
list
,
tuple
)):
if
is_tf_
available
()
and
_is_
tensor
flow
(
first_element
):
if
is_tf_tensor
(
first_element
):
return_tensors
=
"tf"
if
return_tensors
is
None
else
return_tensors
elif
is_torch_
available
()
and
_is_torch
(
first_element
):
elif
is_torch_
tensor
(
first_element
):
return_tensors
=
"pt"
if
return_tensors
is
None
else
return_tensors
elif
isinstance
(
first_element
,
np
.
ndarray
):
return_tensors
=
"np"
if
return_tensors
is
None
else
return_tensors
...
...
src/transformers/utils/__init__.py
View file @
9151e649
...
...
@@ -40,7 +40,12 @@ from .generic import (
cached_property
,
find_labels
,
flatten_dict
,
is_jax_tensor
,
is_numpy_array
,
is_tensor
,
is_tf_tensor
,
is_torch_device
,
is_torch_tensor
,
to_numpy
,
to_py_obj
,
working_or_temp_dir
,
...
...
src/transformers/utils/generic.py
View file @
9151e649
...
...
@@ -83,30 +83,65 @@ def _is_numpy(x):
return
isinstance
(
x
,
np
.
ndarray
)
def
is_numpy_array
(
x
):
"""
Tests if `x` is a numpy array or not.
"""
return
_is_numpy
(
x
)
def
_is_torch
(
x
):
import
torch
return
isinstance
(
x
,
torch
.
Tensor
)
def
is_torch_tensor
(
x
):
"""
Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed.
"""
return
False
if
not
is_torch_available
()
else
_is_torch
(
x
)
def
_is_torch_device
(
x
):
import
torch
return
isinstance
(
x
,
torch
.
device
)
def
is_torch_device
(
x
):
"""
Tests if `x` is a torch device or not. Safe to call even if torch is not installed.
"""
return
False
if
not
is_torch_available
()
else
_is_torch_device
(
x
)
def
_is_tensorflow
(
x
):
import
tensorflow
as
tf
return
isinstance
(
x
,
tf
.
Tensor
)
def
is_tf_tensor
(
x
):
"""
Tests if `x` is a tensorflow tensor or not. Safe to call even if tensorflow is not installed.
"""
return
False
if
not
is_tf_available
()
else
_is_tensorflow
(
x
)
def
_is_jax
(
x
):
import
jax.numpy
as
jnp
# noqa: F811
return
isinstance
(
x
,
jnp
.
ndarray
)
def
is_jax_tensor
(
x
):
"""
Tests if `x` is a Jax tensor or not. Safe to call even if jax is not installed.
"""
return
False
if
not
is_flax_available
()
else
_is_jax
(
x
)
def
to_py_obj
(
obj
):
"""
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
...
...
@@ -115,11 +150,11 @@ def to_py_obj(obj):
return
{
k
:
to_py_obj
(
v
)
for
k
,
v
in
obj
.
items
()}
elif
isinstance
(
obj
,
(
list
,
tuple
)):
return
[
to_py_obj
(
o
)
for
o
in
obj
]
elif
is_tf_
available
()
and
_is_
tensor
flow
(
obj
):
elif
is_tf_tensor
(
obj
):
return
obj
.
numpy
().
tolist
()
elif
is_torch_
available
()
and
_is_torch
(
obj
):
elif
is_torch_
tensor
(
obj
):
return
obj
.
detach
().
cpu
().
tolist
()
elif
is_
fl
ax_
available
()
and
_is_jax
(
obj
):
elif
is_
j
ax_
tensor
(
obj
):
return
np
.
asarray
(
obj
).
tolist
()
elif
isinstance
(
obj
,
(
np
.
ndarray
,
np
.
number
)):
# tolist also works on 0d np arrays
return
obj
.
tolist
()
...
...
@@ -135,11 +170,11 @@ def to_numpy(obj):
return
{
k
:
to_numpy
(
v
)
for
k
,
v
in
obj
.
items
()}
elif
isinstance
(
obj
,
(
list
,
tuple
)):
return
np
.
array
(
obj
)
elif
is_tf_
available
()
and
_is_
tensor
flow
(
obj
):
elif
is_tf_tensor
(
obj
):
return
obj
.
numpy
()
elif
is_torch_
available
()
and
_is_torch
(
obj
):
elif
is_torch_
tensor
(
obj
):
return
obj
.
detach
().
cpu
().
numpy
()
elif
is_
fl
ax_
available
()
and
_is_jax
(
obj
):
elif
is_
j
ax_
tensor
(
obj
):
return
np
.
asarray
(
obj
)
else
:
return
obj
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment