Unverified Commit 3ef3791a authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[checkpoint] add test for bert and hotfix save bugs (#1297)

parent bd71e2a8
...@@ -28,7 +28,8 @@ def save_checkpoint(dire: str, ...@@ -28,7 +28,8 @@ def save_checkpoint(dire: str,
if isinstance(v, ColoTensor): if isinstance(v, ColoTensor):
mapping[k] = (v.dist_spec, v.compute_spec) mapping[k] = (v.dist_spec, v.compute_spec)
new_dict[k] = v.to_replicate().detach() new_dict[k] = v.to_replicate().detach()
else:
new_dict[k] = v
if dist.get_rank() == 0: if dist.get_rank() == 0:
for k, v in new_dict.items(): for k, v in new_dict.items():
if isinstance(v, ColoTensor): if isinstance(v, ColoTensor):
...@@ -60,7 +61,7 @@ def load_checkpoint(dire, ...@@ -60,7 +61,7 @@ def load_checkpoint(dire,
""" """
mapping = dict() mapping = dict()
for k, v in model.named_parameters(): for k, v in model.state_dict().items():
if isinstance(v, ColoTensor): if isinstance(v, ColoTensor):
mapping[k] = (v.dist_spec, v.compute_spec) mapping[k] = (v.dist_spec, v.compute_spec)
v.to_replicate_() v.to_replicate_()
...@@ -70,6 +71,6 @@ def load_checkpoint(dire, ...@@ -70,6 +71,6 @@ def load_checkpoint(dire,
# reset tensors to original dist spec. # reset tensors to original dist spec.
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for k, v in model.named_parameters(): for k, v in model.state_dict().items():
if isinstance(v, ColoTensor): if isinstance(v, ColoTensor):
v.set_tensor_spec(*mapping[k]) v.set_tensor_spec(*mapping[k])
from abc import ABC, abstractmethod
import os, shutil import os, shutil
import torch import torch
import torch.nn as nn
import pytest import pytest
from functools import partial from functools import partial
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.distributed as dist import torch.distributed as dist
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import MultiplicativeLR from torch.optim.lr_scheduler import MultiplicativeLR
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
import colossalai import colossalai
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ShardSpec, ProcessGroup from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, DistSpecManager, ReplicaSpec
from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import ColoOptimizer
class DummyDataGenerator(ABC):
def __init__(self, length=10):
self.length = length
@abstractmethod
def generate(self):
pass
def __iter__(self): from tests.components_to_test.registry import non_distributed_component_funcs
self.step = 0
return self
def __next__(self):
if self.step < self.length:
self.step += 1
return self.generate()
else:
raise StopIteration
def __len__(self):
return self.length
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
weight.set_process_group(pg)
weight.set_tensor_spec(*spec)
class DummyDataLoader(DummyDataGenerator):
def __init__(self, batch_size, category, feature_size, length=10):
super().__init__(length)
self.batch_size = batch_size
self.category = category
self.feature_size = feature_size
def generate(self): def init_1d_col_linear(weight, pg):
image_dict = {} spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
image_dict['pixel_values'] = torch.rand(self.batch_size, self.feature_size, device=get_current_device()) * 2 - 1 weight.set_process_group(pg)
image_dict['label'] = torch.randint(self.category, (self.batch_size,), weight.set_tensor_spec(*spec)
dtype=torch.int64,
device=get_current_device())
return image_dict
class MLP(nn.Module): def init_1d_row_embedding(weight, pg):
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
weight.set_process_group(pg)
weight.set_tensor_spec(*spec)
def __init__(self, in_features, out_features, hidden_features=None):
super().__init__()
if hidden_features is None:
hidden_features = out_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.fc2 = nn.Linear(hidden_features, out_features)
self.activation = nn.ReLU()
def forward(self, x): def init_1d_col_embedding(weight, pg):
x = self.fc1(x) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
x = self.activation(x) weight.set_process_group(pg)
x = self.fc2(x) weight.set_tensor_spec(*spec)
return x
def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup): def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): for name, p in model.named_parameters():
for n, p in model.named_parameters(): if not isinstance(p, ColoTensor):
if 'weight' in n: continue
p.set_process_group(pg) if 'embed' in name and 'weight' in name:
p.set_tensor_spec(*spec) init_1d_col_embedding(p, pg)
if 'proj1' in name and ('weight' in name or 'bias' in name):
init_1d_col_linear(p, pg)
if 'proj2' in name and 'weight' in name:
init_1d_row_linear(p, pg)
if 'classifier' in name and ('weight' in name or 'bias' in name):
init_1d_col_linear(p, pg)
def check_param_equal(model, torch_model): def check_param_equal(model, torch_model):
...@@ -103,56 +77,75 @@ def remove(path): ...@@ -103,56 +77,75 @@ def remove(path):
raise ValueError("file {} is not a file or dir.".format(path)) raise ValueError("file {} is not a file or dir.".format(path))
def run_checkpoint(init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg): def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
num_epoch = 5 get_components_func = non_distributed_component_funcs.get_callable(model_name)
warmup_epoch = 2 model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
batch = 3 rank = torch.distributed.get_rank()
feature = 32 world_size = torch.distributed.get_world_size()
category = 16
# set_seed(1)
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = MLP(feature, category) model = model_builder(checkpoint=True)
model_reload = model_builder(checkpoint=True)
with ColoInitContext(device=get_current_device()): if use_mp_reload:
model_reload = MLP(feature, category) if 'bert' == model_name:
for name, p in model.named_parameters():
if not isinstance(p, ColoTensor):
continue
# num_class = type_vocab_size = 2 | (8, 2)
if 'classifier' in name and 'weight' in name:
init_1d_row_linear(p, pg)
# num_class = vocab_size = 30524 | (30524, 8)
elif 'word_embeddings' in name and 'weight' in name:
init_1d_row_embedding(p, pg)
# num_class = seq_len = 512 | (512, 8)
elif 'position_embeddings' in name and 'weight' in name:
init_1d_row_embedding(p, pg)
# num_class = type_vocab_size = 2 | (2, 8)
elif 'token_type_embeddings' in name and 'weight' in name:
init_1d_col_embedding(p, pg)
elif p.process_group.tp_world_size() == 1:
p.redistribute(ReplicaSpec(), pg)
elif "simple_net" == model_name:
init_spec_func(model, pg)
model = model.cuda() model = model.cuda()
model.train()
model_reload = model_reload.cuda() model_reload = model_reload.cuda()
if use_ddp: model_reload.train()
model = ColoDDP(model, pg)
model_reload = ColoDDP(model_reload, pg)
init_spec_func(model, pg) colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
if use_mp_reload:
init_spec_func(model_reload, pg) for i, (data, label) in enumerate(train_dataloader):
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
optimizer_reload = torch.optim.Adam(model_reload.parameters(),
lr=0.001,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0)
lr_scheduler = None
if test_scheduler == 'colossalai_cosine_warmup':
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=num_epoch, warmup_steps=warmup_epoch)
lr_scheduler_reload = CosineAnnealingWarmupLR(optimizer=optimizer_reload,
total_steps=num_epoch,
warmup_steps=warmup_epoch)
elif test_scheduler == 'torch_cosine':
lr_scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=num_epoch)
lr_scheduler_reload = CosineAnnealingLR(optimizer=optimizer_reload, T_max=num_epoch)
elif test_scheduler == 'torch_lambda':
lr_lambda = lambda epoch: 0.95
lr_scheduler = MultiplicativeLR(optimizer=optimizer, lr_lambda=lr_lambda)
lr_scheduler_reload = MultiplicativeLR(optimizer=optimizer_reload, lr_lambda=lr_lambda)
else:
raise TypeError(f"{test_scheduler} is invalid")
save_checkpoint('./checkpoint', 0, model, optimizer, lr_scheduler) # Zero grad
colo_optimizer.zero_grad()
data = data.to(get_current_device())
label = label.to(get_current_device())
# Bcast rank0 data to all processes
if criterion:
output = model(data)
loss = criterion(output, label)
else:
output = model(data, label)
loss = output
loss.backward()
colo_optimizer.step()
if i > 2:
break
if not os.path.isdir('./checkpoint') and rank == 0:
os.mkdir('./checkpoint')
save_checkpoint('./checkpoint', 0, model, None, None)
dist.barrier() dist.barrier()
load_checkpoint('./checkpoint', 0, model_reload, optimizer_reload, lr_scheduler_reload) load_checkpoint('./checkpoint', 0, model_reload, None, None)
# Since model is sharded, we merge them before param checking. # Since model is sharded, we merge them before param checking.
for p in model.parameters(): for p in model.parameters():
...@@ -163,26 +156,29 @@ def run_checkpoint(init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg): ...@@ -163,26 +156,29 @@ def run_checkpoint(init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
check_param_equal(model, model_reload) check_param_equal(model, model_reload)
if rank == 0:
remove('./checkpoint')
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler): def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
if use_ddp and world_size == 1: colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
return
tp_world_size = world_size // 2 if use_ddp else world_size
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, use_mp_reload, test_scheduler=test_scheduler, pg=pg) for model_name in ['bert', 'simple_net']:
_run_checkpoint(model_name,
init_1d_row_for_linear_weight_spec,
use_ddp,
use_mp_reload,
test_scheduler=test_scheduler,
pg=pg)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2]) @pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('use_ddp', [True, False]) @pytest.mark.parametrize('use_ddp', [False])
@pytest.mark.parametrize('use_mp_reload', [True, False]) @pytest.mark.parametrize('use_mp_reload', [True, False])
@pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda']) # @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler): def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None):
if not os.path.isdir('./checkpoint'):
os.mkdir('./checkpoint')
run_func = partial(run_dist, run_func = partial(run_dist,
world_size=world_size, world_size=world_size,
port=free_port(), port=free_port(),
...@@ -190,8 +186,7 @@ def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler): ...@@ -190,8 +186,7 @@ def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler):
use_mp_reload=use_mp_reload, use_mp_reload=use_mp_reload,
test_scheduler=test_scheduler) test_scheduler=test_scheduler)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
remove('./checkpoint')
if __name__ == '__main__': if __name__ == '__main__':
test_checkpoint(2, True, False, "torch_cosine") test_checkpoint(2, use_ddp=False, use_mp_reload=True, test_scheduler="torch_cosine")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment