"docs/source/vscode:/vscode.git/clone" did not exist on "27e907386ad108b3fc6d5a9fc77aac8ba474592b"
Unverified Commit e54a1b49 authored by Atharva Ingle's avatar Atharva Ingle Committed by GitHub
Browse files

`model.tie_weights()` should be applied after `accelerator.prepare()` (#18676)

* `model.tie_weights()` should be applied after `accelerator.prepare`

Weight tying should be done after the model has been moved to XLA device as mentioned on PyTorch/XLA Troubleshooting guide [here](https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks)

* format code
parent bbbb453e
......@@ -518,10 +518,6 @@ def main():
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
# shorter in multiprocess)
......@@ -544,6 +540,10 @@ def main():
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment