Unverified Commit e4685832 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[engine] fixed bug in gradient accumulation dataloader to keep the last step (#1030)

parent 32291dd7
......@@ -145,6 +145,7 @@ class GradAccumDataloader:
def __next__(self) -> Union[Tensor, Tuple[Tensor]]:
if self._cur_step < self.steps_per_epoch:
self._cur_step += 1
data = next(self._dataiter)
if self._cur_step == self.steps_per_epoch and self.consume_remain_data:
# this is to handle non standard pytorch dataloader
......@@ -154,7 +155,7 @@ class GradAccumDataloader:
_ = next(self._dataiter)
except StopIteration:
break
return next(self._dataiter)
return data
else:
raise StopIteration
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment