"test/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "723e9a52ebde0afd542b1cc8588598ad2c893c87"
Commit 5f432480 authored by VictorSanh's avatar VictorSanh
Browse files

Create DataParallel model if several GPUs

parent 5889765a
......@@ -249,6 +249,9 @@ def main():
if args.init_checkpoint is not None:
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device)
if n_gpu > 1:
model = nn.DataParallel(model)
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
......
......@@ -482,6 +482,9 @@ def main():
if args.init_checkpoint is not None:
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device)
if n_gpu > 1:
model = torch.nn.DataParallel(model)
optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01},
{'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.}
......
......@@ -795,6 +795,9 @@ def main():
if args.init_checkpoint is not None:
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device)
if n_gpu > 1:
model = torch.nn.DataParallel(model)
optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01},
{'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment