Unverified Commit afe73aed authored by yhuang's avatar yhuang Committed by GitHub
Browse files

Fix the behavior of collecting 'num_input_tokens_seen' (#29099)

fix the behavior of collecting 'num_input_tokens_seen'

See https://github.com/huggingface/transformers/issues/28791 for more details.
parent 39114c03
...@@ -2097,7 +2097,12 @@ class Trainer: ...@@ -2097,7 +2097,12 @@ class Trainer:
"a `main_input_name` attribute to the model class you are using." "a `main_input_name` attribute to the model class you are using."
) )
else: else:
self.state.num_input_tokens_seen += self.accelerator.gather(inputs[main_input_name]).numel() input_device = inputs[main_input_name].device
self.state.num_input_tokens_seen += torch.sum(
self.accelerator.gather(
torch.tensor(inputs[main_input_name].numel(), device=input_device, dtype=torch.int64)
)
).item()
if rng_to_sync: if rng_to_sync:
self._load_rng_state(resume_from_checkpoint) self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False rng_to_sync = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment