Commit 9f64f5f4 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

working: checkpoint save/load.

parent 82491e4b
......@@ -201,6 +201,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for model_param in group_range["params"]:
assert model_param.requires_grad
model_index, dtype = param_gbuf_map[model_param]
gbuf_range = model_gbuf_ranges[model_index][dtype]
param_range = gbuf_range["param_map"][model_param]["param"]
......@@ -310,50 +312,44 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
return None
# >>>
# def state_dict(self):
# state_dict = {}
# state_dict['optimizer'] = self.optimizer.state_dict()
# if self.grad_scaler:
# state_dict['grad_scaler'] = self.grad_scaler.state_dict()
# state_dict['groups'] = [g['params'] for g in self.optimizer.param_groups]
# return state_dict
def state_dict(self):
raise Exception("fix me.")
# <<<
# >>>
# def load_state_dict(self, state_dict):
# # Optimizer.
# optimizer_key = 'optimizer'
# if optimizer_key not in state_dict:
# optimizer_key = 'optimizer_state_dict'
# print_rank_0('***WARNING*** loading optimizer from '
# 'an old checkpoint ...')
# self.optimizer.load_state_dict(state_dict[optimizer_key])
# # Grad scaler.
# if 'grad_scaler' not in state_dict:
# print_rank_0('***WARNING*** found an old checkpoint, will not '
# 'load grad scaler ...')
# else:
# if self.grad_scaler:
# self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
# else:
# print_rank_0('***WARNING*** fould the grad scaler in the '
# 'checkpoint but it is None in the class. '
# 'Skipping loading grad scaler ...')
# # Copy data for the main params.
# current_groups = [ g["params"] for g in self.optimizer.param_groups ]
# assert "groups" in state_dict, "key 'groups' not in state_dict."
# for current_group, saved_group in zip(current_groups, state_dict["groups"]):
# for current_param, saved_param in zip(current_group, saved_group):
# current_param.data.copy_(saved_param.data)
state_dict = {}
state_dict['optimizer'] = self.optimizer.state_dict()
if self.grad_scaler:
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
state_dict['shard_fp32_from_float16_groups'] = \
self.shard_fp32_from_float16_groups
return state_dict
def load_state_dict(self, state_dict):
raise Exception("hi.")
# <<<
# Optimizer.
optimizer_key = 'optimizer'
if optimizer_key not in state_dict:
optimizer_key = 'optimizer_state_dict'
print_rank_0('***WARNING*** loading optimizer from '
'an old checkpoint ...')
self.optimizer.load_state_dict(state_dict[optimizer_key])
# Grad scaler.
if 'grad_scaler' not in state_dict:
print_rank_0('***WARNING*** found an old checkpoint, will not '
'load grad scaler ...')
else:
if self.grad_scaler:
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
else:
print_rank_0('***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...')
# Copy data for the main params.
for current_group, saved_group in zip(
self.shard_fp32_from_float16_groups,
state_dict["shard_fp32_from_float16_groups"]):
for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data)
def zero_grad(self, set_to_none=True):
......@@ -362,11 +358,19 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
fp32_from_float16_groups as a memory optimization to reduce
fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point."""
# >>>
# params = [ p for g in self.shard_fp32_groups for p in g ]
# pax(0, {
# "shard_fp32_groups" : self.shard_fp32_groups,
# "params" : params,
# "grads" : [ p.grad for p in params ],
# })
# <<<
for groups in (
self.full_float16_groups,
self.full_fp32_groups,
self.shard_float16_groups, # grad empty/unused here?
self.shard_fp32_groups,
self.shard_fp32_groups, # throws grad-access warning
self.shard_fp32_from_float16_groups):
for group in groups:
_zero_grad_group_helper(group, set_to_none)
......
......@@ -52,6 +52,10 @@ from megatron.schedules import get_forward_backward_func
from megatron.utils import report_memory
from megatron.model.vision.knn_monitor import compute_feature_bank
# >>>
from lutil import pax, tp, print_seq
# <<<
def print_datetime(string):
"""Note that this call will sync across all ranks."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment