Commit 7ffe47c8 authored by piero's avatar piero Committed by Julien Chaumond
Browse files

Improved device specification

parent 4f2164e4
...@@ -242,10 +242,9 @@ def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False): ...@@ -242,10 +242,9 @@ def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False):
def train_discriminator( def train_discriminator(
dataset, dataset_fp=None, pretrained_model='gpt2-medium', dataset, dataset_fp=None, pretrained_model='gpt2-medium',
epochs=10, batch_size=64, log_interval=10, epochs=10, batch_size=64, log_interval=10,
save_model=False, cached=False, use_cuda=False): save_model=False, cached=False, no_cuda=False):
if use_cuda: global device
global device device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
device = 'cuda'
print('Preprocessing {} dataset...'.format(dataset)) print('Preprocessing {} dataset...'.format(dataset))
start = time.time() start = time.time()
...@@ -577,8 +576,8 @@ if __name__ == '__main__': ...@@ -577,8 +576,8 @@ if __name__ == '__main__':
help='whether to save the model') help='whether to save the model')
parser.add_argument('--cached', action='store_true', parser.add_argument('--cached', action='store_true',
help='whether to cache the input representations') help='whether to cache the input representations')
parser.add_argument('--use_cuda', action='store_true', parser.add_argument('--no_cuda', action='store_true',
help='use to turn on cuda') help='use to turn off cuda')
args = parser.parse_args() args = parser.parse_args()
train_discriminator(**(vars(args))) train_discriminator(**(vars(args)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment