Unverified Commit 4d04120c authored by Piero Molino's avatar Piero Molino Committed by GitHub
Browse files

Replaced torch.load for loading the pretrained vocab of TransformerXL...


Replaced torch.load for loading the pretrained vocab of TransformerXL tokenizer to pickle.load (#6935)

* Replaced torch.load for loading the pretrained vocab of TransformerXL to pickle.load

* Replaced torch.save with pickle.dump when saving the vocabulary

* updating transformer-xl

* uploaded on S3 - compatibility

* fix tests

* style

* Address review comments
Co-authored-by: default avatarThomas Wolf <thomwolf@users.noreply.github.com>
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent aba4e229
......@@ -203,6 +203,19 @@ def is_faiss_available():
return _faiss_available
def torch_only_method(fn):
def wrapper(*args, **kwargs):
if not _torch_available:
raise ImportError(
"You need to install pytorch to use this method or class, "
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
)
else:
return fn(*args, **kwargs)
return wrapper
def is_sklearn_available():
return _has_sklearn
......
......@@ -36,7 +36,7 @@ from tokenizers.normalizers import Lowercase, Sequence, Strip, unicode_normalize
from tokenizers.pre_tokenizers import CharDelimiterSplit, WhitespaceSplit
from tokenizers.processors import BertProcessing
from .file_utils import cached_path, is_torch_available
from .file_utils import cached_path, is_torch_available, torch_only_method
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_fast import PreTrainedTokenizerFast
from .utils import logging
......@@ -48,12 +48,16 @@ if is_torch_available():
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"pretrained_vocab_file": "vocab.bin", "vocab_file": "vocab.txt"}
VOCAB_FILES_NAMES = {
"pretrained_vocab_file": "vocab.pkl",
"pretrained_vocab_file_torch": "vocab.bin",
"vocab_file": "vocab.txt",
}
VOCAB_FILES_NAMES_FAST = {"pretrained_vocab_file": "vocab.json", "vocab_file": "vocab.json"}
PRETRAINED_VOCAB_FILES_MAP = {
"pretrained_vocab_file": {
"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.pkl",
}
}
......@@ -139,8 +143,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
File containing the vocabulary (from the original implementation).
pretrained_vocab_file (:obj:`str`, `optional`):
File containing the vocabulary as saved with the :obj:`save_pretrained()` method.
never_split (xxx, `optional`):
Fill me with intesting stuff.
never_split (:obj:`List[str]`, `optional`):
List of tokens that should never be split. If no list is specified, will simply use the existing
special tokens.
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
......@@ -165,7 +170,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
lower_case=False,
delimiter=None,
vocab_file=None,
pretrained_vocab_file=None,
pretrained_vocab_file: str = None,
never_split=None,
unk_token="<unk>",
eos_token="<eos>",
......@@ -197,23 +202,40 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
self.moses_tokenizer = sm.MosesTokenizer(language)
self.moses_detokenizer = sm.MosesDetokenizer(language)
# This try... catch... is not beautiful but honestly this tokenizer was not made to be used
# in a library like ours, at all.
try:
vocab_dict = None
if pretrained_vocab_file is not None:
# Hack because, honestly this tokenizer was not made to be used
# in a library like ours, at all.
# Priority on pickle files (support PyTorch and TF)
with open(pretrained_vocab_file, "rb") as f:
vocab_dict = pickle.load(f)
# Loading a torch-saved transfo-xl vocab dict with pickle results in an integer
# Entering this if statement means that we tried to load a torch-saved file with pickle, and we failed.
# We therefore load it with torch, if it's available.
if type(vocab_dict) == int:
if not is_torch_available():
raise ImportError(
"Not trying to load dict with PyTorch as you need to install pytorch to load "
"from a PyTorch pretrained vocabulary, "
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
)
vocab_dict = torch.load(pretrained_vocab_file)
if vocab_dict is not None:
for key, value in vocab_dict.items():
if key not in self.__dict__:
self.__dict__[key] = value
if vocab_file is not None:
elif vocab_file is not None:
self.build_vocab()
except Exception:
except Exception as e:
raise ValueError(
"Unable to parse file {}. Unknown format. "
"If you tried to load a model saved through TransfoXLTokenizerFast,"
"please note they are not compatible.".format(pretrained_vocab_file)
)
) from e
if vocab_file is not None:
self.build_vocab()
......@@ -286,7 +308,8 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["pretrained_vocab_file"])
else:
vocab_file = vocab_path
torch.save(self.__dict__, vocab_file)
with open(vocab_file, "wb") as f:
pickle.dump(self.__dict__, f)
return (vocab_file,)
def build_vocab(self):
......@@ -309,6 +332,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
logger.info("final vocab size {} from {} unique tokens".format(len(self), len(self.counter)))
@torch_only_method
def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False):
if verbose:
logger.info("encoding file {} ...".format(path))
......@@ -326,6 +350,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
return encoded
@torch_only_method
def encode_sents(self, sents, ordered=False, verbose=False):
if verbose:
logger.info("encoding {} sents ...".format(len(sents)))
......@@ -436,6 +461,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
out_string = self.moses_detokenizer.detokenize(tokens)
return detokenize_numbers(out_string).strip()
@torch_only_method
def convert_to_tensor(self, symbols):
return torch.LongTensor(self.convert_tokens_to_ids(symbols))
......@@ -706,6 +732,7 @@ class LMShuffledIterator(object):
for idx in epoch_indices:
yield self.data[idx]
@torch_only_method
def stream_iterator(self, sent_stream):
# streams for each data in the batch
streams = [None] * self.bsz
......@@ -795,6 +822,7 @@ class LMMultiFileIterator(LMShuffledIterator):
class TransfoXLCorpus(object):
@classmethod
@torch_only_method
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a pre-processed corpus.
......@@ -892,10 +920,14 @@ class TransfoXLCorpus(object):
data_iter = LMOrderedIterator(data, *args, **kwargs)
elif self.dataset == "lm1b":
data_iter = LMShuffledIterator(data, *args, **kwargs)
else:
data_iter = None
raise ValueError(f"Split not recognized: {split}")
return data_iter
@torch_only_method
def get_lm_corpus(datadir, dataset):
fn = os.path.join(datadir, "cache.pt")
fn_pickle = os.path.join(datadir, "cache.pkl")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment