"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a541d97477a8901e37e5f850f2cd707ffc82445b"
Unverified Commit 1c801d65 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Enforce same behavior as PyTorch 2.0 for older versions (#22136)

parent e16cbe88
...@@ -1811,7 +1811,7 @@ class Trainer: ...@@ -1811,7 +1811,7 @@ class Trainer:
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0 self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step self._globalstep_last_logged = self.state.global_step
model.zero_grad() model.zero_grad(set_to_none=True)
self.control = self.callback_handler.on_train_begin(args, self.state, self.control) self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
...@@ -1967,7 +1967,7 @@ class Trainer: ...@@ -1967,7 +1967,7 @@ class Trainer:
if optimizer_was_run and not self.deepspeed: if optimizer_was_run and not self.deepspeed:
self.lr_scheduler.step() self.lr_scheduler.step()
model.zero_grad() model.zero_grad(set_to_none=True)
self.state.global_step += 1 self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control) self.control = self.callback_handler.on_step_end(args, self.state, self.control)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment