"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "8284ea208468cec3f97088bd5bdfc8376356eb34"
Commit ec16142e authored by Patrick von Platen's avatar Patrick von Platen
Browse files

add special tokens to pretrain configs of respective lm head models

parent e645dcbb
...@@ -134,6 +134,8 @@ class GPT2Config(PretrainedConfig): ...@@ -134,6 +134,8 @@ class GPT2Config(PretrainedConfig):
summary_activation=None, summary_activation=None,
summary_proj_to_labels=True, summary_proj_to_labels=True,
summary_first_dropout=0.1, summary_first_dropout=0.1,
bos_token_id=50256,
eos_token_id=50256,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -155,6 +157,9 @@ class GPT2Config(PretrainedConfig): ...@@ -155,6 +157,9 @@ class GPT2Config(PretrainedConfig):
self.summary_first_dropout = summary_first_dropout self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels self.summary_proj_to_labels = summary_proj_to_labels
self.bos_token_id = bos_token_id
self.eos_token_ids = [eos_token_id]
@property @property
def max_position_embeddings(self): def max_position_embeddings(self):
return self.n_positions return self.n_positions
......
...@@ -149,6 +149,7 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -149,6 +149,7 @@ class TransfoXLConfig(PretrainedConfig):
proj_init_std=0.01, proj_init_std=0.01,
init_std=0.02, init_std=0.02,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
eos_token_id=0,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -186,6 +187,8 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -186,6 +187,8 @@ class TransfoXLConfig(PretrainedConfig):
self.init_std = init_std self.init_std = init_std
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.eos_token_ids = [eos_token_id]
@property @property
def max_position_embeddings(self): def max_position_embeddings(self):
return self.tgt_len + self.ext_len + self.mem_len return self.tgt_len + self.ext_len + self.mem_len
......
...@@ -193,6 +193,8 @@ class XLMConfig(PretrainedConfig): ...@@ -193,6 +193,8 @@ class XLMConfig(PretrainedConfig):
end_n_top=5, end_n_top=5,
mask_token_id=0, mask_token_id=0,
lang_id=0, lang_id=0,
bos_token_id=0,
pad_token_id=2,
**kwargs **kwargs
): ):
"""Constructs XLMConfig. """Constructs XLMConfig.
...@@ -233,6 +235,9 @@ class XLMConfig(PretrainedConfig): ...@@ -233,6 +235,9 @@ class XLMConfig(PretrainedConfig):
if "n_words" in kwargs: if "n_words" in kwargs:
self.n_words = kwargs["n_words"] self.n_words = kwargs["n_words"]
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
@property @property
def n_words(self): # For backward compatibility def n_words(self): # For backward compatibility
return self.vocab_size return self.vocab_size
......
...@@ -155,6 +155,9 @@ class XLNetConfig(PretrainedConfig): ...@@ -155,6 +155,9 @@ class XLNetConfig(PretrainedConfig):
summary_last_dropout=0.1, summary_last_dropout=0.1,
start_n_top=5, start_n_top=5,
end_n_top=5, end_n_top=5,
bos_token_id=1,
pad_token_id=5,
eos_token_id=2,
**kwargs **kwargs
): ):
"""Constructs XLNetConfig. """Constructs XLNetConfig.
...@@ -188,6 +191,10 @@ class XLNetConfig(PretrainedConfig): ...@@ -188,6 +191,10 @@ class XLNetConfig(PretrainedConfig):
self.start_n_top = start_n_top self.start_n_top = start_n_top
self.end_n_top = end_n_top self.end_n_top = end_n_top
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
self.eos_token_ids = [eos_token_id]
@property @property
def max_position_embeddings(self): def max_position_embeddings(self):
return -1 return -1
......
...@@ -657,7 +657,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -657,7 +657,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
outputs = model.generate(max_length=40, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, do_sample=False) # do greedy decoding outputs = model.generate(max_length=40, do_sample=False) # do greedy decoding
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
...@@ -672,7 +672,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -672,7 +672,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
input_context = 'The dog' input_context = 'The dog'
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.pad_token_id, eos_token_ids=tokenizer.eos_token_id, num_return_sequences=3) # 3 generate sequences using by sampling outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3) # 3 generate sequences using by sampling
for i in range(3): # 3 output sequences were generated for i in range(3): # 3 output sequences were generated
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment