Commit 7ffe47c8 authored by piero's avatar piero Committed by Julien Chaumond
Browse files

Improved device specification

parent 4f2164e4
......@@ -242,10 +242,9 @@ def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False):
def train_discriminator(
dataset, dataset_fp=None, pretrained_model='gpt2-medium',
epochs=10, batch_size=64, log_interval=10,
save_model=False, cached=False, use_cuda=False):
if use_cuda:
save_model=False, cached=False, no_cuda=False):
global device
device = 'cuda'
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
print('Preprocessing {} dataset...'.format(dataset))
start = time.time()
......@@ -577,8 +576,8 @@ if __name__ == '__main__':
help='whether to save the model')
parser.add_argument('--cached', action='store_true',
help='whether to cache the input representations')
parser.add_argument('--use_cuda', action='store_true',
help='use to turn on cuda')
parser.add_argument('--no_cuda', action='store_true',
help='use to turn off cuda')
args = parser.parse_args()
train_discriminator(**(vars(args)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment