"vscode:/vscode.git/clone" did not exist on "80b38a076cdea1a87382372cc449bc23c4c34153"
Commit 5d7e8457 authored by thomwolf's avatar thomwolf
Browse files

fix model on cuda

parent eccb2f01
......@@ -135,6 +135,7 @@ def main():
tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name, special_tokens=special_tokens)
special_tokens_ids = list(tokenizer.convert_tokens_to_ids(token) for token in special_tokens)
model = OpenAIGPTDoubleHeadsModel.from_pretrained(args.model_name, num_special_tokens=len(special_tokens))
model.to(device)
# Load and encode the datasets
def tokenize_and_encode(obj):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment