"pipelines/vscode:/vscode.git/clone" did not exist on "59b8df0e15d986c234f9ed19f7ae173b6b769b59"
Unverified Commit ce374ba8 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix Trainer in DataParallel setting (#5685)



* Fix Trainer in DataParallel setting

* Fix typo
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent 0a19a49d
......@@ -618,6 +618,9 @@ class Trainer:
if self.args.past_index >= 0 and self._past is not None:
inputs["mems"] = self._past
# Our model outputs do not work with DataParallel, so forcing return tuple.
if self.args.n_gpu > 1:
inputs["return_tuple"] = True
outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
......@@ -818,6 +821,9 @@ class Trainer:
inputs[k] = v.to(self.args.device)
if self.args.past_index >= 0:
inputs["mems"] = past
# Our model outputs do not work with DataParallel, so forcing return tuple.
if self.args.n_gpu > 1:
inputs["return_tuple"] = True
with torch.no_grad():
outputs = model(**inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment