Commit cb34cd38 authored by Jiarui Fang's avatar Jiarui Fang Committed by Frank Lee
Browse files

[test] polish zero related unitest (#351)

parent 534e0bb1
import torch
from colossalai.zero.sharded_model import ShardedModelV2
import copy
def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module):
"""
copy param of the ShardedModelV2 to other_model.
Note the other_model has to be the same as self.
"""
for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()):
assert hasattr(zero_param, 'col_attr')
shard_flag = zero_param.col_attr.data.is_sharded
if shard_flag:
sharded_model.shard_strategy.gather([zero_param.col_attr.data])
param.data = copy.deepcopy(zero_param.col_attr.data.payload)
if shard_flag:
sharded_model.shard_strategy.shard([zero_param.col_attr.data])
...@@ -3,8 +3,10 @@ from functools import partial ...@@ -3,8 +3,10 @@ from functools import partial
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import checkpoint from colossalai.utils import checkpoint
from colossalai.zero.sharded_model import ShardedModelV2
LOGGER = get_dist_logger() LOGGER = get_dist_logger()
...@@ -20,6 +22,21 @@ CONFIG = dict(fp16=dict(mode=None,), ...@@ -20,6 +22,21 @@ CONFIG = dict(fp16=dict(mode=None,),
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None))) parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, ShardedModelV2):
model.backward(loss)
else:
loss.backward()
def checkpoint_wrapper(module, enable=True): def checkpoint_wrapper(module, enable=True):
if enable: if enable:
module.forward = partial(checkpoint, module.forward) module.forward = partial(checkpoint, module.forward)
......
...@@ -3,81 +3,70 @@ ...@@ -3,81 +3,70 @@
import copy import copy
from functools import partial from functools import partial
import colossalai
import pytest import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils.tensor_shard_strategy import \ from colossalai.zero.shard_utils.tensor_shard_strategy import \
TensorShardStrategy TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_grads_padding
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False): from tests.components_to_test.registry import non_distributed_component_funcs
model.train() from common import CONFIG, check_grads_padding, run_fwd_bwd
with torch.cuda.amp.autocast(enabled=enable_autocast): from colossalai.zero.sharded_model.utils import col_model_deepcopy
y = model(data)
loss = criterion(y, label)
loss = loss.float()
if isinstance(model, ShardedModelV2):
model.backward(loss)
else:
loss.backward()
# with no criterion
def run_fwd_bwd_no_criterion(model, data, label, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
loss = model(data, label)
if isinstance(model, ShardedModelV2):
model.backward(loss)
else:
loss.backward()
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = TensorShardStrategy() shard_strategy = TensorShardStrategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() model_builder, train_dataloader, _, _, criterion = get_components_func()
model = model(checkpoint=True).half().cuda()
if use_zero_init_ctx:
with ZeroInitContext(convert_fp16=True, convert_cuda=True, shard_strategy=shard_strategy, shard_param=True):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model, shard_strategy)
model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
model = model.cuda()
else:
model = model_builder(checkpoint=True).half().cuda()
zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy) zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy)
if dist.get_world_size() > 1:
model = DDP(model) model = DDP(model)
for i, (data, label) in enumerate(train_dataloader): for i, (data, label) in enumerate(train_dataloader):
if i > 2: if i > 3:
break break
if criterion is None:
data, label = data.cuda(), label.cuda()
run_fwd_bwd_no_criterion(model, data, label, False)
run_fwd_bwd_no_criterion(zero_model, data, label, False)
else:
data, label = cast_tensor_to_fp16(data).cuda(), label.cuda() data, label = cast_tensor_to_fp16(data).cuda(), label.cuda()
run_fwd_bwd(model, data, label, criterion, False) run_fwd_bwd(model, data, label, criterion, enable_autocast)
run_fwd_bwd(zero_model, data, label, criterion, False) run_fwd_bwd(zero_model, data, label, criterion, enable_autocast)
check_grads_padding(model, zero_model, loose=True) check_grads_padding(model, zero_model, loose=True)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2, 4]) @pytest.mark.parametrize("world_size", [1, 2])
def test_shard_model_v2(world_size): @pytest.mark.parametrize("enable_autocast", [True])
run_func = partial(run_dist, world_size=world_size, port=free_port()) @pytest.mark.parametrize("use_zero_init_ctx", [True])
def test_shard_model_v2(world_size, use_zero_init_ctx, enable_autocast):
run_func = partial(run_dist,
world_size=world_size,
port=free_port(),
use_zero_init_ctx=use_zero_init_ctx,
enable_autocast=enable_autocast)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_shard_model_v2(world_size=2) test_shard_model_v2(world_size=2, use_zero_init_ctx=True, enable_autocast=True)
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
from functools import partial
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils.tensor_shard_strategy import \
TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_grads, check_grads_padding
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(data)
loss = criterion(y, label)
loss = loss.float()
if isinstance(model, ShardedModelV2):
model.backward(loss)
else:
loss.backward()
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18']
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
shard_strategy = TensorShardStrategy()
with ZeroInitContext(convert_fp16=True, convert_cuda=True, shard_strategy=shard_strategy, shard_param=True):
zero_model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
zero_model = zero_model()
model = copy.deepcopy(zero_model)
zero_model = ShardedModelV2(zero_model, shard_strategy)
model_state_dict = zero_model.state_dict()
for n, p in model.named_parameters():
p.data = model_state_dict[n]
model = model.half().cuda()
if dist.get_world_size() > 1:
model = DDP(model)
for i, (data, label) in enumerate(train_dataloader):
if i > 2:
break
data, label = data.half().cuda(), label.cuda()
run_fwd_bwd(model, data, label, criterion, False)
run_fwd_bwd(zero_model, data, label, criterion, False)
if dist.get_world_size() > 1:
check_grads_padding(model, zero_model, loose=True)
else:
check_grads(model, zero_model, loose=True)
@pytest.mark.dist
def test_shard_model_v2():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_shard_model_v2()
...@@ -78,7 +78,7 @@ def run_dist(rank, world_size, port): ...@@ -78,7 +78,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2, 4]) @pytest.mark.parametrize("world_size", [1, 2])
def test_sharded_optim_v2(world_size): def test_sharded_optim_v2(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment