Commit 4c09a960 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Simplify re-raising exceptions.

Most module use the simpler `raise` version. Normalize those that don't.
parent 5565dcdd
...@@ -76,12 +76,12 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i ...@@ -76,12 +76,12 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
try: try:
import tensorflow as tf # noqa: F401 import tensorflow as tf # noqa: F401
import torch # noqa: F401 import torch # noqa: F401
except ImportError as e: except ImportError:
logger.error( logger.error(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see " "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
) )
raise e raise
pt_path = os.path.abspath(pytorch_checkpoint_path) pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info("Loading PyTorch weights from {}".format(pt_path)) logger.info("Loading PyTorch weights from {}".format(pt_path))
...@@ -111,12 +111,12 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a ...@@ -111,12 +111,12 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
import torch # noqa: F401 import torch # noqa: F401
import tensorflow as tf # noqa: F401 import tensorflow as tf # noqa: F401
from tensorflow.python.keras import backend as K from tensorflow.python.keras import backend as K
except ImportError as e: except ImportError:
logger.error( logger.error(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see " "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
) )
raise e raise
if tf_inputs is None: if tf_inputs is None:
tf_inputs = tf_model.dummy_inputs tf_inputs = tf_model.dummy_inputs
...@@ -209,12 +209,12 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs ...@@ -209,12 +209,12 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
try: try:
import tensorflow as tf # noqa: F401 import tensorflow as tf # noqa: F401
import torch # noqa: F401 import torch # noqa: F401
except ImportError as e: except ImportError:
logger.error( logger.error(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
) )
raise e raise
import transformers import transformers
...@@ -251,12 +251,12 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F ...@@ -251,12 +251,12 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
try: try:
import tensorflow as tf # noqa: F401 import tensorflow as tf # noqa: F401
import torch # noqa: F401 import torch # noqa: F401
except ImportError as e: except ImportError:
logger.error( logger.error(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
) )
raise e raise
new_pt_params_dict = {} new_pt_params_dict = {}
current_pt_params_dict = dict(pt_model.named_parameters()) current_pt_params_dict = dict(pt_model.named_parameters())
......
...@@ -454,12 +454,12 @@ class PreTrainedModel(nn.Module): ...@@ -454,12 +454,12 @@ class PreTrainedModel(nn.Module):
from transformers import load_tf2_checkpoint_in_pytorch_model from transformers import load_tf2_checkpoint_in_pytorch_model
model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True) model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
except ImportError as e: except ImportError:
logger.error( logger.error(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
) )
raise e raise
else: else:
# Convert old format to new format if needed from a PyTorch state_dict # Convert old format to new format if needed from a PyTorch state_dict
old_keys = [] old_keys = []
......
...@@ -646,7 +646,7 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -646,7 +646,7 @@ class XLMTokenizer(PreTrainedTokenizer):
self.ja_word_tokenizer = Mykytea.Mykytea( self.ja_word_tokenizer = Mykytea.Mykytea(
"-model %s/local/share/kytea/model.bin" % os.path.expanduser("~") "-model %s/local/share/kytea/model.bin" % os.path.expanduser("~")
) )
except (AttributeError, ImportError) as e: except (AttributeError, ImportError):
logger.error( logger.error(
"Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps" "Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps"
) )
...@@ -655,7 +655,7 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -655,7 +655,7 @@ class XLMTokenizer(PreTrainedTokenizer):
logger.error("3. ./configure --prefix=$HOME/local") logger.error("3. ./configure --prefix=$HOME/local")
logger.error("4. make && make install") logger.error("4. make && make install")
logger.error("5. pip install kytea") logger.error("5. pip install kytea")
raise e raise
return list(self.ja_word_tokenizer.getWS(text)) return list(self.ja_word_tokenizer.getWS(text))
@property @property
...@@ -760,12 +760,12 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -760,12 +760,12 @@ class XLMTokenizer(PreTrainedTokenizer):
from pythainlp.tokenize import word_tokenize as th_word_tokenize from pythainlp.tokenize import word_tokenize as th_word_tokenize
else: else:
th_word_tokenize = sys.modules["pythainlp"].word_tokenize th_word_tokenize = sys.modules["pythainlp"].word_tokenize
except (AttributeError, ImportError) as e: except (AttributeError, ImportError):
logger.error( logger.error(
"Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps" "Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps"
) )
logger.error("1. pip install pythainlp") logger.error("1. pip install pythainlp")
raise e raise
text = th_word_tokenize(text) text = th_word_tokenize(text)
elif lang == "zh": elif lang == "zh":
try: try:
...@@ -773,10 +773,10 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -773,10 +773,10 @@ class XLMTokenizer(PreTrainedTokenizer):
import jieba import jieba
else: else:
jieba = sys.modules["jieba"] jieba = sys.modules["jieba"]
except (AttributeError, ImportError) as e: except (AttributeError, ImportError):
logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps") logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps")
logger.error("1. pip install jieba") logger.error("1. pip install jieba")
raise e raise
text = " ".join(jieba.cut(text)) text = " ".join(jieba.cut(text))
text = self.moses_pipeline(text, lang=lang) text = self.moses_pipeline(text, lang=lang)
text = text.split() text = text.split()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment