Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
072750f4
Unverified
Commit
072750f4
authored
Dec 23, 2019
by
Aymeric Augustin
Committed by
GitHub
Dec 23, 2019
Browse files
Merge pull request #2288 from aaugustin/better-handle-optional-imports
Improve handling of optional imports
parents
23dad844
4621ad6f
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
31 additions
and
33 deletions
+31
-33
src/transformers/commands/serving.py
src/transformers/commands/serving.py
+1
-1
src/transformers/commands/train.py
src/transformers/commands/train.py
+1
-1
src/transformers/data/metrics/__init__.py
src/transformers/data/metrics/__init__.py
+0
-6
src/transformers/data/processors/squad.py
src/transformers/data/processors/squad.py
+2
-2
src/transformers/data/processors/utils.py
src/transformers/data/processors/utils.py
+2
-2
src/transformers/modeling_tf_pytorch_utils.py
src/transformers/modeling_tf_pytorch_utils.py
+8
-8
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+2
-2
src/transformers/pipelines.py
src/transformers/pipelines.py
+1
-1
src/transformers/tokenization_albert.py
src/transformers/tokenization_albert.py
+2
-0
src/transformers/tokenization_t5.py
src/transformers/tokenization_t5.py
+2
-0
src/transformers/tokenization_transfo_xl.py
src/transformers/tokenization_transfo_xl.py
+2
-4
src/transformers/tokenization_xlm.py
src/transformers/tokenization_xlm.py
+6
-6
src/transformers/tokenization_xlnet.py
src/transformers/tokenization_xlnet.py
+2
-0
No files found.
src/transformers/commands/serving.py
View file @
072750f4
...
...
@@ -107,7 +107,7 @@ class ServeCommand(BaseTransformersCLICommand):
self
.
_host
=
host
self
.
_port
=
port
if
not
_serve_dependancies_installed
:
raise
Import
Error
(
raise
Runtime
Error
(
"Using serve command requires FastAPI and unicorn. "
"Please install transformers with [serving]: pip install transformers[serving]."
"Or install FastAPI and unicorn separatly."
...
...
src/transformers/commands/train.py
View file @
072750f4
...
...
@@ -8,7 +8,7 @@ from transformers.commands import BaseTransformersCLICommand
if
not
is_tf_available
()
and
not
is_torch_available
():
raise
Import
Error
(
"At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training"
)
raise
Runtime
Error
(
"At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training"
)
# TF training parameters
USE_XLA
=
False
...
...
src/transformers/data/metrics/__init__.py
View file @
072750f4
...
...
@@ -14,18 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
logger
=
logging
.
getLogger
(
__name__
)
try
:
from
scipy.stats
import
pearsonr
,
spearmanr
from
sklearn.metrics
import
matthews_corrcoef
,
f1_score
_has_sklearn
=
True
except
(
AttributeError
,
ImportError
)
as
e
:
logger
.
warning
(
"To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html"
)
_has_sklearn
=
False
...
...
src/transformers/data/processors/squad.py
View file @
072750f4
...
...
@@ -324,7 +324,7 @@ def squad_convert_examples_to_features(
del
new_features
if
return_dataset
==
"pt"
:
if
not
is_torch_available
():
raise
Import
Error
(
"Py
t
orch must be installed to return a
pyt
orch dataset."
)
raise
Runtime
Error
(
"Py
T
orch must be installed to return a
PyT
orch dataset."
)
# Convert to Tensors and build dataset
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
features
],
dtype
=
torch
.
long
)
...
...
@@ -354,7 +354,7 @@ def squad_convert_examples_to_features(
return
features
,
dataset
elif
return_dataset
==
"tf"
:
if
not
is_tf_available
():
raise
Import
Error
(
"TensorFlow must be installed to return a TensorFlow dataset."
)
raise
Runtime
Error
(
"TensorFlow must be installed to return a TensorFlow dataset."
)
def
gen
():
for
ex
in
features
:
...
...
src/transformers/data/processors/utils.py
View file @
072750f4
...
...
@@ -294,7 +294,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
return
features
elif
return_tensors
==
"tf"
:
if
not
is_tf_available
():
raise
Import
Error
(
"return_tensors set to 'tf' but TensorFlow 2.0 can't be imported"
)
raise
Runtime
Error
(
"return_tensors set to 'tf' but TensorFlow 2.0 can't be imported"
)
import
tensorflow
as
tf
def
gen
():
...
...
@@ -309,7 +309,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
return
dataset
elif
return_tensors
==
"pt"
:
if
not
is_torch_available
():
raise
Import
Error
(
"return_tensors set to 'pt' but PyTorch can't be imported"
)
raise
Runtime
Error
(
"return_tensors set to 'pt' but PyTorch can't be imported"
)
import
torch
from
torch.utils.data
import
TensorDataset
...
...
src/transformers/modeling_tf_pytorch_utils.py
View file @
072750f4
...
...
@@ -76,12 +76,12 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
try
:
import
tensorflow
as
tf
# noqa: F401
import
torch
# noqa: F401
except
ImportError
as
e
:
except
ImportError
:
logger
.
error
(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
e
raise
pt_path
=
os
.
path
.
abspath
(
pytorch_checkpoint_path
)
logger
.
info
(
"Loading PyTorch weights from {}"
.
format
(
pt_path
))
...
...
@@ -111,12 +111,12 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
import
torch
# noqa: F401
import
tensorflow
as
tf
# noqa: F401
from
tensorflow.python.keras
import
backend
as
K
except
ImportError
as
e
:
except
ImportError
:
logger
.
error
(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
e
raise
if
tf_inputs
is
None
:
tf_inputs
=
tf_model
.
dummy_inputs
...
...
@@ -209,12 +209,12 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
try
:
import
tensorflow
as
tf
# noqa: F401
import
torch
# noqa: F401
except
ImportError
as
e
:
except
ImportError
:
logger
.
error
(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
e
raise
import
transformers
...
...
@@ -251,12 +251,12 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
try
:
import
tensorflow
as
tf
# noqa: F401
import
torch
# noqa: F401
except
ImportError
as
e
:
except
ImportError
:
logger
.
error
(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
e
raise
new_pt_params_dict
=
{}
current_pt_params_dict
=
dict
(
pt_model
.
named_parameters
())
...
...
src/transformers/modeling_utils.py
View file @
072750f4
...
...
@@ -454,12 +454,12 @@ class PreTrainedModel(nn.Module):
from
transformers
import
load_tf2_checkpoint_in_pytorch_model
model
=
load_tf2_checkpoint_in_pytorch_model
(
model
,
resolved_archive_file
,
allow_missing_keys
=
True
)
except
ImportError
as
e
:
except
ImportError
:
logger
.
error
(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
e
raise
else
:
# Convert old format to new format if needed from a PyTorch state_dict
old_keys
=
[]
...
...
src/transformers/pipelines.py
View file @
072750f4
...
...
@@ -68,7 +68,7 @@ def get_framework(model=None):
# Try to guess which framework to use from the model classname
framework
=
"tf"
if
model
.
__class__
.
__name__
.
startswith
(
"TF"
)
else
"pt"
elif
not
is_tf_available
()
and
not
is_torch_available
():
raise
Import
Error
(
raise
Runtime
Error
(
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/."
...
...
src/transformers/tokenization_albert.py
View file @
072750f4
...
...
@@ -100,6 +100,7 @@ class AlbertTokenizer(PreTrainedTokenizer):
"You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
raise
self
.
do_lower_case
=
do_lower_case
self
.
remove_space
=
remove_space
...
...
@@ -127,6 +128,7 @@ class AlbertTokenizer(PreTrainedTokenizer):
"You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
raise
self
.
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
sp_model
.
Load
(
self
.
vocab_file
)
...
...
src/transformers/tokenization_t5.py
View file @
072750f4
...
...
@@ -107,6 +107,7 @@ class T5Tokenizer(PreTrainedTokenizer):
"https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
raise
self
.
vocab_file
=
vocab_file
self
.
_extra_ids
=
extra_ids
...
...
@@ -132,6 +133,7 @@ class T5Tokenizer(PreTrainedTokenizer):
"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
raise
self
.
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
sp_model
.
Load
(
self
.
vocab_file
)
...
...
src/transformers/tokenization_transfo_xl.py
View file @
072750f4
...
...
@@ -26,14 +26,12 @@ from collections import Counter, OrderedDict
import
numpy
as
np
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
,
is_torch_available
from
.tokenization_utils
import
PreTrainedTokenizer
try
:
if
is_torch_available
()
:
import
torch
except
ImportError
:
pass
logger
=
logging
.
getLogger
(
__name__
)
...
...
src/transformers/tokenization_xlm.py
View file @
072750f4
...
...
@@ -646,7 +646,7 @@ class XLMTokenizer(PreTrainedTokenizer):
self
.
ja_word_tokenizer
=
Mykytea
.
Mykytea
(
"-model %s/local/share/kytea/model.bin"
%
os
.
path
.
expanduser
(
"~"
)
)
except
(
AttributeError
,
ImportError
)
as
e
:
except
(
AttributeError
,
ImportError
):
logger
.
error
(
"Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps"
)
...
...
@@ -655,7 +655,7 @@ class XLMTokenizer(PreTrainedTokenizer):
logger
.
error
(
"3. ./configure --prefix=$HOME/local"
)
logger
.
error
(
"4. make && make install"
)
logger
.
error
(
"5. pip install kytea"
)
raise
e
raise
return
list
(
self
.
ja_word_tokenizer
.
getWS
(
text
))
@
property
...
...
@@ -760,12 +760,12 @@ class XLMTokenizer(PreTrainedTokenizer):
from
pythainlp.tokenize
import
word_tokenize
as
th_word_tokenize
else
:
th_word_tokenize
=
sys
.
modules
[
"pythainlp"
].
word_tokenize
except
(
AttributeError
,
ImportError
)
as
e
:
except
(
AttributeError
,
ImportError
):
logger
.
error
(
"Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps"
)
logger
.
error
(
"1. pip install pythainlp"
)
raise
e
raise
text
=
th_word_tokenize
(
text
)
elif
lang
==
"zh"
:
try
:
...
...
@@ -773,10 +773,10 @@ class XLMTokenizer(PreTrainedTokenizer):
import
jieba
else
:
jieba
=
sys
.
modules
[
"jieba"
]
except
(
AttributeError
,
ImportError
)
as
e
:
except
(
AttributeError
,
ImportError
):
logger
.
error
(
"Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps"
)
logger
.
error
(
"1. pip install jieba"
)
raise
e
raise
text
=
" "
.
join
(
jieba
.
cut
(
text
))
text
=
self
.
moses_pipeline
(
text
,
lang
=
lang
)
text
=
text
.
split
()
...
...
src/transformers/tokenization_xlnet.py
View file @
072750f4
...
...
@@ -100,6 +100,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
raise
self
.
do_lower_case
=
do_lower_case
self
.
remove_space
=
remove_space
...
...
@@ -127,6 +128,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
raise
self
.
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
sp_model
.
Load
(
self
.
vocab_file
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment