Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
01ae3b87
Unverified
Commit
01ae3b87
authored
Apr 18, 2024
by
Yih-Dar
Committed by
GitHub
Apr 18, 2024
Browse files
Avoid `jnp` import in `utils/generic.py` (#30322)
fix Co-authored-by:
ydshieh
<
ydshieh@users.noreply.github.com
>
parent
60d5f8f9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
4 deletions
+8
-4
src/transformers/utils/generic.py
src/transformers/utils/generic.py
+8
-4
No files found.
src/transformers/utils/generic.py
View file @
01ae3b87
...
@@ -38,10 +38,6 @@ from .import_utils import (
...
@@ -38,10 +38,6 @@ from .import_utils import (
)
)
if
is_flax_available
():
import
jax.numpy
as
jnp
class
cached_property
(
property
):
class
cached_property
(
property
):
"""
"""
Descriptor that mimics @property but caches output in member variable.
Descriptor that mimics @property but caches output in member variable.
...
@@ -624,6 +620,8 @@ def transpose(array, axes=None):
...
@@ -624,6 +620,8 @@ def transpose(array, axes=None):
return
tf
.
transpose
(
array
,
perm
=
axes
)
return
tf
.
transpose
(
array
,
perm
=
axes
)
elif
is_jax_tensor
(
array
):
elif
is_jax_tensor
(
array
):
import
jax.numpy
as
jnp
return
jnp
.
transpose
(
array
,
axes
=
axes
)
return
jnp
.
transpose
(
array
,
axes
=
axes
)
else
:
else
:
raise
ValueError
(
f
"Type not supported for transpose:
{
type
(
array
)
}
."
)
raise
ValueError
(
f
"Type not supported for transpose:
{
type
(
array
)
}
."
)
...
@@ -643,6 +641,8 @@ def reshape(array, newshape):
...
@@ -643,6 +641,8 @@ def reshape(array, newshape):
return
tf
.
reshape
(
array
,
newshape
)
return
tf
.
reshape
(
array
,
newshape
)
elif
is_jax_tensor
(
array
):
elif
is_jax_tensor
(
array
):
import
jax.numpy
as
jnp
return
jnp
.
reshape
(
array
,
newshape
)
return
jnp
.
reshape
(
array
,
newshape
)
else
:
else
:
raise
ValueError
(
f
"Type not supported for reshape:
{
type
(
array
)
}
."
)
raise
ValueError
(
f
"Type not supported for reshape:
{
type
(
array
)
}
."
)
...
@@ -662,6 +662,8 @@ def squeeze(array, axis=None):
...
@@ -662,6 +662,8 @@ def squeeze(array, axis=None):
return
tf
.
squeeze
(
array
,
axis
=
axis
)
return
tf
.
squeeze
(
array
,
axis
=
axis
)
elif
is_jax_tensor
(
array
):
elif
is_jax_tensor
(
array
):
import
jax.numpy
as
jnp
return
jnp
.
squeeze
(
array
,
axis
=
axis
)
return
jnp
.
squeeze
(
array
,
axis
=
axis
)
else
:
else
:
raise
ValueError
(
f
"Type not supported for squeeze:
{
type
(
array
)
}
."
)
raise
ValueError
(
f
"Type not supported for squeeze:
{
type
(
array
)
}
."
)
...
@@ -681,6 +683,8 @@ def expand_dims(array, axis):
...
@@ -681,6 +683,8 @@ def expand_dims(array, axis):
return
tf
.
expand_dims
(
array
,
axis
=
axis
)
return
tf
.
expand_dims
(
array
,
axis
=
axis
)
elif
is_jax_tensor
(
array
):
elif
is_jax_tensor
(
array
):
import
jax.numpy
as
jnp
return
jnp
.
expand_dims
(
array
,
axis
=
axis
)
return
jnp
.
expand_dims
(
array
,
axis
=
axis
)
else
:
else
:
raise
ValueError
(
f
"Type not supported for expand_dims:
{
type
(
array
)
}
."
)
raise
ValueError
(
f
"Type not supported for expand_dims:
{
type
(
array
)
}
."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment