Unverified Commit b28991dd authored by HELSON's avatar HELSON Committed by GitHub
Browse files

[feature] A new ZeRO implementation (#1644)

parent b1be5b88
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from functools import partial
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ZeroDDP
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor
from tests.test_tensor.common_utils import debug_print
from time import time
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
chunk_manager = model.chunk_manager
param_list = [p for p in model.parameters()]
chunk_list = chunk_manager.get_chunks(param_list)
for chunk in chunk_list:
chunk_manager.access_chunk(chunk)
for (p0, p1) in zip(model.parameters(), torch_model.parameters()):
assert torch.allclose(p0, p1.grad, atol=1e-3, rtol=1e-5), "{}".format(torch.max(torch.abs(p0 - p1.grad)).item())
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
optimizer.zero_grad()
logits = model(input_ids, attn_mask)
logits = logits.float()
loss = criterion(logits, input_ids)
optimizer.backward(loss)
return logits
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
def exam_gpt_fwd_bwd(placement_policy):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
pg = ProcessGroup()
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
model.eval()
torch_model.eval()
set_seed(pg.dp_local_rank())
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 0:
break
logits = model(input_ids, attn_mask)
logits = logits.float()
loss = criterion(logits, input_ids)
model.backward(loss)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
assert torch.allclose(logits, torch_logits, rtol=0), "{} {} {}".format(
torch.max(torch.abs(logits - torch_logits)).item(), logits, torch_logits)
check_grad(model, torch_model)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_gpt_fwd_bwd()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_gpt(1)
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from functools import partial
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ZeroDDP
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.gemini_mgr import GeminiManager
from tests.test_tensor.common_utils import debug_print
from time import time
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict()
for key, value in torch_dict.items():
# key is 'module.model.PARAMETER', so we truncate it
key = key[7:]
if key == 'model.lm_head.weight':
continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key)
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
optimizer.zero_grad()
logits = model(input_ids, attn_mask)
logits = logits.float()
loss = criterion(logits, input_ids)
optimizer.backward(loss)
return logits
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
def exam_gpt_fwd_bwd(placement_policy):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
model.eval()
torch_model.eval()
set_seed(dist.get_rank() * 3 + 128)
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 2:
break
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
# debug_print([0], zero_logits, torch_logits)
zero_optim.step()
torch_optim.step()
check_param(model, torch_model)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_gpt_fwd_bwd()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_gpt(1)
...@@ -8,7 +8,7 @@ import torch.distributed as dist ...@@ -8,7 +8,7 @@ import torch.distributed as dist
import colossalai import colossalai
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.gemini.update import search_chunk_configuration from colossalai.gemini.chunk import search_chunk_configuration
from colossalai.utils import free_port, get_current_device from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup
...@@ -35,8 +35,7 @@ def exam_search_chunk_size(): ...@@ -35,8 +35,7 @@ def exam_search_chunk_size():
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = model_builder() model = model_builder()
init_1d_row_spec(model, pg_tp) init_1d_row_spec(model, pg_tp)
config_dict = search_chunk_configuration( config_dict = search_chunk_configuration(model,
model,
search_range_mb=1, search_range_mb=1,
search_interval_byte=16, search_interval_byte=16,
min_chunk_size_mb=0, min_chunk_size_mb=0,
......
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from functools import partial
from tests.test_tensor.common_utils import set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.parallel import ZeroDDP
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.gemini.gemini_mgr import GeminiManager
from tests.test_tensor.common_utils import debug_print
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('keep_gathered', [True, False])
def exam_state_dict(placement_policy, keep_gathered):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
torch_model = model_builder()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
model.train()
zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict()
for key, value in torch_dict.items():
if key == 'model.lm_head.weight':
continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('keep_gathered', [True, False])
def exam_load_state_dict(placement_policy, keep_gathered):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
set_seed(451)
torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
torch_dict = torch_model.state_dict()
model.load_state_dict(torch_dict, strict=False)
zero_dict = model.state_dict(only_rank_0=False)
for key, value in torch_dict.items():
if key == 'model.lm_head.weight':
continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_state_dict()
exam_load_state_dict()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_zero_ddp(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_ddp(1)
...@@ -2,99 +2,96 @@ import pytest ...@@ -2,99 +2,96 @@ import pytest
import colossalai import colossalai
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.gemini import ChunkManager
from functools import partial from functools import partial
from tests.test_tensor.common_utils import set_seed
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer from colossalai.zero import ZeroOptimizer
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize from colossalai.testing import parameterize
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import ProcessGroup from tests.test_tensor.common_utils import debug_print
def check_state(s1, s2):
for v1, v2 in zip(s1.values(), s2.values()):
if isinstance(v1, torch.Tensor):
v1 = v1.to(v2.device)
assert torch.equal(v1, v2), f'{torch.sum((v1-v2).abs())}'
else:
assert v1 == v2
def check_load_state_dict(optim, torch_optim):
for group, torch_group in zip(optim.optim.param_groups, torch_optim.param_groups):
for p, torch_p in zip(group['params'], torch_group['params']):
state = optim.optim.state[p]
torch_state = torch_optim.state[torch_p]
if p.storage().size() == 0:
assert len(state) == 0
check_state(state, torch_state)
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
def check_state_dict(state_dict, torch_state_dict):
for (k1, s1), (k2, s2) in zip(state_dict['state'].items(), torch_state_dict['state'].items()):
assert k1 == k2
check_state(s1, s2)
@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) @parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('only_rank_0', [False, True]) @parameterize('keep_gathered', [True, False])
def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy, only_rank_0): def exam_zero_optim_state_dict(placement_policy, keep_gathered):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable('gpt2') get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = model_builder() model = model_builder()
model = model.cuda()
torch_model = model_builder().cuda()
pg = ProcessGroup()
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=use_zero,
init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)
optim = HybridAdam(model.parameters(), lr=1e-3)
optim = ZeroOptimizer(optim, model, initial_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) set_seed(451)
torch_model = model_builder() # get a different model
for p in torch_model.parameters(): world_size = torch.distributed.get_world_size()
p.grad = torch.rand_like(p) config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
torch_optim.step() if placement_policy != 'cuda':
torch_state_dict = torch_optim.state_dict() init_device = torch.device('cpu')
optim.load_state_dict(torch_state_dict) else:
check_load_state_dict(optim, torch_optim) init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
state_dict = optim.state_dict(only_rank_0) gemini_manager = GeminiManager(placement_policy, chunk_manager)
if not only_rank_0 or pg.rank() == 0: model = ZeroDDP(model, gemini_manager, pin_memory=True)
check_state_dict(state_dict, torch_state_dict)
optimizer = HybridAdam(model.parameters())
optim = ZeroOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32
set_seed(dist.get_rank() * 3 + 128)
model.train()
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 0:
break
optim.zero_grad()
logits = model(input_ids, attn_mask)
logits = logits.float()
loss = criterion(logits, input_ids)
optim.backward(loss)
optim.step()
optim_state_dict = optim.state_dict()
optim.load_state_dict(optim_state_dict)
new_state = optim.state_dict()['state']
org_state = optim_state_dict['state']
for k, v in org_state.items():
w = new_state[k]
for n, m in v.items():
if isinstance(m, torch.Tensor):
o = w[n]
if m.device != o.device:
o = o.to(m.device)
assert torch.equal(m, o)
else:
assert m == w[n]
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
config = {} config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_zero_optim_state_dict() exam_zero_optim_state_dict()
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_zero_optim_state_dict(world_size): def test_zero_optim(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_zero_optim_state_dict(2) test_zero_optim(1)
import torch
import colossalai
import pytest
import torch.multiprocessing as mp
from typing import List
from functools import partial
from colossalai.gemini import ChunkManager
from colossalai.testing import rerun_if_address_is_in_use, parameterize
from colossalai.utils import free_port
from colossalai.tensor import ProcessGroup as ColoProcessGroup
def check_has_params(params: List[torch.Tensor], has_tensors: List[bool]):
for p, has_tensor in zip(params, has_tensors):
if has_tensor:
assert p.storage().size() > 0
assert p.device.type == 'cuda'
else:
assert p.storage().size() == 0
# HAS_TENSORS[use_chunk][use_zero]
HAS_TENSORS = {
True: {
True: [[True, True, False], [False, False, True]],
False: [[True, True, True], [True, True, True]]
},
False: {
True: [[True, False, True], [False, True, False]],
False: [[True, True, True], [True, True, True]]
}
}
TOTAL_MEM = {True: {True: [512, 512], False: [1024, 1024]}, False: {True: [512, 256], False: [768, 768]}}
@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
def run_chunk_zero(use_chunk, use_zero):
pg = ColoProcessGroup()
rank = pg.rank()
if rank == 0:
print(f'use_chunk={use_chunk}, use_zero={use_zero}')
params = [torch.rand(8, 8) for _ in range(3)]
chunk_size = 128 if use_chunk else None
chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero)
chunk_manager.create_group('param')
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == 0
for p in params:
chunk_manager.append_tensor(p, 'param')
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank]
chunks = chunk_manager.get_chunks(params)
for chunk in chunks:
chunk_manager.access_chunk(chunk)
check_has_params(params, [True, True, True])
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][False][rank]
for chunk in chunks:
chunk_manager.release_chunk(chunk)
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
for chunk in chunks:
chunk_manager.move_chunk(chunk, torch.device('cpu'))
assert chunk_manager.total_mem['cpu'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
assert chunk_manager.total_mem['cuda'] == 0
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_chunk_zero()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_chunk_mapping(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_chunk_mapping(2)
...@@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use ...@@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.gemini import ChunkManager from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from functools import partial from functools import partial
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
...@@ -21,20 +21,20 @@ from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, Compute ...@@ -21,20 +21,20 @@ from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, Compute
from tests.test_tensor.model.test_gpt2 import init_megatron_spec from tests.test_tensor.model.test_gpt2 import init_megatron_spec
def check_param_equal(model, torch_model, pg: ProcessGroup): def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup):
for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()): zero_dict = model.state_dict(only_rank_0=False)
if p.storage().size() > 0: torch_dict = torch_model.state_dict()
assert p.dtype == torch.float16
assert tensor_shard_equal(tp.to(dtype=p.dtype, device=p.device), p, pg.tp_local_rank(),
pg.tp_world_size()), f'{tp} vs {p}\n{n}:\n\t{tp.shape} vs {p.shape}'
for key, value in torch_dict.items():
def check_grad_equal(model, torch_model, pg: ProcessGroup): # key is 'module.model.PARAMETER', so we truncate it
for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()): key = key[7:]
if p.grad is not None: if key == 'model.lm_head.weight':
assert tensor_shard_equal(tp.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad, continue
pg.tp_local_rank(), pg.tp_world_size()), \ assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
f'{tp.grad} vs {p.grad}\n{n}:\n\t{tp.grad.shape} vs {p.grad.shape} in {pg.rank()}' temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert tensor_shard_equal(value, temp_zero_value, pg.tp_local_rank(), pg.tp_world_size()), \
"parameter '{}' has problem.".format(key)
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
...@@ -62,10 +62,8 @@ def init_1d_col_spec(model, pg: ProcessGroup): ...@@ -62,10 +62,8 @@ def init_1d_col_spec(model, pg: ProcessGroup):
p.set_tensor_spec(*spec) p.set_tensor_spec(*spec)
@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
@parameterize('placement_policy', ['cuda', 'cpu']) @parameterize('placement_policy', ['cuda', 'cpu'])
def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None): def run_gpt(placement_policy, tp_init_spec_func=None):
set_seed(42) set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2') get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
...@@ -89,15 +87,20 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None): ...@@ -89,15 +87,20 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
if tp_init_spec_func: if tp_init_spec_func:
tp_init_spec_func(model, pg) tp_init_spec_func(model, pg)
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None dp_world_size = pg.dp_world_size()
chunk_manager = ChunkManager(chunk_size, config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
pg, config_dict[dp_world_size]['chunk_size'] = 5000
enable_distributed_storage=use_zero, config_dict[dp_world_size]['keep_gathered'] = False
init_device=GeminiManager.get_default_device(placement_policy)) if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager) model = ZeroDDP(model, gemini_manager, pin_memory=True)
optim = HybridAdam(model.parameters(), lr=1e-3)
optim = ZeroOptimizer(optim, model, initial_scale=1) optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
...@@ -105,7 +108,7 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None): ...@@ -105,7 +108,7 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
print(chunk_manager) print(chunk_manager)
check_param_equal(model, torch_model, pg) check_param(model, torch_model, pg)
model.eval() model.eval()
torch_model.eval() torch_model.eval()
...@@ -115,13 +118,13 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None): ...@@ -115,13 +118,13 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
if i > 2: if i > 2:
break break
input_ids_colo = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg)) input_ids_colo = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg))
logits = run_fwd_bwd(model, criterion, optim, input_ids_colo, attn_mask) zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids_colo, attn_mask)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
assert tensor_equal(logits, torch_logits) assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
check_grad_equal(model, torch_model, pg)
optim.step() zero_optim.step()
torch_optim.step() torch_optim.step()
check_param_equal(model, torch_model, pg) check_param(model, torch_model, pg)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment