Commit e768f232 authored by thomwolf's avatar thomwolf
Browse files

update run_openai_gpt to fix #1264

parent 83349939
...@@ -153,9 +153,11 @@ def main(): ...@@ -153,9 +153,11 @@ def main():
# This loading functions also add new tokens and embeddings called `special tokens` # This loading functions also add new tokens and embeddings called `special tokens`
# These new embeddings will be fine-tuned on the RocStories dataset # These new embeddings will be fine-tuned on the RocStories dataset
special_tokens = ['_start_', '_delimiter_', '_classify_'] special_tokens = ['_start_', '_delimiter_', '_classify_']
tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name, special_tokens=special_tokens) tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name)
special_tokens_ids = list(tokenizer.convert_tokens_to_ids(token) for token in special_tokens) tokenizer.add_tokens(special_tokens)
model = OpenAIGPTDoubleHeadsModel.from_pretrained(args.model_name, num_special_tokens=len(special_tokens)) special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens)
model = OpenAIGPTDoubleHeadsModel.from_pretrained(args.model_name)
model.resize_token_embeddings(len(tokenizer))
model.to(device) model.to(device)
# Load and encode the datasets # Load and encode the datasets
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment