Unverified Commit fecb665b authored by anj-s's avatar anj-s Committed by GitHub
Browse files

Update offload_model.rst (#806)

parent 1b9be421
......@@ -75,7 +75,7 @@ of slices that the model should be sharded into. By default activation checkpoin
optimizer.zero_grad()
inputs = batch_inputs.reshape(-1, num_inputs * num_inputs)
with torch.cuda.amp.autocast():
output = model(inputs)
output = offload_model(inputs)
loss = criterion(output, target=batch_outputs)
loss.backward()
optimizer.step()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment