"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "6070b55443d14ae480a0f359f3aff45308e7341d"
Commit cb76c1dd authored by thomwolf's avatar thomwolf
Browse files

add model.zero_grad()

parent a4086c5d
...@@ -531,6 +531,7 @@ def main(): ...@@ -531,6 +531,7 @@ def main():
loss, _ = model(input_ids, segment_ids, input_mask, label_ids) loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
total_tr_loss += loss.item() total_tr_loss += loss.item()
nb_tr_examples += input_ids.size(0) nb_tr_examples += input_ids.size(0)
model.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
global_step += 1 global_step += 1
......
...@@ -856,6 +856,7 @@ def main(): ...@@ -856,6 +856,7 @@ def main():
logger.info("HHHHH Forward") logger.info("HHHHH Forward")
loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions) loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
model.zero_grad()
logger.info("HHHHH Backward") logger.info("HHHHH Backward")
loss.backward() loss.backward()
logger.info("HHHHH Loading data") logger.info("HHHHH Loading data")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment