Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6d4f8bd0
Unverified
Commit
6d4f8bd0
authored
Oct 20, 2020
by
Sylvain Gugger
Committed by
GitHub
Oct 20, 2020
Browse files
Add Flax dummy objects (#7918)
parent
3e31e7f9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
102 additions
and
5 deletions
+102
-5
src/transformers/__init__.py
src/transformers/__init__.py
+5
-0
src/transformers/file_utils.py
src/transformers/file_utils.py
+12
-0
src/transformers/utils/dummy_flax_objects.py
src/transformers/utils/dummy_flax_objects.py
+20
-0
utils/check_dummies.py
utils/check_dummies.py
+65
-5
No files found.
src/transformers/__init__.py
View file @
6d4f8bd0
...
@@ -841,6 +841,11 @@ else:
...
@@ -841,6 +841,11 @@ else:
if
is_flax_available
():
if
is_flax_available
():
from
.modeling_flax_bert
import
FlaxBertModel
from
.modeling_flax_bert
import
FlaxBertModel
from
.modeling_flax_roberta
import
FlaxRobertaModel
from
.modeling_flax_roberta
import
FlaxRobertaModel
else
:
# Import the same objects as dummies to get them in the namespace.
# They will raise an import error if the user tries to instantiate / use them.
from
.utils.dummy_flax_objects
import
*
if
not
is_tf_available
()
and
not
is_torch_available
():
if
not
is_tf_available
()
and
not
is_torch_available
():
logger
.
warning
(
logger
.
warning
(
...
...
src/transformers/file_utils.py
View file @
6d4f8bd0
...
@@ -356,6 +356,12 @@ installation page: https://www.tensorflow.org/install and follow the ones that m
...
@@ -356,6 +356,12 @@ installation page: https://www.tensorflow.org/install and follow the ones that m
"""
"""
FLAX_IMPORT_ERROR
=
"""
{0} requires the FLAX library but it was not found in your enviromnent. Checkout the instructions on the
installation page: https://github.com/google/flax and follow the ones that match your enviromnent.
"""
def
requires_datasets
(
obj
):
def
requires_datasets
(
obj
):
name
=
obj
.
__name__
if
hasattr
(
obj
,
"__name__"
)
else
obj
.
__class__
.
__name__
name
=
obj
.
__name__
if
hasattr
(
obj
,
"__name__"
)
else
obj
.
__class__
.
__name__
if
not
is_datasets_available
():
if
not
is_datasets_available
():
...
@@ -386,6 +392,12 @@ def requires_tf(obj):
...
@@ -386,6 +392,12 @@ def requires_tf(obj):
raise
ImportError
(
TENSORFLOW_IMPORT_ERROR
.
format
(
name
))
raise
ImportError
(
TENSORFLOW_IMPORT_ERROR
.
format
(
name
))
def
requires_flax
(
obj
):
name
=
obj
.
__name__
if
hasattr
(
obj
,
"__name__"
)
else
obj
.
__class__
.
__name__
if
not
is_flax_available
():
raise
ImportError
(
FLAX_IMPORT_ERROR
.
format
(
name
))
def
requires_tokenizers
(
obj
):
def
requires_tokenizers
(
obj
):
name
=
obj
.
__name__
if
hasattr
(
obj
,
"__name__"
)
else
obj
.
__class__
.
__name__
name
=
obj
.
__name__
if
hasattr
(
obj
,
"__name__"
)
else
obj
.
__class__
.
__name__
if
not
is_tokenizers_available
():
if
not
is_tokenizers_available
():
...
...
src/transformers/utils/dummy_flax_objects.py
0 → 100644
View file @
6d4f8bd0
# This file is autogenerated by the command `make fix-copies`, do not edit.
from
..file_utils
import
requires_flax
class
FlaxBertModel
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_flax
(
self
)
@
classmethod
def
from_pretrained
(
self
,
*
args
,
**
kwargs
):
requires_flax
(
self
)
class
FlaxRobertaModel
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_flax
(
self
)
@
classmethod
def
from_pretrained
(
self
,
*
args
,
**
kwargs
):
requires_flax
(
self
)
utils/check_dummies.py
View file @
6d4f8bd0
...
@@ -72,6 +72,28 @@ def {0}(*args, **kwargs):
...
@@ -72,6 +72,28 @@ def {0}(*args, **kwargs):
"""
"""
DUMMY_FLAX_PRETRAINED_CLASS
=
"""
class {0}:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
"""
DUMMY_FLAX_CLASS
=
"""
class {0}:
def __init__(self, *args, **kwargs):
requires_flax(self)
"""
DUMMY_FLAX_FUNCTION
=
"""
def {0}(*args, **kwargs):
requires_flax({0})
"""
DUMMY_SENTENCEPIECE_PRETRAINED_CLASS
=
"""
DUMMY_SENTENCEPIECE_PRETRAINED_CLASS
=
"""
class {0}:
class {0}:
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs):
...
@@ -120,6 +142,7 @@ def {0}(*args, **kwargs):
...
@@ -120,6 +142,7 @@ def {0}(*args, **kwargs):
DUMMY_PRETRAINED_CLASS
=
{
DUMMY_PRETRAINED_CLASS
=
{
"pt"
:
DUMMY_PT_PRETRAINED_CLASS
,
"pt"
:
DUMMY_PT_PRETRAINED_CLASS
,
"tf"
:
DUMMY_TF_PRETRAINED_CLASS
,
"tf"
:
DUMMY_TF_PRETRAINED_CLASS
,
"flax"
:
DUMMY_FLAX_PRETRAINED_CLASS
,
"sentencepiece"
:
DUMMY_SENTENCEPIECE_PRETRAINED_CLASS
,
"sentencepiece"
:
DUMMY_SENTENCEPIECE_PRETRAINED_CLASS
,
"tokenizers"
:
DUMMY_TOKENIZERS_PRETRAINED_CLASS
,
"tokenizers"
:
DUMMY_TOKENIZERS_PRETRAINED_CLASS
,
}
}
...
@@ -127,6 +150,7 @@ DUMMY_PRETRAINED_CLASS = {
...
@@ -127,6 +150,7 @@ DUMMY_PRETRAINED_CLASS = {
DUMMY_CLASS
=
{
DUMMY_CLASS
=
{
"pt"
:
DUMMY_PT_CLASS
,
"pt"
:
DUMMY_PT_CLASS
,
"tf"
:
DUMMY_TF_CLASS
,
"tf"
:
DUMMY_TF_CLASS
,
"flax"
:
DUMMY_FLAX_CLASS
,
"sentencepiece"
:
DUMMY_SENTENCEPIECE_CLASS
,
"sentencepiece"
:
DUMMY_SENTENCEPIECE_CLASS
,
"tokenizers"
:
DUMMY_TOKENIZERS_CLASS
,
"tokenizers"
:
DUMMY_TOKENIZERS_CLASS
,
}
}
...
@@ -134,6 +158,7 @@ DUMMY_CLASS = {
...
@@ -134,6 +158,7 @@ DUMMY_CLASS = {
DUMMY_FUNCTION
=
{
DUMMY_FUNCTION
=
{
"pt"
:
DUMMY_PT_FUNCTION
,
"pt"
:
DUMMY_PT_FUNCTION
,
"tf"
:
DUMMY_TF_FUNCTION
,
"tf"
:
DUMMY_TF_FUNCTION
,
"flax"
:
DUMMY_FLAX_FUNCTION
,
"sentencepiece"
:
DUMMY_SENTENCEPIECE_FUNCTION
,
"sentencepiece"
:
DUMMY_SENTENCEPIECE_FUNCTION
,
"tokenizers"
:
DUMMY_TOKENIZERS_FUNCTION
,
"tokenizers"
:
DUMMY_TOKENIZERS_FUNCTION
,
}
}
...
@@ -208,7 +233,24 @@ def read_init():
...
@@ -208,7 +233,24 @@ def read_init():
elif
line
.
startswith
(
" "
):
elif
line
.
startswith
(
" "
):
tf_objects
.
append
(
line
[
8
:
-
2
])
tf_objects
.
append
(
line
[
8
:
-
2
])
line_index
+=
1
line_index
+=
1
return
sentencepiece_objects
,
tokenizers_objects
,
pt_objects
,
tf_objects
# Find where the FLAX imports begin
flax_objects
=
[]
while
not
lines
[
line_index
].
startswith
(
"if is_flax_available():"
):
line_index
+=
1
line_index
+=
1
# Until we unindent, add PyTorch objects to the list
while
len
(
lines
[
line_index
])
<=
1
or
lines
[
line_index
].
startswith
(
" "
):
line
=
lines
[
line_index
]
search
=
_re_single_line_import
.
search
(
line
)
if
search
is
not
None
:
flax_objects
+=
search
.
groups
()[
0
].
split
(
", "
)
elif
line
.
startswith
(
" "
):
flax_objects
.
append
(
line
[
8
:
-
2
])
line_index
+=
1
return
sentencepiece_objects
,
tokenizers_objects
,
pt_objects
,
tf_objects
,
flax_objects
def
create_dummy_object
(
name
,
type
=
"pt"
):
def
create_dummy_object
(
name
,
type
=
"pt"
):
...
@@ -224,7 +266,7 @@ def create_dummy_object(name, type="pt"):
...
@@ -224,7 +266,7 @@ def create_dummy_object(name, type="pt"):
"Model"
,
"Model"
,
"Tokenizer"
,
"Tokenizer"
,
]
]
assert
type
in
[
"pt"
,
"tf"
,
"sentencepiece"
,
"tokenizers"
]
assert
type
in
[
"pt"
,
"tf"
,
"sentencepiece"
,
"tokenizers"
,
"flax"
]
if
name
.
isupper
():
if
name
.
isupper
():
return
DUMMY_CONSTANT
.
format
(
name
)
return
DUMMY_CONSTANT
.
format
(
name
)
elif
name
.
islower
():
elif
name
.
islower
():
...
@@ -244,7 +286,7 @@ def create_dummy_object(name, type="pt"):
...
@@ -244,7 +286,7 @@ def create_dummy_object(name, type="pt"):
def
create_dummy_files
():
def
create_dummy_files
():
""" Create the content of the dummy files. """
""" Create the content of the dummy files. """
sentencepiece_objects
,
tokenizers_objects
,
pt_objects
,
tf_objects
=
read_init
()
sentencepiece_objects
,
tokenizers_objects
,
pt_objects
,
tf_objects
,
flax_objects
=
read_init
()
sentencepiece_dummies
=
"# This file is autogenerated by the command `make fix-copies`, do not edit.
\n
"
sentencepiece_dummies
=
"# This file is autogenerated by the command `make fix-copies`, do not edit.
\n
"
sentencepiece_dummies
+=
"from ..file_utils import requires_sentencepiece
\n\n
"
sentencepiece_dummies
+=
"from ..file_utils import requires_sentencepiece
\n\n
"
...
@@ -262,17 +304,22 @@ def create_dummy_files():
...
@@ -262,17 +304,22 @@ def create_dummy_files():
tf_dummies
+=
"from ..file_utils import requires_tf
\n\n
"
tf_dummies
+=
"from ..file_utils import requires_tf
\n\n
"
tf_dummies
+=
"
\n
"
.
join
([
create_dummy_object
(
o
,
type
=
"tf"
)
for
o
in
tf_objects
])
tf_dummies
+=
"
\n
"
.
join
([
create_dummy_object
(
o
,
type
=
"tf"
)
for
o
in
tf_objects
])
return
sentencepiece_dummies
,
tokenizers_dummies
,
pt_dummies
,
tf_dummies
flax_dummies
=
"# This file is autogenerated by the command `make fix-copies`, do not edit.
\n
"
flax_dummies
+=
"from ..file_utils import requires_flax
\n\n
"
flax_dummies
+=
"
\n
"
.
join
([
create_dummy_object
(
o
,
type
=
"flax"
)
for
o
in
flax_objects
])
return
sentencepiece_dummies
,
tokenizers_dummies
,
pt_dummies
,
tf_dummies
,
flax_dummies
def
check_dummies
(
overwrite
=
False
):
def
check_dummies
(
overwrite
=
False
):
""" Check if the dummy files are up to date and maybe `overwrite` with the right content. """
""" Check if the dummy files are up to date and maybe `overwrite` with the right content. """
sentencepiece_dummies
,
tokenizers_dummies
,
pt_dummies
,
tf_dummies
=
create_dummy_files
()
sentencepiece_dummies
,
tokenizers_dummies
,
pt_dummies
,
tf_dummies
,
flax_dummies
=
create_dummy_files
()
path
=
os
.
path
.
join
(
PATH_TO_TRANSFORMERS
,
"utils"
)
path
=
os
.
path
.
join
(
PATH_TO_TRANSFORMERS
,
"utils"
)
sentencepiece_file
=
os
.
path
.
join
(
path
,
"dummy_sentencepiece_objects.py"
)
sentencepiece_file
=
os
.
path
.
join
(
path
,
"dummy_sentencepiece_objects.py"
)
tokenizers_file
=
os
.
path
.
join
(
path
,
"dummy_tokenizers_objects.py"
)
tokenizers_file
=
os
.
path
.
join
(
path
,
"dummy_tokenizers_objects.py"
)
pt_file
=
os
.
path
.
join
(
path
,
"dummy_pt_objects.py"
)
pt_file
=
os
.
path
.
join
(
path
,
"dummy_pt_objects.py"
)
tf_file
=
os
.
path
.
join
(
path
,
"dummy_tf_objects.py"
)
tf_file
=
os
.
path
.
join
(
path
,
"dummy_tf_objects.py"
)
flax_file
=
os
.
path
.
join
(
path
,
"dummy_flax_objects.py"
)
with
open
(
sentencepiece_file
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
sentencepiece_file
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
actual_sentencepiece_dummies
=
f
.
read
()
actual_sentencepiece_dummies
=
f
.
read
()
...
@@ -282,6 +329,8 @@ def check_dummies(overwrite=False):
...
@@ -282,6 +329,8 @@ def check_dummies(overwrite=False):
actual_pt_dummies
=
f
.
read
()
actual_pt_dummies
=
f
.
read
()
with
open
(
tf_file
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
tf_file
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
actual_tf_dummies
=
f
.
read
()
actual_tf_dummies
=
f
.
read
()
with
open
(
flax_file
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
actual_flax_dummies
=
f
.
read
()
if
sentencepiece_dummies
!=
actual_sentencepiece_dummies
:
if
sentencepiece_dummies
!=
actual_sentencepiece_dummies
:
if
overwrite
:
if
overwrite
:
...
@@ -327,6 +376,17 @@ def check_dummies(overwrite=False):
...
@@ -327,6 +376,17 @@ def check_dummies(overwrite=False):
"Run `make fix-copies` to fix this."
,
"Run `make fix-copies` to fix this."
,
)
)
if
flax_dummies
!=
actual_flax_dummies
:
if
overwrite
:
print
(
"Updating transformers.utils.dummy_flax_objects.py as the main __init__ has new objects."
)
with
open
(
flax_file
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
flax_dummies
)
else
:
raise
ValueError
(
"The main __init__ has objects that are not present in transformers.utils.dummy_flax_objects.py."
,
"Run `make fix-copies` to fix this."
,
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment