Commit 366a3b02 authored by thomwolf's avatar thomwolf
Browse files

clean up in tokenization

parent 0efc4ab6
...@@ -39,8 +39,10 @@ from .modeling import BertLayerNorm as LayerNorm ...@@ -39,8 +39,10 @@ from .modeling import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"} PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"} "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path): def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
""" Load tf checkpoints in a pytorch model """ Load tf checkpoints in a pytorch model
......
...@@ -37,9 +37,11 @@ logger = logging.getLogger(__name__) ...@@ -37,9 +37,11 @@ logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = { PRETRAINED_VOCAB_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
} }
PRETRAINED_MERGES_ARCHIVE_MAP = { PRETRAINED_MERGES_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
} }
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'gpt2': 1024, 'gpt2': 1024,
...@@ -263,9 +265,14 @@ class GPT2Tokenizer(object): ...@@ -263,9 +265,14 @@ class GPT2Tokenizer(object):
def encode(self, text): def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text)) return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, tokens, skip_special_tokens=False): def decode(self, tokens, skip_special_tokens=False, clean_up_tokenization_spaces=True):
text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens)) text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens))
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
if clean_up_tokenization_spaces:
text = text.replace('<unk>', '')
text = text.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return text return text
def save_vocabulary(self, vocab_path): def save_vocabulary(self, vocab_path):
......
...@@ -272,7 +272,7 @@ class OpenAIGPTTokenizer(object): ...@@ -272,7 +272,7 @@ class OpenAIGPTTokenizer(object):
out_string = ''.join(tokens).replace('</w>', ' ').strip() out_string = ''.join(tokens).replace('</w>', ' ').strip()
if clean_up_tokenization_spaces: if clean_up_tokenization_spaces:
out_string = out_string.replace('<unk>', '') out_string = out_string.replace('<unk>', '')
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ',' out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't" ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re") ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return out_string return out_string
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment