Unverified Commit 9e4c6449 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[checkpoint] add ColoOptimizer checkpointing (#1316)

parent 7c2634f4
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
......
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.tensor import ColoTensor, DistSpecManager from colossalai.tensor import ColoTensor, DistSpecManager
from colossalai.nn.optimizer import ColossalaiOptimizer
from copy import copy
from typing import Optional
def save_checkpoint(dire: str, def save_checkpoint(dire: str,
epoch: int, epoch: int,
model: torch.nn.Module, model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None, optimizer: Optional[ColossalaiOptimizer] = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
*args, *args,
**kwargs): **kwargs):
...@@ -16,7 +19,7 @@ def save_checkpoint(dire: str, ...@@ -16,7 +19,7 @@ def save_checkpoint(dire: str,
dire (str): directory to save the checkpoint files. dire (str): directory to save the checkpoint files.
epoch (int): the number of epoch epoch (int): the number of epoch
model (torch.nn.Module): a torch module initialized by ColoInitContext model (torch.nn.Module): a torch module initialized by ColoInitContext
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None. optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
""" """
...@@ -41,11 +44,21 @@ def save_checkpoint(dire: str, ...@@ -41,11 +44,21 @@ def save_checkpoint(dire: str,
# delete the new dict # delete the new dict
del new_dict del new_dict
optim_state_copy = copy(optimizer.state_dict())
for k, v in optim_state_copy['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
t.to_replicate_()
if dist.get_rank() == 0:
model_state = {'epoch': epoch, 'optim': optim_state_copy}
torch.save(model_state, dire + '/epoch_{}_optim.pth'.format(epoch))
del optim_state_copy
def load_checkpoint(dire, def load_checkpoint(dire,
epoch: int, epoch: int,
model: torch.nn.Module, model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None, optimizer: Optional[ColossalaiOptimizer] = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
*args, *args,
**kwargs): **kwargs):
...@@ -56,7 +69,7 @@ def load_checkpoint(dire, ...@@ -56,7 +69,7 @@ def load_checkpoint(dire,
epoch (int): _description_ epoch (int): _description_
rank (int): _description_ rank (int): _description_
model (torch.nn.Module): _description_ model (torch.nn.Module): _description_
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None. optimizer (ColossalaiOptimizer, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
""" """
...@@ -74,3 +87,24 @@ def load_checkpoint(dire, ...@@ -74,3 +87,24 @@ def load_checkpoint(dire,
for k, v in model.state_dict().items(): for k, v in model.state_dict().items():
if isinstance(v, ColoTensor): if isinstance(v, ColoTensor):
v.set_tensor_spec(*mapping[k]) v.set_tensor_spec(*mapping[k])
del mapping
mapping = dict()
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
mapping[(k, n)] = (t.dist_spec, t.compute_spec)
t.to_replicate_()
colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
optimizer.load_state_dict(colo_checkpoint['optim'])
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
# skip key not in mapping.
# For Adam, if it dose not execute step() once, there will be not exp_avg and exp_avg_sq in optimizer
if (k, n) not in mapping:
continue
t.set_tensor_spec(*mapping[(k, n)])
...@@ -77,6 +77,18 @@ def remove(path): ...@@ -77,6 +77,18 @@ def remove(path):
raise ValueError("file {} is not a file or dir.".format(path)) raise ValueError("file {} is not a file or dir.".format(path))
def compare_optims(optim1, optim2):
state1 = optim1.state_dict()['state']
state2 = optim2.state_dict()['state']
for k, p1 in state1.items():
if k not in state2:
continue
p2 = state2[k]
if isinstance(p1, ColoTensor):
assert isinstance(p2, ColoTensor)
assert torch.allclose(p1.to_replicate_(), p2.to_replicate_(), rtol=1e-3, atol=1e-1)
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg): def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
...@@ -117,7 +129,10 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch ...@@ -117,7 +129,10 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
model_reload = model_reload.cuda() model_reload = model_reload.cuda()
model_reload.train() model_reload.train()
colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.named_parameters(), r=0.1)) opt_class = torch.optim.Adam
colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))
run_reload = False
for i, (data, label) in enumerate(train_dataloader): for i, (data, label) in enumerate(train_dataloader):
...@@ -130,22 +145,35 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch ...@@ -130,22 +145,35 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
# Bcast rank0 data to all processes # Bcast rank0 data to all processes
if criterion: if criterion:
output = model(data) output = model(data)
output_reload = model_reload(data)
loss = criterion(output, label) loss = criterion(output, label)
loss_reload = criterion(output_reload, label)
else: else:
output = model(data, label) loss = model(data, label)
loss = output loss_reload = model_reload(data, label)
loss.backward() loss.backward()
colo_optimizer.step() loss_reload.backward()
if run_reload:
colo_optimizer_reload.zero_grad()
if criterion:
output_reload = model_reload(data)
loss_reload = criterion(output_reload, label)
else:
loss_reload = model_reload(data, label)
loss_reload.backward()
colo_optimizer_reload.step()
if i > 2: if i > 2:
break break
if not os.path.isdir('./checkpoint') and rank == 0: if not os.path.isdir('./checkpoint') and rank == 0:
os.mkdir('./checkpoint') os.mkdir('./checkpoint')
save_checkpoint('./checkpoint', 0, model, None, None) save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
dist.barrier()
load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)
dist.barrier() dist.barrier()
load_checkpoint('./checkpoint', 0, model_reload, None, None)
# Since model is sharded, we merge them before param checking. # Since model is sharded, we merge them before param checking.
for p in model.parameters(): for p in model.parameters():
...@@ -155,7 +183,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch ...@@ -155,7 +183,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
p.to_replicate_() p.to_replicate_()
check_param_equal(model, model_reload) check_param_equal(model, model_reload)
compare_optims(colo_optimizer, colo_optimizer_reload)
if rank == 0: if rank == 0:
remove('./checkpoint') remove('./checkpoint')
...@@ -163,7 +191,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch ...@@ -163,7 +191,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler): def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
for model_name in ['bert', 'simple_net']: for model_name in ['simple_net', 'bert']:
_run_checkpoint(model_name, _run_checkpoint(model_name,
init_1d_row_for_linear_weight_spec, init_1d_row_for_linear_weight_spec,
use_ddp, use_ddp,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment