Commit 70dcdf17 authored by Guolin Ke's avatar Guolin Ke
Browse files

bug fix for rngstate with update_freq > 1.

parent 0a79672a
...@@ -399,13 +399,13 @@ class Trainer(object): ...@@ -399,13 +399,13 @@ class Trainer(object):
if errors.missing_keys: if errors.missing_keys:
logger.warning( logger.warning(
"Error in loading model state, missing_keys " + "Error in loading model state, missing_keys "
str(errors.missing_keys) + str(errors.missing_keys)
) )
if errors.unexpected_keys: if errors.unexpected_keys:
logger.warning( logger.warning(
"Error in loading model state, unexpected_keys " + "Error in loading model state, unexpected_keys "
str(errors.unexpected_keys) + str(errors.unexpected_keys)
) )
if utils.has_parameters(self.get_loss()): if utils.has_parameters(self.get_loss()):
self.get_loss().load_state_dict(state["loss"], strict=True) self.get_loss().load_state_dict(state["loss"], strict=True)
...@@ -607,7 +607,10 @@ class Trainer(object): ...@@ -607,7 +607,10 @@ class Trainer(object):
with maybe_no_sync(): with maybe_no_sync():
# use different seed for different rank in training, otherwise the dropout will be the same in different workers. # use different seed for different rank in training, otherwise the dropout will be the same in different workers.
with utils.torch_seed( with utils.torch_seed(
self.args.seed, self.get_num_updates(), self.data_parallel_rank self.args.seed,
self.get_num_updates(),
i,
self.data_parallel_rank,
): ):
# forward and backward # forward and backward
loss, sample_size_i, logging_output = self.task.train_step( loss, sample_size_i, logging_output = self.task.train_step(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment