Commit 177a7212 authored by thomwolf's avatar thomwolf
Browse files

move back to simple space spliting

parent a5997dd8
...@@ -194,7 +194,7 @@ def main(): ...@@ -194,7 +194,7 @@ def main():
elif args.length < 0: elif args.length < 0:
args.length = MAX_LENGTH # avoid infinite loop args.length = MAX_LENGTH # avoid infinite loop
print(args) logger.info(args)
if args.model_type in ["ctrl"]: if args.model_type in ["ctrl"]:
if args.temperature > 0.7 : if args.temperature > 0.7 :
logger.info('CTRL typically works better with lower temperatures (and lower top_k).') logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
......
...@@ -22,9 +22,6 @@ import os ...@@ -22,9 +22,6 @@ import os
import regex as re import regex as re
from io import open from io import open
import sacremoses as sm
from .tokenization_xlm import replace_unicode_punct, remove_non_printing_char
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -81,9 +78,6 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -81,9 +78,6 @@ class CTRLTokenizer(PreTrainedTokenizer):
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
self.punct_normalizer = sm.MosesPunctNormalizer(lang='en')
self.moses_tokenizer = sm.MosesTokenizer(lang='en')
self.encoder = json.load(open(vocab_file, encoding="utf-8")) self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v:k for k,v in self.encoder.items()}
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
...@@ -138,22 +132,12 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -138,22 +132,12 @@ class CTRLTokenizer(PreTrainedTokenizer):
self.cache[token] = word self.cache[token] = word
return word return word
def moses_pipeline(self, text): def _tokenize(self, text):
text = replace_unicode_punct(text)
text = self.punct_normalizer.normalize(text)
text = remove_non_printing_char(text)
return text
def _tokenize(self, text, bypass_tokenizer=False):
""" Tokenize a string. """ Tokenize a string.
""" """
split_tokens = [] split_tokens = []
if bypass_tokenizer: text = text.split(' ')
text = text.split()
else:
text = self.moses_pipeline(text)
text = self.moses_tokenizer.tokenize(text, return_str=False, escape=False)
for token in text: for token in text:
split_tokens.extend([t for t in self.bpe(token).split(' ')]) split_tokens.extend([t for t in self.bpe(token).split(' ')])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment