Commit 481f5c4f authored by Rick Ho's avatar Rick Ho
Browse files

add functions to support checkpointing in megatron ddp

parent 79ccb7b6
......@@ -49,3 +49,12 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
mp_group=mpu.get_model_parallel_group(),
dp_group=mpu.get_data_parallel_group()
)
def state_dict(self, *args, **kwargs):
return self.module.state_dict(*args, **kwargs)
def state_dict_for_save_checkpoint(self, *args, **kwargs):
return self.module.state_dict_for_save_checkpoint(*args, **kwargs)
def load_state_dict(self, *args, **kwargs):
return self.module.load_state_dict(*args, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment