Commit 4c06a4fd authored by Guolin Ke's avatar Guolin Ke
Browse files

refine rngstate in trainer

parent 70dcdf17
...@@ -717,7 +717,7 @@ class Trainer(object): ...@@ -717,7 +717,7 @@ class Trainer(object):
with torch.autograd.profiler.record_function("optimizer"): with torch.autograd.profiler.record_function("optimizer"):
# fixed the seed in case for the stochastic rounding in different ranks # fixed the seed in case for the stochastic rounding in different ranks
with utils.torch_seed(self.args.seed, self.get_num_updates(), -1): with utils.torch_seed(self.args.seed, self.get_num_updates()):
# take an optimization step # take an optimization step
self.task.optimizer_step( self.task.optimizer_step(
self.optimizer, self.optimizer,
...@@ -733,8 +733,14 @@ class Trainer(object): ...@@ -733,8 +733,14 @@ class Trainer(object):
# out where it fails # out where it fails
self.zero_grad() self.zero_grad()
with NanDetector(self.get_model()): with NanDetector(self.get_model()):
for _, sample in enumerate(samples): for i, sample in enumerate(samples):
sample, _ = self._prepare_sample(sample) sample, _ = self._prepare_sample(sample)
with utils.torch_seed(
self.args.seed,
self.get_num_updates(),
i,
self.data_parallel_rank,
):
self.task.train_step( self.task.train_step(
sample, sample,
self.model, self.model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment