Unverified Commit 4d04120c authored by Piero Molino's avatar Piero Molino Committed by GitHub
Browse files

Replaced torch.load for loading the pretrained vocab of TransformerXL...


Replaced torch.load for loading the pretrained vocab of TransformerXL tokenizer to pickle.load (#6935)

* Replaced torch.load for loading the pretrained vocab of TransformerXL to pickle.load

* Replaced torch.save with pickle.dump when saving the vocabulary

* updating transformer-xl

* uploaded on S3 - compatibility

* fix tests

* style

* Address review comments
Co-authored-by: default avatarThomas Wolf <thomwolf@users.noreply.github.com>
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent aba4e229
...@@ -203,6 +203,19 @@ def is_faiss_available(): ...@@ -203,6 +203,19 @@ def is_faiss_available():
return _faiss_available return _faiss_available
def torch_only_method(fn):
def wrapper(*args, **kwargs):
if not _torch_available:
raise ImportError(
"You need to install pytorch to use this method or class, "
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
)
else:
return fn(*args, **kwargs)
return wrapper
def is_sklearn_available(): def is_sklearn_available():
return _has_sklearn return _has_sklearn
......
...@@ -36,7 +36,7 @@ from tokenizers.normalizers import Lowercase, Sequence, Strip, unicode_normalize ...@@ -36,7 +36,7 @@ from tokenizers.normalizers import Lowercase, Sequence, Strip, unicode_normalize
from tokenizers.pre_tokenizers import CharDelimiterSplit, WhitespaceSplit from tokenizers.pre_tokenizers import CharDelimiterSplit, WhitespaceSplit
from tokenizers.processors import BertProcessing from tokenizers.processors import BertProcessing
from .file_utils import cached_path, is_torch_available from .file_utils import cached_path, is_torch_available, torch_only_method
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_fast import PreTrainedTokenizerFast from .tokenization_utils_fast import PreTrainedTokenizerFast
from .utils import logging from .utils import logging
...@@ -48,12 +48,16 @@ if is_torch_available(): ...@@ -48,12 +48,16 @@ if is_torch_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"pretrained_vocab_file": "vocab.bin", "vocab_file": "vocab.txt"} VOCAB_FILES_NAMES = {
"pretrained_vocab_file": "vocab.pkl",
"pretrained_vocab_file_torch": "vocab.bin",
"vocab_file": "vocab.txt",
}
VOCAB_FILES_NAMES_FAST = {"pretrained_vocab_file": "vocab.json", "vocab_file": "vocab.json"} VOCAB_FILES_NAMES_FAST = {"pretrained_vocab_file": "vocab.json", "vocab_file": "vocab.json"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
"pretrained_vocab_file": { "pretrained_vocab_file": {
"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin", "transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.pkl",
} }
} }
...@@ -139,8 +143,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -139,8 +143,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
File containing the vocabulary (from the original implementation). File containing the vocabulary (from the original implementation).
pretrained_vocab_file (:obj:`str`, `optional`): pretrained_vocab_file (:obj:`str`, `optional`):
File containing the vocabulary as saved with the :obj:`save_pretrained()` method. File containing the vocabulary as saved with the :obj:`save_pretrained()` method.
never_split (xxx, `optional`): never_split (:obj:`List[str]`, `optional`):
Fill me with intesting stuff. List of tokens that should never be split. If no list is specified, will simply use the existing
special tokens.
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`): unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead. token instead.
...@@ -165,7 +170,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -165,7 +170,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
lower_case=False, lower_case=False,
delimiter=None, delimiter=None,
vocab_file=None, vocab_file=None,
pretrained_vocab_file=None, pretrained_vocab_file: str = None,
never_split=None, never_split=None,
unk_token="<unk>", unk_token="<unk>",
eos_token="<eos>", eos_token="<eos>",
...@@ -197,23 +202,40 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -197,23 +202,40 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
self.moses_tokenizer = sm.MosesTokenizer(language) self.moses_tokenizer = sm.MosesTokenizer(language)
self.moses_detokenizer = sm.MosesDetokenizer(language) self.moses_detokenizer = sm.MosesDetokenizer(language)
# This try... catch... is not beautiful but honestly this tokenizer was not made to be used
# in a library like ours, at all.
try: try:
vocab_dict = None
if pretrained_vocab_file is not None: if pretrained_vocab_file is not None:
# Hack because, honestly this tokenizer was not made to be used # Priority on pickle files (support PyTorch and TF)
# in a library like ours, at all. with open(pretrained_vocab_file, "rb") as f:
vocab_dict = torch.load(pretrained_vocab_file) vocab_dict = pickle.load(f)
# Loading a torch-saved transfo-xl vocab dict with pickle results in an integer
# Entering this if statement means that we tried to load a torch-saved file with pickle, and we failed.
# We therefore load it with torch, if it's available.
if type(vocab_dict) == int:
if not is_torch_available():
raise ImportError(
"Not trying to load dict with PyTorch as you need to install pytorch to load "
"from a PyTorch pretrained vocabulary, "
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
)
vocab_dict = torch.load(pretrained_vocab_file)
if vocab_dict is not None:
for key, value in vocab_dict.items(): for key, value in vocab_dict.items():
if key not in self.__dict__: if key not in self.__dict__:
self.__dict__[key] = value self.__dict__[key] = value
elif vocab_file is not None:
if vocab_file is not None:
self.build_vocab() self.build_vocab()
except Exception:
except Exception as e:
raise ValueError( raise ValueError(
"Unable to parse file {}. Unknown format. " "Unable to parse file {}. Unknown format. "
"If you tried to load a model saved through TransfoXLTokenizerFast," "If you tried to load a model saved through TransfoXLTokenizerFast,"
"please note they are not compatible.".format(pretrained_vocab_file) "please note they are not compatible.".format(pretrained_vocab_file)
) ) from e
if vocab_file is not None: if vocab_file is not None:
self.build_vocab() self.build_vocab()
...@@ -286,7 +308,8 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -286,7 +308,8 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["pretrained_vocab_file"]) vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["pretrained_vocab_file"])
else: else:
vocab_file = vocab_path vocab_file = vocab_path
torch.save(self.__dict__, vocab_file) with open(vocab_file, "wb") as f:
pickle.dump(self.__dict__, f)
return (vocab_file,) return (vocab_file,)
def build_vocab(self): def build_vocab(self):
...@@ -309,6 +332,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -309,6 +332,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
logger.info("final vocab size {} from {} unique tokens".format(len(self), len(self.counter))) logger.info("final vocab size {} from {} unique tokens".format(len(self), len(self.counter)))
@torch_only_method
def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False): def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False):
if verbose: if verbose:
logger.info("encoding file {} ...".format(path)) logger.info("encoding file {} ...".format(path))
...@@ -326,6 +350,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -326,6 +350,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
return encoded return encoded
@torch_only_method
def encode_sents(self, sents, ordered=False, verbose=False): def encode_sents(self, sents, ordered=False, verbose=False):
if verbose: if verbose:
logger.info("encoding {} sents ...".format(len(sents))) logger.info("encoding {} sents ...".format(len(sents)))
...@@ -436,6 +461,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -436,6 +461,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
out_string = self.moses_detokenizer.detokenize(tokens) out_string = self.moses_detokenizer.detokenize(tokens)
return detokenize_numbers(out_string).strip() return detokenize_numbers(out_string).strip()
@torch_only_method
def convert_to_tensor(self, symbols): def convert_to_tensor(self, symbols):
return torch.LongTensor(self.convert_tokens_to_ids(symbols)) return torch.LongTensor(self.convert_tokens_to_ids(symbols))
...@@ -706,6 +732,7 @@ class LMShuffledIterator(object): ...@@ -706,6 +732,7 @@ class LMShuffledIterator(object):
for idx in epoch_indices: for idx in epoch_indices:
yield self.data[idx] yield self.data[idx]
@torch_only_method
def stream_iterator(self, sent_stream): def stream_iterator(self, sent_stream):
# streams for each data in the batch # streams for each data in the batch
streams = [None] * self.bsz streams = [None] * self.bsz
...@@ -795,6 +822,7 @@ class LMMultiFileIterator(LMShuffledIterator): ...@@ -795,6 +822,7 @@ class LMMultiFileIterator(LMShuffledIterator):
class TransfoXLCorpus(object): class TransfoXLCorpus(object):
@classmethod @classmethod
@torch_only_method
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
""" """
Instantiate a pre-processed corpus. Instantiate a pre-processed corpus.
...@@ -892,10 +920,14 @@ class TransfoXLCorpus(object): ...@@ -892,10 +920,14 @@ class TransfoXLCorpus(object):
data_iter = LMOrderedIterator(data, *args, **kwargs) data_iter = LMOrderedIterator(data, *args, **kwargs)
elif self.dataset == "lm1b": elif self.dataset == "lm1b":
data_iter = LMShuffledIterator(data, *args, **kwargs) data_iter = LMShuffledIterator(data, *args, **kwargs)
else:
data_iter = None
raise ValueError(f"Split not recognized: {split}")
return data_iter return data_iter
@torch_only_method
def get_lm_corpus(datadir, dataset): def get_lm_corpus(datadir, dataset):
fn = os.path.join(datadir, "cache.pt") fn = os.path.join(datadir, "cache.pt")
fn_pickle = os.path.join(datadir, "cache.pkl") fn_pickle = os.path.join(datadir, "cache.pkl")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment