Commit 1b4242c1 authored by Leo Gao's avatar Leo Gao
Browse files

More changes to make neo work

parent 747b851d
...@@ -9,11 +9,16 @@ from tqdm import tqdm ...@@ -9,11 +9,16 @@ from tqdm import tqdm
class GPT2LM(LM): class GPT2LM(LM):
MAX_GEN_TOKS = 256 MAX_GEN_TOKS = 256
def __init__(self, device="cpu", pretrained='gpt2'): def __init__(self, device=None, pretrained='gpt2'):
if device:
self.device = torch.device(device) self.device = torch.device(device)
else:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device) self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device)
self.gpt2.eval() self.gpt2.eval()
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained(pretrained)
# pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2
self.tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
self.tokenizer.pad_token = "<|endoftext|>" self.tokenizer.pad_token = "<|endoftext|>"
self.max_length = self.gpt2.config.n_ctx self.max_length = self.gpt2.config.n_ctx
...@@ -22,7 +27,7 @@ class GPT2LM(LM): ...@@ -22,7 +27,7 @@ class GPT2LM(LM):
@classmethod @classmethod
def create_from_arg_string(cls, arg_string): def create_from_arg_string(cls, arg_string):
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
return cls(device=args.get("device", "cpu"), pretrained=args.get("pretrained", "gpt2")) return cls(device=args.get("device", None), pretrained=args.get("pretrained", "gpt2"))
def loglikelihood(self, requests): def loglikelihood(self, requests):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
......
...@@ -35,7 +35,7 @@ def main(): ...@@ -35,7 +35,7 @@ def main():
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.") print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
if not args.no_cache: if not args.no_cache:
lm = base.CachingLM(lm, 'lm_cache/' + args.model + '_' + args.model_args.replace('=', '-').replace(',', '_') + '.db') lm = base.CachingLM(lm, 'lm_cache/' + args.model + '_' + args.model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db')
if args.tasks == "all_tasks": if args.tasks == "all_tasks":
task_names = tasks.ALL_TASKS task_names = tasks.ALL_TASKS
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment