Commit 70dcdf17 authored by Guolin Ke's avatar Guolin Ke
Browse files

bug fix for rngstate with update_freq > 1.

parent 0a79672a
......@@ -399,13 +399,13 @@ class Trainer(object):
if errors.missing_keys:
logger.warning(
"Error in loading model state, missing_keys " +
str(errors.missing_keys)
"Error in loading model state, missing_keys "
+ str(errors.missing_keys)
)
if errors.unexpected_keys:
logger.warning(
"Error in loading model state, unexpected_keys " +
str(errors.unexpected_keys)
"Error in loading model state, unexpected_keys "
+ str(errors.unexpected_keys)
)
if utils.has_parameters(self.get_loss()):
self.get_loss().load_state_dict(state["loss"], strict=True)
......@@ -607,7 +607,10 @@ class Trainer(object):
with maybe_no_sync():
# use different seed for different rank in training, otherwise the dropout will be the same in different workers.
with utils.torch_seed(
self.args.seed, self.get_num_updates(), self.data_parallel_rank
self.args.seed,
self.get_num_updates(),
i,
self.data_parallel_rank,
):
# forward and backward
loss, sample_size_i, logging_output = self.task.train_step(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment