Commit 5f42f976 authored by Leo Gao's avatar Leo Gao
Browse files

Overhaul command flags a bit

model_args should only be things that affect output of the model
therefore, stuff like batch size, device, etc shouldn't be in there
parent ba18bb4f
......@@ -10,7 +10,7 @@ from tqdm import tqdm
class GPT2LM(LM):
MAX_GEN_TOKS = 256
def __init__(self, device=None, pretrained='gpt2'):
def __init__(self, device='cuda', pretrained='gpt2', batch_size=1):
super().__init__()
if device:
self.device = torch.device(device)
......@@ -32,17 +32,19 @@ class GPT2LM(LM):
# multithreading and batching
gpus = torch.cuda.device_count()
batch_size_per_gpu = 2 # todo: adaptive batch size
batch_size_per_gpu = batch_size # todo: adaptive batch size
self.batch_size = batch_size_per_gpu * gpus
if gpus > 1:
self.gpt2 = nn.DataParallel(self.gpt2)
# TODO: fix multi-gpu
# if gpus > 1:
# self.gpt2 = nn.DataParallel(self.gpt2)
@classmethod
def create_from_arg_string(cls, arg_string):
def create_from_arg_string(cls, arg_string, **kwargs):
args = utils.simple_parse_args_string(arg_string)
return cls(device=args.get("device", None), pretrained=args.get("pretrained", "gpt2"))
kwargs = {k: v for k, v in kwargs.items() if v is not None}
return cls(pretrained=args.get("pretrained", "gpt2"), **kwargs)
def loglikelihood(self, requests):
new_reqs = []
......
......@@ -63,9 +63,10 @@ class GPT3LM(LM):
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
@classmethod
def create_from_arg_string(cls, arg_string):
def create_from_arg_string(cls, arg_string, **kwargs):
args = utils.simple_parse_args_string(arg_string)
return cls(engine=args.get("engine", "davinci"))
kwargs = {k: v for k, v in kwargs.items() if v is not None}
return cls(engine=args.get("engine", "davinci"), **kwargs)
def loglikelihood(self, requests):
new_reqs = []
......
......@@ -15,6 +15,8 @@ def parse_args():
parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None)
parser.add_argument('--device', type=int, default=None)
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None)
......@@ -27,7 +29,7 @@ def main():
random.seed(args.seed)
np.random.seed(args.seed)
lm = models.get_model(args.model).create_from_arg_string(args.model_args)
lm = models.get_model(args.model).create_from_arg_string(args.model_args, batch_size=args.batch_size, device=args.device)
if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment