Commit 9d6d291f authored by Leo Gao's avatar Leo Gao
Browse files

make create_from_arg_string fully general

for some reason putting it in LM and having it be inherited breaks
everything. should try to figure this out at some point.
parent 2bb67c98
...@@ -44,10 +44,10 @@ class GPT2LM(LM): ...@@ -44,10 +44,10 @@ class GPT2LM(LM):
# self.gpt2 = nn.DataParallel(self.gpt2) # self.gpt2 = nn.DataParallel(self.gpt2)
@classmethod @classmethod
def create_from_arg_string(cls, arg_string, **kwargs): def create_from_arg_string(cls, arg_string, additional_config={}):
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
kwargs = {k: v for k, v in kwargs.items() if v is not None} args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(pretrained=args.get("pretrained", "gpt2"), **kwargs) return cls(**args, **args2)
def loglikelihood(self, requests): def loglikelihood(self, requests):
new_reqs = [] new_reqs = []
......
...@@ -65,10 +65,10 @@ class GPT3LM(LM): ...@@ -65,10 +65,10 @@ class GPT3LM(LM):
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"] openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
@classmethod @classmethod
def create_from_arg_string(cls, arg_string, **kwargs): def create_from_arg_string(cls, arg_string, additional_config={}):
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
kwargs = {k: v for k, v in kwargs.items() if v is not None} args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(engine=args.get("engine", "davinci"), **kwargs) return cls(**args, **args2)
def loglikelihood(self, requests): def loglikelihood(self, requests):
new_reqs = [] new_reqs = []
......
...@@ -29,7 +29,9 @@ def main(): ...@@ -29,7 +29,9 @@ def main():
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
lm = models.get_model(args.model).create_from_arg_string(args.model_args, batch_size=args.batch_size, device=args.device) lm = models.get_model(args.model).create_from_arg_string(args.model_args, {
'batch_size': args.batch_size, 'device': args.device
})
if args.limit: if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.") print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment