Unverified Commit f40b7189 authored by flybird1111's avatar flybird1111 Committed by GitHub
Browse files

[doc] Fix gradient accumulation doc. (#4349)

* [doc] fix gradient accumulation doc

* [doc] fix gradient accumulation doc
parent 38b792aa
......@@ -103,10 +103,12 @@ for idx, (img, label) in enumerate(train_dataloader):
with sync_context:
output = model(img)
train_loss = criterion(output, label)
train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, optimizer)
else:
output = model(img)
train_loss = criterion(output, label)
train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, optimizer)
optimizer.step()
optimizer.zero_grad()
......
......@@ -106,10 +106,12 @@ for idx, (img, label) in enumerate(train_dataloader):
with sync_context:
output = model(img)
train_loss = criterion(output, label)
train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, optimizer)
else:
output = model(img)
train_loss = criterion(output, label)
train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, optimizer)
optimizer.step()
optimizer.zero_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment