Commit 7f7b717d authored by Leo Gao's avatar Leo Gao
Browse files

gpt2: fix bad multi gpu behavior and typecheck inputs

parent e037ef0a
...@@ -15,6 +15,11 @@ class GPT2LM(LM): ...@@ -15,6 +15,11 @@ class GPT2LM(LM):
def __init__(self, device='cuda', pretrained='gpt2', batch_size=1): def __init__(self, device='cuda', pretrained='gpt2', batch_size=1):
super().__init__() super().__init__()
assert isinstance(device, str)
assert isinstance(pretrained, str)
assert isinstance(batch_size, int)
if device: if device:
self.device = torch.device(device) self.device = torch.device(device)
else: else:
...@@ -37,7 +42,8 @@ class GPT2LM(LM): ...@@ -37,7 +42,8 @@ class GPT2LM(LM):
gpus = torch.cuda.device_count() gpus = torch.cuda.device_count()
batch_size_per_gpu = batch_size # todo: adaptive batch size batch_size_per_gpu = batch_size # todo: adaptive batch size
self.batch_size = batch_size_per_gpu * gpus # TODO: fix multi-gpu
self.batch_size = batch_size_per_gpu# * gpus
# TODO: fix multi-gpu # TODO: fix multi-gpu
# if gpus > 1: # if gpus > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment