Commit ec16142e authored by Patrick von Platen's avatar Patrick von Platen
Browse files

add special tokens to pretrain configs of respective lm head models

parent e645dcbb
......@@ -134,6 +134,8 @@ class GPT2Config(PretrainedConfig):
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
bos_token_id=50256,
eos_token_id=50256,
**kwargs
):
super().__init__(**kwargs)
......@@ -155,6 +157,9 @@ class GPT2Config(PretrainedConfig):
self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels
self.bos_token_id = bos_token_id
self.eos_token_ids = [eos_token_id]
@property
def max_position_embeddings(self):
return self.n_positions
......
......@@ -149,6 +149,7 @@ class TransfoXLConfig(PretrainedConfig):
proj_init_std=0.01,
init_std=0.02,
layer_norm_epsilon=1e-5,
eos_token_id=0,
**kwargs
):
super().__init__(**kwargs)
......@@ -186,6 +187,8 @@ class TransfoXLConfig(PretrainedConfig):
self.init_std = init_std
self.layer_norm_epsilon = layer_norm_epsilon
self.eos_token_ids = [eos_token_id]
@property
def max_position_embeddings(self):
return self.tgt_len + self.ext_len + self.mem_len
......
......@@ -193,6 +193,8 @@ class XLMConfig(PretrainedConfig):
end_n_top=5,
mask_token_id=0,
lang_id=0,
bos_token_id=0,
pad_token_id=2,
**kwargs
):
"""Constructs XLMConfig.
......@@ -233,6 +235,9 @@ class XLMConfig(PretrainedConfig):
if "n_words" in kwargs:
self.n_words = kwargs["n_words"]
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
@property
def n_words(self): # For backward compatibility
return self.vocab_size
......
......@@ -155,6 +155,9 @@ class XLNetConfig(PretrainedConfig):
summary_last_dropout=0.1,
start_n_top=5,
end_n_top=5,
bos_token_id=1,
pad_token_id=5,
eos_token_id=2,
**kwargs
):
"""Constructs XLNetConfig.
......@@ -188,6 +191,10 @@ class XLNetConfig(PretrainedConfig):
self.start_n_top = start_n_top
self.end_n_top = end_n_top
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
self.eos_token_ids = [eos_token_id]
@property
def max_position_embeddings(self):
return -1
......
......@@ -657,7 +657,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
outputs = model.generate(max_length=40, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, do_sample=False) # do greedy decoding
outputs = model.generate(max_length=40, do_sample=False) # do greedy decoding
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
......@@ -672,7 +672,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
input_context = 'The dog'
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.pad_token_id, eos_token_ids=tokenizer.eos_token_id, num_return_sequences=3) # 3 generate sequences using by sampling
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3) # 3 generate sequences using by sampling
for i in range(3): # 3 output sequences were generated
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment