Unverified Commit c92f84fc authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[tensor] distributed checkpointing for parameters (#1240)

parent 49114d8d
...@@ -143,10 +143,10 @@ class ColoTensor(torch.Tensor): ...@@ -143,10 +143,10 @@ class ColoTensor(torch.Tensor):
self._redistribute(dist_spec) self._redistribute(dist_spec)
def set_tensor_spec(self, dist_spec, compute_spec): def set_tensor_spec(self, dist_spec, compute_spec):
if dist_spec: if dist_spec is not None:
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)}" assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)}"
self.set_dist_spec(dist_spec) self.set_dist_spec(dist_spec)
if compute_spec: if compute_spec is not None:
self.compute_spec = compute_spec self.compute_spec = compute_spec
def has_compute_pattern(self, compute_pattern): def has_compute_pattern(self, compute_pattern):
......
from enum import Enum from enum import Enum
from typing import List from typing import List, Optional
__all__ = ['replicate', 'shard'] __all__ = ['replicate', 'shard']
......
import torch import torch
import torch.nn as nn
import torch.distributed as dist import torch.distributed as dist
import collections from colossalai.tensor import ColoTensor, DistSpecManager
import inspect
from colossalai.utils.model.colo_init_context import colo_state_dict
def filter_dict(dict_to_filter, thing_with_kwargs):
sig = inspect.signature(thing_with_kwargs)
filter_keys = [param.name for param in sig.parameters.values() if param.kind == param.POSITIONAL_OR_KEYWORD]
filter_dict = {}
for filter_key in filter_keys:
if filter_key in dict_to_filter:
filter_dict[filter_key] = dict_to_filter[filter_key]
return filter_dict
def save_checkpoint(dire: str, def save_checkpoint(dire: str,
...@@ -32,21 +19,30 @@ def save_checkpoint(dire: str, ...@@ -32,21 +19,30 @@ def save_checkpoint(dire: str,
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None. optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
""" """
model_state = {'epoch': epoch, 'model': model.state_dict()}
mapping = dict()
new_dict = dict()
# save the dist context about the tensors in a new dict, while still maintain the original dict.
for k, v in model.state_dict().items():
if isinstance(v, ColoTensor):
mapping[k] = (v.dist_spec, v.compute_spec)
new_dict[k] = v.to_replicate().detach()
if dist.get_rank() == 0: if dist.get_rank() == 0:
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch)) for k, v in new_dict.items():
if isinstance(v, ColoTensor):
assert v.is_replicate()
# TODO() If use tensor parallelism, optim_states contain SHARD ColoTensors. model_state = {'epoch': epoch, 'model': new_dict}
# 1. convert SHARD ColoTensor to REPLICATE torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))
# only rank 0 saves the REPLICATE tensors.
optim_state = {'epoch': epoch, 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict()}
torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank())) # delete the new dict
del new_dict
def load_checkpoint(dire, def load_checkpoint(dire,
epoch: int, epoch: int,
rank: int,
model: torch.nn.Module, model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None, optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
...@@ -62,19 +58,18 @@ def load_checkpoint(dire, ...@@ -62,19 +58,18 @@ def load_checkpoint(dire,
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None. optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
""" """
mapping = dict()
for k, v in model.named_parameters():
if isinstance(v, ColoTensor):
mapping[k] = (v.dist_spec, v.compute_spec)
v.to_replicate_()
model_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch)) model_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch))
model_state['model'] = collections.OrderedDict([(k.split('.', 1)[1], v) for k, v in model_state['model'].items()])
model.load_state_dict(model_state['model']) model.load_state_dict(model_state['model'])
optim_state = torch.load(dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, rank))
optimizer.load_state_dict(optim_state['optimizer']) # reset tensors to original dist spec.
lr_scheduler_dict = optim_state['lr_scheduler'] with DistSpecManager.no_grad():
if 'after_scheduler_type' in lr_scheduler_dict: for k, v in model.named_parameters():
after_scheduler_type = lr_scheduler_dict.pop('after_scheduler_type') if isinstance(v, ColoTensor):
after_scheduler_dict = lr_scheduler_dict.pop('after_scheduler_dict') v.set_tensor_spec(*mapping[k])
reload_scheduler = getattr(torch.optim.lr_scheduler, after_scheduler_type)
filtered_dict = filter_dict(after_scheduler_dict, reload_scheduler)
lr_scheduler_dict['after_scheduler'] = reload_scheduler(
optimizer,
**filtered_dict,
)
lr_scheduler.load_state_dict(lr_scheduler_dict)
from .utils import InsertPostInitMethodToModuleSubClasses from .utils import InsertPostInitMethodToModuleSubClasses
import torch import torch
from colossalai.tensor import ColoTensor, ColoParameter, distspec, ProcessGroup, ReplicaSpec from colossalai.tensor import ColoTensor, ColoParameter
from colossalai.nn.parallel.layers import register_colo_module, \ from colossalai.nn.parallel.layers import register_colo_module, \
ColoLinear, ColoEmbedding ColoLinear, ColoEmbedding
from copy import copy
from torch import nn from torch import nn
from typing import Iterator, Tuple, Union from typing import Iterator, Tuple, Union
from functools import partialmethod
# find named_params includes replica # find named_params includes replica
...@@ -34,47 +31,6 @@ def ColoModulize(module): ...@@ -34,47 +31,6 @@ def ColoModulize(module):
module._colo_visited = True module._colo_visited = True
def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_dict_func=None):
# build param to spec mapping
mapping1 = dict()
mapping2 = dict()
mapping3 = dict()
# gather all params
has_dist_parameter = False
with torch.no_grad():
for param in self.parameters():
if isinstance(param, ColoParameter):
has_dist_parameter = True
mapping1[id(param)] = copy(param.dist_spec)
mapping2[id(param)] = copy(param.compute_spec)
# TODO(jiaruifang) fixme, we should elegently handle the default PG in init context
if param.get_process_group() is None:
param.process_group = ProcessGroup()
param.set_dist_spec(distspec.replicate())
mapping3[id(param)] = param.get_process_group()
param.process_group = None
# TODO: fix when keep_vars = True
# when keep_vars = False, the state_dict_func will call detach to create
# new tensors, but when keep_vars = True, the recovery of spec will be reflected
# in the `ret`, such that the final state dict will still contain process group,
# raising exception as it is not serializable
assert not (keep_vars and has_dist_parameter), 'keep_vars cannot be True when there are distributed ColoParameters.'
ret = state_dict_func(self, destination, prefix, keep_vars)
# recover
with torch.no_grad():
for param in self.parameters():
param_id = id(param)
if param_id in mapping1:
dist_spec = mapping1[id(param)]
compute_spec = mapping2[id(param)]
param.process_group = mapping3[id(param)]
param.set_tensor_spec(dist_spec, compute_spec)
return ret
class ColoInitContext(InsertPostInitMethodToModuleSubClasses): class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')): def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')):
...@@ -94,8 +50,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): ...@@ -94,8 +50,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
register_colo_module(torch.nn.Embedding, ColoEmbedding()) register_colo_module(torch.nn.Embedding, ColoEmbedding())
def _pre_context_exec(self): def _pre_context_exec(self):
self.state_dict_func = nn.Module.state_dict pass
nn.Module.state_dict = partialmethod(colo_state_dict, state_dict_func=self.state_dict_func)
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
""" """
......
...@@ -122,6 +122,19 @@ def _run_redistributed(world_size): ...@@ -122,6 +122,19 @@ def _run_redistributed(world_size):
assert t1.is_replicate() assert t1.is_replicate()
def _run_set_tensor_spec(world_size):
if world_size != 4:
return
pg = ProcessGroup(tp_degree=2, dp_degree=2)
spec1 = ColoTensorSpec(pg)
t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1)
dist_spec2 = (ShardSpec([-1], [pg.tp_world_size()]), None)
assert t1.is_replicate()
t1.set_dist_spec(*dist_spec2)
assert t1.is_shard_1dcol()
def run_dist_tests(rank, world_size, port): def run_dist_tests(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
_run_tensor_shard_init(world_size) _run_tensor_shard_init(world_size)
...@@ -132,6 +145,7 @@ def run_dist_tests(rank, world_size, port): ...@@ -132,6 +145,7 @@ def run_dist_tests(rank, world_size, port):
_run_operand(world_size) _run_operand(world_size)
_run_wrapped_tensor_func() _run_wrapped_tensor_func()
_run_redistributed(world_size) _run_redistributed(world_size)
_run_set_tensor_spec(world_size)
@pytest.mark.dist @pytest.mark.dist
......
...@@ -3,7 +3,6 @@ import os, shutil ...@@ -3,7 +3,6 @@ import os, shutil
import torch import torch
import torch.nn as nn import torch.nn as nn
import pytest import pytest
import copy
from functools import partial from functools import partial
import torch.multiprocessing as mp import torch.multiprocessing as mp
...@@ -104,7 +103,7 @@ def remove(path): ...@@ -104,7 +103,7 @@ def remove(path):
raise ValueError("file {} is not a file or dir.".format(path)) raise ValueError("file {} is not a file or dir.".format(path))
def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg): def run_checkpoint(init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
num_epoch = 5 num_epoch = 5
warmup_epoch = 2 warmup_epoch = 2
...@@ -112,31 +111,28 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg): ...@@ -112,31 +111,28 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
feature = 32 feature = 32
category = 16 category = 16
train_dataloader = DummyDataLoader(batch, category, feature, length=16)
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = MLP(feature, category) model = MLP(feature, category)
with ColoInitContext(device=get_current_device()):
model_reload = MLP(feature, category) model_reload = MLP(feature, category)
model_ref = MLP(feature, category)
model = model.cuda() model = model.cuda()
model_reload = model_reload.cuda() model_reload = model_reload.cuda()
model_ref = model_ref.cuda()
if use_ddp: if use_ddp:
model = ColoDDP(model, pg) model = ColoDDP(model, pg)
model_reload = ColoDDP(model_reload, pg) model_reload = ColoDDP(model_reload, pg)
model_ref = ColoDDP(model_ref, pg)
init_spec_func(model, pg) init_spec_func(model, pg)
init_spec_func(model_ref, pg) if use_mp_reload:
init_spec_func(model_reload, pg)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
optimizer_reload = torch.optim.Adam(model_reload.parameters(), optimizer_reload = torch.optim.Adam(model_reload.parameters(),
lr=0.001, lr=0.001,
betas=(0.9, 0.999), betas=(0.9, 0.999),
eps=1e-08, eps=1e-08,
weight_decay=0) weight_decay=0)
optimizer_ref = torch.optim.Adam(model_ref.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
lr_scheduler = None lr_scheduler = None
if test_scheduler == 'colossalai_cosine_warmup': if test_scheduler == 'colossalai_cosine_warmup':
...@@ -154,91 +150,48 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg): ...@@ -154,91 +150,48 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
else: else:
raise TypeError(f"{test_scheduler} is invalid") raise TypeError(f"{test_scheduler} is invalid")
for epoch in range(0, num_epoch): save_checkpoint('./checkpoint', 0, model, optimizer, lr_scheduler)
if epoch <= test_epoch: dist.barrier()
for i, image_dict in enumerate(train_dataloader): load_checkpoint('./checkpoint', 0, model_reload, optimizer_reload, lr_scheduler_reload)
if use_ddp:
model.zero_grad() # Since model is sharded, we merge them before param checking.
else: for p in model.parameters():
optimizer.zero_grad() p.to_replicate_()
logits = model(image_dict['pixel_values'])
loss = criterion(logits, image_dict['label']) for p in model_reload.parameters():
if use_ddp: p.to_replicate_()
model.backward(loss)
else: check_param_equal(model, model_reload)
loss.backward()
optimizer.step()
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
if epoch == test_epoch:
for ref_p, p in zip(model_ref.parameters(), model.parameters()):
ref_p.data.copy_(p)
optimizer_ref = copy.deepcopy(optimizer)
check_param_equal(model, model_ref)
save_checkpoint('./checkpoint', epoch, model, optimizer, lr_scheduler)
dist.barrier()
else:
if epoch == test_epoch + 1:
load_checkpoint('./checkpoint', test_epoch, dist.get_rank(), model_reload, optimizer_reload,
lr_scheduler_reload)
init_spec_func(model_reload, pg)
for i, image_dict in enumerate(train_dataloader):
if use_ddp:
model_ref.zero_grad()
model_reload.zero_grad()
else:
optimizer_ref.zero_grad()
optimizer_reload.zero_grad()
logits_ref = model_ref(image_dict['pixel_values'])
logits_reload = model_reload(image_dict['pixel_values'])
loss_ref = criterion(logits_ref, image_dict['label'])
loss_reload = criterion(logits_reload, image_dict['label'])
if use_ddp:
model_ref.backward(loss_ref)
model_reload.backward(loss_reload)
else:
loss_ref.backward()
loss_reload.backward()
optimizer_ref.step()
optimizer_reload.step()
lr_scheduler.step()
check_param_equal(model_ref, model_reload)
def run_dist(rank, world_size, port, use_ddp, test_epoch, test_scheduler):
if use_ddp and world_size == 1: if use_ddp and world_size == 1:
return return
tp_world_size = world_size // 2 if use_ddp else world_size tp_world_size = world_size // 2 if use_ddp else world_size
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
run_checkpoint(init_1d_row_for_linear_weight_spec, run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, use_mp_reload, test_scheduler=test_scheduler, pg=pg)
use_ddp,
test_epoch=test_epoch,
test_scheduler=test_scheduler,
pg=pg)
@pytest.mark.skip
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [4]) @pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('use_ddp', [True]) @pytest.mark.parametrize('use_ddp', [True, False])
@pytest.mark.parametrize('test_epoch', [1, 2, 3]) @pytest.mark.parametrize('use_mp_reload', [True, False])
@pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda']) @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_checkpoint(world_size, use_ddp, test_epoch, test_scheduler): def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler):
if not os.path.isdir('./checkpoint'): if not os.path.isdir('./checkpoint'):
os.mkdir('./checkpoint') os.mkdir('./checkpoint')
run_func = partial(run_dist, run_func = partial(run_dist,
world_size=world_size, world_size=world_size,
port=free_port(), port=free_port(),
use_ddp=use_ddp, use_ddp=use_ddp,
test_epoch=test_epoch, use_mp_reload=use_mp_reload,
test_scheduler=test_scheduler) test_scheduler=test_scheduler)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
remove('./checkpoint') remove('./checkpoint')
if __name__ == '__main__': if __name__ == '__main__':
test_checkpoint(4, True, 1, "colossalai_cosine_warmup") test_checkpoint(2, True, False, "torch_cosine")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment