Commit d9e60f4f authored by thomwolf's avatar thomwolf
Browse files

Merge branch 'master' into pr/1383

parents 07d055f8 1c507995
......@@ -69,7 +69,7 @@ def create_sinusoidal_embeddings(n_pos, dim, out):
def gelu(x):
""" Gaussian Error Linear Unit.
Original Implementation of the gelu activation function in Google Bert repo when initialy created.
Original Implementation of the gelu activation function in Google Bert repo when initially created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415
......
......@@ -501,7 +501,10 @@ class PoolerEndLogits(nn.Module):
x = self.dense_1(x).squeeze(-1)
if p_mask is not None:
x = x * (1 - p_mask) - 1e30 * p_mask
if next(self.parameters()).dtype == torch.float16:
x = x * (1 - p_mask) - 65500 * p_mask
else:
x = x * (1 - p_mask) - 1e30 * p_mask
return x
......
......@@ -46,12 +46,14 @@ PRETRAINED_VOCAB_FILES_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json",
'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-vocab.json",
},
'merges_file':
{
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt",
'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-merges.txt",
},
}
......@@ -59,6 +61,7 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'gpt2': 1024,
'gpt2-medium': 1024,
'gpt2-large': 1024,
'distilgpt2': 1024,
}
@lru_cache()
......
......@@ -512,7 +512,8 @@ class PreTrainedTokenizer(object):
for token in new_tokens:
assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
if token != self.unk_token and \
self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token):
self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \
token not in to_add_tokens:
to_add_tokens.append(token)
logger.info("Adding %s to the vocabulary", token)
......@@ -911,6 +912,11 @@ class PreTrainedTokenizer(object):
Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces.
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
Args:
token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods.
skip_special_tokens: if set to True, will replace special tokens.
clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces.
"""
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
......@@ -933,20 +939,11 @@ class PreTrainedTokenizer(object):
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
text = ''.join(sub_texts)
if self._sep_token is not None and self._sep_token in text:
text = text.replace(self._cls_token, self._sep_token)
split_text = list(filter(lambda sentence: len(sentence) > 0, text.split(self._sep_token)))
if clean_up_tokenization_spaces:
clean_text = [self.clean_up_tokenization(text) for text in split_text]
return clean_text
else:
return split_text
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text
else:
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text
else:
return text
return text
@property
def special_tokens_map(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment