Unverified Commit a39a5c66 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

Merge branch 'main' into feature/shardformer

parents e79b1e80 aaeb520c
...@@ -66,6 +66,7 @@ def run_dist(rank, world_size, port): ...@@ -66,6 +66,7 @@ def run_dist(rank, world_size, port):
run_grad_clip_norm(world_size=world_size) run_grad_clip_norm(world_size=world_size)
@pytest.mark.skip("this need to be updated")
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2]) @pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
......
import pytest import pytest
import torch import torch
from torch.distributed.distributed_c10d import _get_default_group
import colossalai import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup from colossalai.tensor import ColoTensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.zero.gemini.chunk import ChunkManager from colossalai.zero.gemini.chunk import ChunkManager
from tests.test_tensor.common_utils import debug_print from tests.test_tensor.common_utils import debug_print
...@@ -15,19 +16,18 @@ CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}} ...@@ -15,19 +16,18 @@ CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}}
@parameterize('keep_gathered', [True, False]) @parameterize('keep_gathered', [True, False])
@parameterize('pin_memory', [True, False]) @parameterize('pin_memory', [True, False])
def exam_chunk_memory(keep_gathered, pin_memory): def exam_chunk_memory(keep_gathered, pin_memory):
pg = ProcessGroup()
debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory)) debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory))
params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)] params = [ColoTensor(torch.rand(8, 8)) for _ in range(3)]
config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)} config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)}
chunk_manager = ChunkManager(config) chunk_manager = ChunkManager(config)
assert chunk_manager.total_mem['cpu'] == 0 assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == 0 assert chunk_manager.total_mem['cuda'] == 0
process_group = _get_default_group()
for p in params: for p in params:
chunk_manager.register_tensor(p, 'param', 2, pin_memory=pin_memory) chunk_manager.register_tensor(p, 'param', 2, process_group, pin_memory=pin_memory)
chunk_manager.close_all_groups() chunk_manager.close_all_groups()
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered]
......
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group
import colossalai import colossalai
from colossalai.tensor import ColoParameter from colossalai.tensor import ColoParameter
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.gemini import TensorState from colossalai.zero.gemini import TensorState
...@@ -36,7 +36,7 @@ def check_equal(param, param_cp): ...@@ -36,7 +36,7 @@ def check_equal(param, param_cp):
@parameterize('pin_memory', [True, False]) @parameterize('pin_memory', [True, False])
def exam_chunk_basic(init_device, keep_gathered, pin_memory): def exam_chunk_basic(init_device, keep_gathered, pin_memory):
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
pg = ColoProcessGroup() pg = _get_default_group()
my_chunk = Chunk(chunk_size=1024, my_chunk = Chunk(chunk_size=1024,
process_group=pg, process_group=pg,
dtype=torch.float32, dtype=torch.float32,
......
import pytest import pytest
import torch import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager from tests.components_to_test import run_fwd_bwd
from tests.components_to_test import run_fwd, run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import set_seed from tests.test_tensor.common_utils import set_seed
PLACEMENT_CONFIGS = [
def check_grad(model: ZeroDDP, torch_model: torch.nn.Module): {
'placement_policy': 'static',
'shard_param_frac': 0.0
}, # zero2
{
'placement_policy': 'static',
'shard_param_frac': 1.0
}, # zero3
{
'placement_policy': 'static',
'shard_param_frac': 0.5
}, # zero3-half
{
'placement_policy': 'auto'
}
]
def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
chunk_manager = model.chunk_manager chunk_manager = model.chunk_manager
param_list = [p for p in model.parameters()] param_list = [p for p in model.parameters()]
chunk_list = chunk_manager.get_chunks(param_list) chunk_list = chunk_manager.get_chunks(param_list)
...@@ -28,12 +45,12 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module): ...@@ -28,12 +45,12 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5) assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5)
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('placement_config', PLACEMENT_CONFIGS)
@parameterize('keep_gather', [False, True]) @parameterize('keep_gather', [False, True])
@parameterize('model_name', ['gpt2', 'bert', 'albert']) @parameterize('model_name', ['gpt2', 'bert', 'albert'])
@parameterize('use_grad_checkpoint', [False, True]) @parameterize('use_grad_checkpoint', [False, True])
def exam_gpt_fwd_bwd( def exam_gpt_fwd_bwd(
placement_policy, placement_config,
keep_gather, keep_gather,
model_name: str, model_name: str,
use_grad_checkpoint: bool = False, use_grad_checkpoint: bool = False,
...@@ -43,8 +60,7 @@ def exam_gpt_fwd_bwd( ...@@ -43,8 +60,7 @@ def exam_gpt_fwd_bwd(
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
set_seed(42) set_seed(42)
with ColoInitContext(device=init_device): model = model_builder(use_grad_checkpoint)
model = model_builder(use_grad_checkpoint)
set_seed(42) set_seed(42)
torch_model = model_builder(use_grad_checkpoint).cuda() torch_model = model_builder(use_grad_checkpoint).cuda()
...@@ -55,19 +71,17 @@ def exam_gpt_fwd_bwd( ...@@ -55,19 +71,17 @@ def exam_gpt_fwd_bwd(
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict) model = GeminiDDP(model, config_dict, init_device, pin_memory=True, **placement_config)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters(), lr=1e-3) optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1)
pg = ProcessGroup() rank = dist.get_rank()
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) torch_model = DDP(torch_model, device_ids=[rank])
set_seed(pg.dp_local_rank()) set_seed(rank)
for i, (input_ids, label) in enumerate(train_dataloader): for i, (input_ids, label) in enumerate(train_dataloader):
# you can only test a single fwd + bwd. # you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization. # after bwd param is grad for Gemini, due to the chunk reuse optimization.
...@@ -89,65 +103,10 @@ def exam_gpt_fwd_bwd( ...@@ -89,65 +103,10 @@ def exam_gpt_fwd_bwd(
check_grad(model, torch_model) check_grad(model, torch_model)
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('keep_gather', [False, True])
@parameterize('model_name', ['gpt2', 'bert', 'albert'])
@parameterize('scatter_after_inference', [False, True])
def exam_gpt_inference(
placement_policy,
keep_gather,
model_name: str,
scatter_after_inference: bool = False,
):
init_device = get_current_device()
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
set_seed(42)
with ColoInitContext(device=init_device):
model = model_builder()
set_seed(42)
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True, scatter_after_inference=scatter_after_inference)
pg = ProcessGroup()
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
set_seed(pg.dp_local_rank())
model.eval()
torch_model.eval()
for i, (input_ids, label) in enumerate(train_dataloader):
# you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
if i > 0:
break
with torch.no_grad():
input_ids, label = input_ids.cuda(), label.cuda()
torch_loss = run_fwd(torch_model, input_ids, label, criterion)
loss = run_fwd(model, input_ids, label, criterion)
assert torch.equal(torch_loss, loss)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
config = {} config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_gpt_fwd_bwd() exam_gpt_fwd_bwd()
exam_gpt_inference()
@pytest.mark.dist @pytest.mark.dist
......
import pytest import pytest
import torch import torch
import torch.distributed as dist
import colossalai import colossalai
from colossalai.tensor import ProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.zero import ColoInitContext, ZeroDDP from colossalai.zero import GeminiDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
from tests.components_to_test import run_fwd_bwd from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
...@@ -24,8 +23,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ ...@@ -24,8 +23,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device='cpu'): model = model_builder(use_grad_checkpoint).cuda()
model = model_builder(use_grad_checkpoint)
print(f'model_name {model_name}') print(f'model_name {model_name}')
runtime_mem_tracer = RuntimeMemTracer(model) runtime_mem_tracer = RuntimeMemTracer(model)
...@@ -59,12 +57,13 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ ...@@ -59,12 +57,13 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict) model = GeminiDDP(model,
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) chunk_config_dict=config_dict,
model = ZeroDDP(model, gemini_manager, pin_memory=True) placement_policy=placement_policy,
pin_memory=True,
memstats=memstats)
pg = ProcessGroup() set_seed(dist.get_rank())
set_seed(pg.dp_local_rank())
for i, (input_ids, label) in enumerate(train_dataloader): for i, (input_ids, label) in enumerate(train_dataloader):
# you can only test a single fwd + bwd. # you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization. # after bwd param is grad for Gemini, due to the chunk reuse optimization.
...@@ -76,7 +75,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ ...@@ -76,7 +75,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
set_seed(42) set_seed(42)
loss = run_fwd_bwd(model, input_ids, label, criterion, model) loss = run_fwd_bwd(model, input_ids, label, criterion, model)
gemini_non_model_data = gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda') gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda')
# print('gemini non model data:', gemini_non_model_data) # print('gemini non model data:', gemini_non_model_data)
...@@ -90,6 +89,7 @@ def run_dist(rank, world_size, port): ...@@ -90,6 +89,7 @@ def run_dist(rank, world_size, port):
run_gemini_use_rmt() run_gemini_use_rmt()
@pytest.mark.skip("this is not used")
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
......
import pytest
import torch
import colossalai
from colossalai.tensor import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, GeminiDDP
from colossalai.zero.gemini.utils import get_static_torch_model
from tests.components_to_test.registry import non_distributed_component_funcs
@parameterize('model_name', ['hanging_param_model', 'resnet18', 'gpt2'])
def run_convert_torch_module(model_name: str):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, _, _, _, _ = get_components_func()
with ColoInitContext(device=torch.device("cpu")):
model = model_builder(checkpoint=False)
model = GeminiDDP(model, device=get_current_device(), placement_policy='auto', pin_memory=True)
pytorch_model = get_static_torch_model(model, only_rank_0=False)
for n, p in pytorch_model.named_parameters():
assert type(p) == torch.nn.Parameter, f"type error: {n} is a {type(p)}"
# get the static model should not change the original model
for n, p in model.named_parameters():
assert isinstance(p, ColoParameter)
for (pn, pm), (cn, cm) in zip(pytorch_model.named_modules(), model.named_modules()):
assert pn == cn
assert id(pm) != id(cm)
for pp, cp in zip(pm.parameters(recurse=False), cm.parameters(recurse=False)):
assert id(pp) != id(cp)
assert pp.shape == cp.shape
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_convert_torch_module()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_convert_torch_module(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_convert_torch_module(2)
...@@ -8,16 +8,38 @@ import colossalai ...@@ -8,16 +8,38 @@ import colossalai
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from tests.components_to_test import run_fwd_bwd from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import set_seed from tests.test_tensor.common_utils import set_seed
PLACEMENT_CONFIGS = [
def check_param(model: ZeroDDP, torch_model: torch.nn.Module): {
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 0.0,
'offload_param_frac': 0.0
}, # zero2
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 1.0,
'offload_param_frac': 0.0
}, # zero2-offload
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 0.5,
'offload_param_frac': 0.0
}, # zero2-offload-half
{
'placement_policy': 'auto'
}
]
def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
zero_dict = model.state_dict(only_rank_0=False) zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict() torch_dict = torch_model.state_dict()
...@@ -30,9 +52,9 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module): ...@@ -30,9 +52,9 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('placement_config', PLACEMENT_CONFIGS)
@parameterize('model_name', ['gpt2']) @parameterize('model_name', ['gpt2'])
def exam_grad_clipping(placement_policy, model_name: str): def exam_grad_clipping(placement_config, model_name: str):
set_seed(1912) set_seed(1912)
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
...@@ -43,9 +65,7 @@ def exam_grad_clipping(placement_policy, model_name: str): ...@@ -43,9 +65,7 @@ def exam_grad_clipping(placement_policy, model_name: str):
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
init_dev = get_current_device() model = model_builder()
with ColoInitContext(device=init_dev):
model = model_builder()
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data) p.data.copy_(torch_p.data)
...@@ -54,16 +74,19 @@ def exam_grad_clipping(placement_policy, model_name: str): ...@@ -54,16 +74,19 @@ def exam_grad_clipping(placement_policy, model_name: str):
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda': if placement_config['placement_policy'] != 'cuda':
init_device = torch.device('cpu') init_device = torch.device('cpu')
else: else:
init_device = None init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager) model = GeminiDDP(model,
model = ZeroDDP(model, gemini_manager, pin_memory=True) chunk_config_dict=config_dict,
chunk_init_device=init_device,
pin_memory=True,
**placement_config)
optimizer = HybridAdam(model.parameters(), lr=1e-3) optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0)
model.train() model.train()
torch_model.train() torch_model.train()
......
...@@ -11,15 +11,32 @@ from colossalai.amp import convert_to_apex_amp ...@@ -11,15 +11,32 @@ from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx, zero_model_wrapper from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from tests.components_to_test import run_fwd_bwd from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed from tests.test_tensor.common_utils import set_seed
PLACEMENT_CONFIGS = [
def check_param(model: ZeroDDP, torch_model: torch.nn.Module): {
'placement_policy': 'static',
'shard_param_frac': 0.0
}, # zero2
{
'placement_policy': 'static',
'shard_param_frac': 1.0
}, # zero3
{
'placement_policy': 'static',
'shard_param_frac': 0.5
}, # zero3-half
{
'placement_policy': 'auto'
}
]
def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
zero_dict = model.state_dict(only_rank_0=False) zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict() torch_dict = torch_model.state_dict()
...@@ -32,35 +49,24 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module): ...@@ -32,35 +49,24 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)
def multi_chunk_init(model: torch.nn.Module, placement_policy: str): def multi_chunk_init(model: torch.nn.Module, placement_config: dict):
world_size = dist.get_world_size() world_size = dist.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda': model = GeminiDDP(model, config_dict, pin_memory=True, **placement_config)
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
return model return model
def single_chunk_init(model: torch.nn.Module, placement_policy: str): def single_chunk_init(model: torch.nn.Module, placement_config: dict):
gemini_config = dict( model = GeminiDDP(model, chunk_init_device=get_current_device(), pin_memory=True, **placement_config)
device=get_current_device(),
placement_policy=placement_policy,
pin_memory=True,
)
model = zero_model_wrapper(model=model, zero_stage=3, gemini_config=gemini_config)
return model return model
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('placement_config', PLACEMENT_CONFIGS)
@parameterize('model_name', ['gpt2']) @parameterize('model_name', ['gpt2'])
@parameterize('model_init_func', [single_chunk_init, multi_chunk_init]) @parameterize('model_init_func', [single_chunk_init, multi_chunk_init])
def exam_inference(placement_policy: str, model_name: str, model_init_func: Callable): def exam_inference(placement_config: dict, model_name: str, model_init_func: Callable):
set_seed(19360226) set_seed(19360226)
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
...@@ -70,17 +76,15 @@ def exam_inference(placement_policy: str, model_name: str, model_init_func: Call ...@@ -70,17 +76,15 @@ def exam_inference(placement_policy: str, model_name: str, model_init_func: Call
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
init_dev = get_current_device() init_dev = get_current_device()
with ColoInitContext(device=init_dev): model = model_builder().to(init_dev)
model = model_builder()
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data) p.data.copy_(torch_p.data)
model = model_init_func(model, placement_policy) model = model_init_func(model, placement_config)
optimizer = HybridAdam(model.parameters(), lr=1e-3) optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128)
model.eval() model.eval()
torch_model.eval() torch_model.eval()
...@@ -95,7 +99,7 @@ def exam_inference(placement_policy: str, model_name: str, model_init_func: Call ...@@ -95,7 +99,7 @@ def exam_inference(placement_policy: str, model_name: str, model_init_func: Call
torch_optim.zero_grad() torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
assert_close(torch_loss, loss) assert_close(torch_loss, loss, rtol=1e-5, atol=1e-5)
zero_optim.step() zero_optim.step()
torch_optim.step() torch_optim.step()
check_param(model, torch_model) check_param(model, torch_model)
......
...@@ -9,12 +9,46 @@ from colossalai.amp import convert_to_apex_amp ...@@ -9,12 +9,46 @@ from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from tests.components_to_test import run_fwd_bwd from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed from tests.test_tensor.common_utils import set_seed
PLACEMENT_CONFIGS = [
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 0.0
}, # zero2
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 1.0
}, # zero2-offload
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 0.5
}, # zero2-offload-half
{
'placement_policy': 'static',
'shard_param_frac': 1.0
}, # zero3
{
'placement_policy': 'static',
'shard_param_frac': 0.5
}, # zero3-half
{
'placement_policy': 'static',
'shard_param_frac': 1.0,
'offload_optim_frac': 1.0,
'offload_param_frac': 1.0
}, # zero3-offload-all
{
'placement_policy': 'auto'
}
]
# this model is large enough to slice to chunks # this model is large enough to slice to chunks
TEST_MODELS = ['gpt2'] TEST_MODELS = ['gpt2']
...@@ -29,7 +63,7 @@ BF16_IGNORED_KEYS = [ ...@@ -29,7 +63,7 @@ BF16_IGNORED_KEYS = [
] ]
def check_param(model: ZeroDDP, torch_model: torch.nn.Module, dtype: torch.dtype): def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype):
zero_dict = model.state_dict(only_rank_0=False, dtype=dtype) zero_dict = model.state_dict(only_rank_0=False, dtype=dtype)
torch_dict = torch_model.state_dict() torch_dict = torch_model.state_dict()
...@@ -51,10 +85,10 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, dtype: torch.dtype ...@@ -51,10 +85,10 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, dtype: torch.dtype
msg=lambda s: s + f'\n{key}\n{temp_zero_value.dtype}') msg=lambda s: s + f'\n{key}\n{temp_zero_value.dtype}')
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('placement_config', PLACEMENT_CONFIGS)
@parameterize('model_name', TEST_MODELS) @parameterize('model_name', TEST_MODELS)
@parameterize('mixed_precision', [torch.half, torch.bfloat16]) @parameterize('mixed_precision', [torch.half, torch.bfloat16])
def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dtype): def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype):
set_seed(42) set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
...@@ -65,9 +99,7 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt ...@@ -65,9 +99,7 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
init_dev = get_current_device() model = model_builder().cuda()
with ColoInitContext(device=init_dev):
model = model_builder()
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data) p.data.copy_(torch_p.data)
...@@ -76,16 +108,10 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt ...@@ -76,16 +108,10 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda': model = GeminiDDP(model, config_dict, **placement_config, mixed_precision=mixed_precision)
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
optimizer = HybridAdam(model.parameters(), lr=1e-3) optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128)
model.eval() model.eval()
torch_model.eval() torch_model.eval()
...@@ -109,10 +135,10 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt ...@@ -109,10 +135,10 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
check_param(model, torch_model, mixed_precision) check_param(model, torch_model, mixed_precision)
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('placement_config', PLACEMENT_CONFIGS)
@parameterize('model_name', EXAMPLE_MODELS) @parameterize('model_name', EXAMPLE_MODELS)
@parameterize('mixed_precision', [torch.half, torch.bfloat16]) @parameterize('mixed_precision', [torch.half, torch.bfloat16])
def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.dtype): def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.dtype):
set_seed(2008) set_seed(2008)
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
...@@ -123,18 +149,19 @@ def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch. ...@@ -123,18 +149,19 @@ def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
init_dev = get_current_device() model = model_builder().cuda()
with ColoInitContext(device=init_dev):
model = model_builder()
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data) p.data.copy_(torch_p.data)
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_m=1) model = GeminiDDP(model,
gemini_manager = GeminiManager(placement_policy, chunk_manager) chunk_init_device=get_current_device(),
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision) search_range_m=1,
pin_memory=True,
mixed_precision=mixed_precision,
**placement_config)
optimizer = HybridAdam(model.parameters(), lr=1e-3) optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=2)
model.eval() model.eval()
torch_model.eval() torch_model.eval()
......
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
import pytest
import torch import torch
from colossalai.testing import clear_cache_before_run from colossalai.testing import clear_cache_before_run
from colossalai.zero import ColoInitContext
from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
from tests.components_to_test import run_fwd_bwd from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
@pytest.mark.skip("this is not used")
@clear_cache_before_run() @clear_cache_before_run()
def test_runtime_mem_tracer(): def test_runtime_mem_tracer():
test_models = ['gpt2', 'bert', 'simple_net', 'repeated_computed_layers', 'nested_model', 'albert'] test_models = ['gpt2', 'bert', 'simple_net', 'repeated_computed_layers', 'nested_model', 'albert']
...@@ -18,8 +19,7 @@ def test_runtime_mem_tracer(): ...@@ -18,8 +19,7 @@ def test_runtime_mem_tracer():
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func() model_builder, train_dataloader, _, _, criterion = get_components_func()
with ColoInitContext(device='cpu'): model = model_builder(checkpoint=False).cuda()
model = model_builder(checkpoint=False)
model_bk = deepcopy(model) model_bk = deepcopy(model)
runtime_mem_tracer = RuntimeMemTracer(model) runtime_mem_tracer = RuntimeMemTracer(model)
......
...@@ -2,33 +2,20 @@ import pytest ...@@ -2,33 +2,20 @@ import pytest
import torch import torch
import colossalai import colossalai
from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
def init_1d_row_spec(model, pg: ProcessGroup):
tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n:
p.set_process_group(pg)
p.set_tensor_spec(*tensor_spec)
def exam_search_chunk_size(): def exam_search_chunk_size():
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
pg_tp = ProcessGroup(tp_degree=world_size)
get_components_func = non_distributed_component_funcs.get_callable('gpt2') get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
# make sure torch_model and model has the same parameter values # make sure torch_model and model has the same parameter values
with ColoInitContext(device=get_current_device()): model = model_builder()
model = model_builder()
init_1d_row_spec(model, pg_tp)
config_dict, *_ = search_chunk_configuration(model, config_dict, *_ = search_chunk_configuration(model,
search_range_m=1, search_range_m=1,
search_interval=16, search_interval=16,
...@@ -37,57 +24,19 @@ def exam_search_chunk_size(): ...@@ -37,57 +24,19 @@ def exam_search_chunk_size():
for key in config_dict: for key in config_dict:
chunk_size = config_dict[key]['chunk_size'] chunk_size = config_dict[key]['chunk_size']
if world_size == 1: if world_size == 1 or True:
assert chunk_size == 31616 assert chunk_size == 31616
else: else:
assert chunk_size == 1024 assert chunk_size == 1024
def exam_search_strict_ddp():
world_size = torch.distributed.get_world_size()
default_shard_pg = ProcessGroup(tp_degree=world_size)
default_shard_spec = ShardSpec([-1], [world_size])
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
# get the chunk configuration over replicated models
with ColoInitContext(device=get_current_device()):
ddp_model = model_builder()
re_dict, re_total, re_wasted = search_chunk_configuration(ddp_model,
search_range_m=1,
search_interval=16,
min_chunk_size_m=0,
filter_exlarge_params=True,
strict_ddp_flag=False)
# get the chunk configuration over sharded ddp models
with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg,
default_dist_spec=default_shard_spec):
sharded_ddp_model = model_builder()
sh_dict, sh_total, sh_wasted = search_chunk_configuration(sharded_ddp_model,
search_range_m=1,
search_interval=16,
min_chunk_size_m=0,
filter_exlarge_params=True,
strict_ddp_flag=True)
assert re_dict == sh_dict
for key in re_dict:
assert re_dict[key] == sh_dict[key]
assert re_total == sh_total
assert re_wasted == sh_wasted
def exam_chunk_manager(): def exam_chunk_manager():
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
default_shard_pg = ProcessGroup(tp_degree=world_size)
default_shard_spec = ShardSpec([-1], [world_size])
get_components_func = non_distributed_component_funcs.get_callable('gpt2') get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg, sharded_ddp_model = model_builder()
default_dist_spec=default_shard_spec):
sharded_ddp_model = model_builder()
chunk_manager = init_chunk_manager(sharded_ddp_model, chunk_manager = init_chunk_manager(sharded_ddp_model,
get_current_device(), get_current_device(),
hidden_dim=16, hidden_dim=16,
...@@ -103,7 +52,6 @@ def exam_chunk_manager(): ...@@ -103,7 +52,6 @@ def exam_chunk_manager():
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_search_chunk_size() exam_search_chunk_size()
exam_search_strict_ddp()
exam_chunk_manager() exam_chunk_manager()
......
...@@ -4,31 +4,46 @@ from torch.testing import assert_close ...@@ -4,31 +4,46 @@ from torch.testing import assert_close
import colossalai import colossalai
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device from colossalai.zero import GeminiDDP
from colossalai.zero import ColoInitContext, ZeroDDP from colossalai.zero.gemini.chunk import search_chunk_configuration
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed from tests.test_tensor.common_utils import set_seed
PLACEMENT_CONFIGS = [
{
'placement_policy': 'static',
'shard_param_frac': 0.0
}, # zero2
{
'placement_policy': 'static',
'shard_param_frac': 1.0
}, # zero3
{
'placement_policy': 'static',
'shard_param_frac': 0.5
}, # zero3-half
{
'placement_policy': 'auto'
}
]
def ignore_the_first_parameter(model: torch.nn.Module): def ignore_the_first_parameter(model: torch.nn.Module):
for name, param in model.named_parameters(): for name, param in model.named_parameters():
print(f"parameter `{name}` is set ignored") print(f"parameter `{name}` is set ignored")
ZeroDDP.set_params_to_ignore([param]) GeminiDDP.set_params_to_ignore([param])
return return
@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) @parameterize('placement_config', PLACEMENT_CONFIGS)
@parameterize('keep_gathered', [True, False]) @parameterize('keep_gathered', [True, False])
@parameterize('model_name', ['gpt2', 'bert']) @parameterize('model_name', ['gpt2', 'bert'])
def exam_state_dict(placement_policy, keep_gathered, model_name: str): def exam_state_dict(placement_config, keep_gathered, model_name: str):
set_seed(431) set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()): model = model_builder()
model = model_builder()
torch_model = model_builder() torch_model = model_builder()
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
...@@ -38,9 +53,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str): ...@@ -38,9 +53,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered config_dict[world_size]['keep_gathered'] = keep_gathered
chunk_manager = ChunkManager(config_dict) model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
model.train() model.train()
zero_dict = model.state_dict(only_rank_0=False) zero_dict = model.state_dict(only_rank_0=False)
...@@ -52,16 +65,15 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str): ...@@ -52,16 +65,15 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)
@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) @parameterize('placement_config', PLACEMENT_CONFIGS)
@parameterize('keep_gathered', [True, False]) @parameterize('keep_gathered', [True, False])
@parameterize('model_name', ['gpt2', 'bert']) @parameterize('model_name', ['gpt2', 'bert'])
def exam_load_state_dict(placement_policy, keep_gathered, model_name: str): def exam_load_state_dict(placement_config, keep_gathered, model_name: str):
set_seed(431) set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()): model = model_builder()
model = model_builder()
set_seed(451) set_seed(451)
torch_model = model_builder() # get a different model torch_model = model_builder() # get a different model
...@@ -71,13 +83,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str): ...@@ -71,13 +83,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered config_dict[world_size]['keep_gathered'] = keep_gathered
if placement_policy != 'cuda': model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
torch_dict = torch_model.state_dict() torch_dict = torch_model.state_dict()
model.load_state_dict(torch_dict, strict=False) model.load_state_dict(torch_dict, strict=False)
...@@ -89,11 +95,37 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str): ...@@ -89,11 +95,37 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)
@parameterize('placement_config', PLACEMENT_CONFIGS)
@parameterize('model_name', ['gpt2', 'bert'])
def exam_state_dict_shard(placement_config, model_name: str):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder()
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
model = GeminiDDP(model, config_dict, **placement_config)
model.train()
zero_dict = model.state_dict(only_rank_0=False)
accumulated_keys = set()
# ensure number of shards > 1
for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
for key, value in shard.items():
assert key not in accumulated_keys, f"key `{key}` is duplicated."
accumulated_keys.add(key)
assert key in zero_dict, f"{key} not in ZeRO dictionary."
assert torch.equal(value, zero_dict[key]), f"{key} not equal."
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
config = {} config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_state_dict() exam_state_dict()
exam_load_state_dict() exam_load_state_dict()
exam_state_dict_shard()
@pytest.mark.dist @pytest.mark.dist
......
import pytest
import torch
from torch.testing import assert_close
import colossalai
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from tests.components_to_test.registry import non_distributed_component_funcs
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', ['gpt2', 'bert'])
def exam_state_dict(placement_policy, model_name: str):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)
model.train()
zero_dict = model.state_dict(only_rank_0=False)
accumulated_keys = set()
# ensure number of shards > 1
for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
for key, value in shard.items():
assert key not in accumulated_keys, f"key `{key}` is duplicated."
accumulated_keys.add(key)
assert key in zero_dict, f"{key} not in ZeRO dictionary."
assert torch.equal(value, zero_dict[key]), f"{key} not equal."
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_state_dict()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_zero_ddp_state_dict_shard(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_zero_ddp_state_dict_shard(1)
...@@ -5,42 +5,53 @@ import torch.distributed as dist ...@@ -5,42 +5,53 @@ import torch.distributed as dist
import colossalai import colossalai
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed from tests.test_tensor.common_utils import set_seed
PLACEMENT_CONFIGS = [
@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) {
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 0.0
}, # zero2
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 1.0
}, # zero2-offload
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 0.5
}, # zero2-offload-half
{
'placement_policy': 'auto'
}
]
@parameterize('placement_config', PLACEMENT_CONFIGS)
@parameterize('keep_gathered', [True, False]) @parameterize('keep_gathered', [True, False])
def exam_zero_optim_state_dict(placement_policy, keep_gathered): def exam_zero_optim_state_dict(placement_config, keep_gathered):
set_seed(431) set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable('gpt2') get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()): model = model_builder()
model = model_builder()
set_seed(451) set_seed(451)
torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered config_dict[world_size]['keep_gathered'] = keep_gathered
if placement_policy != 'cuda': model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters()) optimizer = HybridAdam(model.parameters())
optim = ZeroOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32 optim = GeminiOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32
set_seed(dist.get_rank() * 3 + 128) set_seed(dist.get_rank() * 3 + 128)
model.train() model.train()
......
...@@ -58,17 +58,8 @@ def exam_zero_1_2_grad_acc(): ...@@ -58,17 +58,8 @@ def exam_zero_1_2_grad_acc():
assert torch.equal(zero1_output, zero2_output) assert torch.equal(zero1_output, zero2_output)
# zero-dp backward # zero-dp backward
no_sync = number == 0 zero1_optimizer.backward(zero1_output.sum().float())
with conditional_context(zero1_optimizer.no_sync(), no_sync): zero2_optimizer.backward(zero2_output.sum().float())
zero1_optimizer.backward(zero1_output.sum().float())
with conditional_context(zero2_optimizer.no_sync(), no_sync):
zero2_optimizer.backward(zero2_output.sum().float())
if check_flag:
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
if z2p.grad is not None:
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
assert torch.equal(z1p.grad, z2p.grad)
fwd_bwd_func(0, input_data1, True) fwd_bwd_func(0, input_data1, True)
fwd_bwd_func(1, input_data2, False) fwd_bwd_func(1, input_data2, False)
...@@ -82,7 +73,7 @@ def exam_zero_1_2_grad_acc(): ...@@ -82,7 +73,7 @@ def exam_zero_1_2_grad_acc():
assert torch.equal(z1p.data, z2p.data) assert torch.equal(z1p.data, z2p.data)
def exam_zero_1_grad_acc(): def exam_zero_1_grad_acc(sync):
local_rank = torch.distributed.get_rank() local_rank = torch.distributed.get_rank()
seed_all(2008) seed_all(2008)
...@@ -112,9 +103,8 @@ def exam_zero_1_grad_acc(): ...@@ -112,9 +103,8 @@ def exam_zero_1_grad_acc():
input_data1 = torch.randn(32, 128).cuda() input_data1 = torch.randn(32, 128).cuda()
input_data2 = torch.randn(32, 128).cuda() input_data2 = torch.randn(32, 128).cuda()
def fwd_bwd_func(number, cur_data, check_flag): def fwd_bwd_func(no_sync, cur_data, check_flag):
no_sync = number == 0
# zero1 fwd and bwd # zero1 fwd and bwd
with conditional_context(zero_optimizer.no_sync(), no_sync): with conditional_context(zero_optimizer.no_sync(), no_sync):
zero_output = zero_model(cur_data) zero_output = zero_model(cur_data)
...@@ -131,8 +121,8 @@ def exam_zero_1_grad_acc(): ...@@ -131,8 +121,8 @@ def exam_zero_1_grad_acc():
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
assert torch.equal(p.grad, z1p.grad) assert torch.equal(p.grad, z1p.grad)
fwd_bwd_func(0, input_data1, True) fwd_bwd_func(sync, input_data1, sync)
fwd_bwd_func(1, input_data2, False) fwd_bwd_func(False, input_data2, False)
zero_optimizer.step() zero_optimizer.step()
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
...@@ -147,9 +137,9 @@ def exam_zero_1_grad_acc(): ...@@ -147,9 +137,9 @@ def exam_zero_1_grad_acc():
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_1_grad_acc() exam_zero_1_grad_acc(sync=True)
# gradient accumulation is not compatible with ZeRO-2 exam_zero_1_grad_acc(sync=False)
# exam_zero_1_2_grad_acc() exam_zero_1_2_grad_acc()
@pytest.mark.dist @pytest.mark.dist
......
...@@ -37,7 +37,7 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32): ...@@ -37,7 +37,7 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
atol = 4e-3 atol = 4e-3
a = a.detach().to(dtype) a = a.detach().to(dtype)
b = b.detach().to(dtype) b = b.detach().to(dtype).to(a.device)
assert_close(a, b, rtol=rtol, atol=atol) assert_close(a, b, rtol=rtol, atol=atol)
......
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
import colossalai
from colossalai.tensor import ProcessGroup
from colossalai.testing import spawn
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
def exam_zero_init():
dp_2_tp_2_pg = ProcessGroup(dp_degree=2, tp_degree=2)
model1 = MlpModel().cuda()
with ColoInitContext(device=get_current_device(), default_pg=dp_2_tp_2_pg):
model2 = MlpModel()
optimizer1 = LowLevelZeroOptimizer(torch.optim.Adam(model1.parameters(), lr=1))
optimizer2 = LowLevelZeroOptimizer(torch.optim.Adam(model2.parameters(), lr=1))
assert optimizer1._local_rank == optimizer2._local_rank
assert optimizer1._world_size == optimizer2._world_size
mp_group1 = optimizer1.tp_pg
mp_group2 = optimizer2.tp_pg
assert dist.get_world_size(mp_group1) == dist.get_world_size(mp_group2)
assert dist.get_rank(mp_group1) == dist.get_rank(mp_group2)
def run_dist(rank, world_size, port):
config_dict = dict(parallel=dict(data=2, tensor=dict(size=2, mode='1d')))
colossalai.launch(config=config_dict, rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_init()
@pytest.mark.dist
def test_zero_init():
spawn(run_dist, 4)
if __name__ == '__main__':
test_zero_init()
...@@ -85,6 +85,7 @@ def run_dist(rank, world_size, port): ...@@ -85,6 +85,7 @@ def run_dist(rank, world_size, port):
exam_zero_with_tp() exam_zero_with_tp()
@pytest.mark.skip('this will be rewritten by shardformer')
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_zero_with_tp(): def test_zero_with_tp():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment