"git@developer.sourcefind.cn:change/sglang.git" did not exist on "ee704e62654c225bc63a4361ca1756fdeea0c264"
Commit 5d7e8457 authored by thomwolf's avatar thomwolf
Browse files

fix model on cuda

parent eccb2f01
...@@ -135,6 +135,7 @@ def main(): ...@@ -135,6 +135,7 @@ def main():
tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name, special_tokens=special_tokens) tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name, special_tokens=special_tokens)
special_tokens_ids = list(tokenizer.convert_tokens_to_ids(token) for token in special_tokens) special_tokens_ids = list(tokenizer.convert_tokens_to_ids(token) for token in special_tokens)
model = OpenAIGPTDoubleHeadsModel.from_pretrained(args.model_name, num_special_tokens=len(special_tokens)) model = OpenAIGPTDoubleHeadsModel.from_pretrained(args.model_name, num_special_tokens=len(special_tokens))
model.to(device)
# Load and encode the datasets # Load and encode the datasets
def tokenize_and_encode(obj): def tokenize_and_encode(obj):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment