Commit 68f7064a authored by Lysandre's avatar Lysandre
Browse files

Add `model.train()` line to ReadMe training example


Co-Authored-By: default avatarSantosh-Gupta <San.Gupta.ML@gmail.com>
parent c8f27121
......@@ -538,6 +538,7 @@ optimizer = AdamW(model.parameters(), lr=lr, correct_bias=False) # To reproduce
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_total=num_total_steps) # PyTorch scheduler
### and used like this:
for batch in train_data:
model.train()
loss = model(batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment