Unverified Commit bb4e9a31 authored by HELSON's avatar HELSON Committed by GitHub
Browse files

[zero] add inference mode and its unit test (#2418)

parent 63be79d5
...@@ -50,6 +50,17 @@ class GeminiManager: ...@@ -50,6 +50,17 @@ class GeminiManager:
self._warmup = True self._warmup = True
self._comp_cuda_demand_time = 0 self._comp_cuda_demand_time = 0
def reset_attributes(self):
self._compute_idx = -1
self._h2d_volume = 0
self._d2h_volume = 0
self._layout_time = 0
self._evict_time = 0
self._comp_cuda_demand_time = 0
def is_warmup(self):
return self._warmup
def memstats(self): def memstats(self):
"""memstats """memstats
...@@ -73,12 +84,7 @@ class GeminiManager: ...@@ -73,12 +84,7 @@ class GeminiManager:
if self._mem_stats_collector and self._warmup: if self._mem_stats_collector and self._warmup:
self._mem_stats_collector.finish_collection() self._mem_stats_collector.finish_collection()
self._warmup = False self._warmup = False
self._compute_idx = -1 self.reset_attributes()
self._h2d_volume = 0
self._d2h_volume = 0
self._layout_time = 0
self._evict_time = 0
self._comp_cuda_demand_time = 0
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None: def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
""" Adjust the layout of stateful tensors according to the information provided """ Adjust the layout of stateful tensors according to the information provided
......
...@@ -268,12 +268,35 @@ class ZeroDDP(ColoDDP): ...@@ -268,12 +268,35 @@ class ZeroDDP(ColoDDP):
self._logger = get_dist_logger() self._logger = get_dist_logger()
def _post_forward(self):
"""This function is only triggered for inference.
"""
access_list = list(self.chunk_manager.accessed_chunks)
# we need to scatter all accessed chunks and move them to their original places
for chunk in access_list:
assert chunk.can_release
self.chunk_manager.release_chunk(chunk)
first_param = next(iter(chunk.tensors_info))
self.chunk_manager.move_chunk(chunk, self.grads_device[first_param])
assert self.chunk_manager.accessed_mem == 0
# reset all recorded attributes
self.gemini_manager.reset_attributes()
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
# check whether we are in a inference mode
grad_flag = torch.is_grad_enabled()
if not grad_flag:
assert not self.gemini_manager.is_warmup(), "You should run a completed iteration as your warmup iter"
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True) self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter(*args) self.gemini_manager.pre_iter(*args)
with ColoParamOpHookManager.use_hooks(self.param_op_hook): with ColoParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs) outputs = self.module(*args, **kwargs)
# scatter chunks in the inference mode
if not grad_flag:
self._post_forward()
if self.force_outputs_fp32: if self.force_outputs_fp32:
return _cast_float(outputs, torch.float) return _cast_float(outputs, torch.float)
return outputs return outputs
......
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.nn.parallel import ZeroDDP
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict()
for key, value in torch_dict.items():
# key is 'module.model.PARAMETER', so we truncate it
key = key[7:]
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('model_name', ['gpt2'])
def exam_inference(placement_policy, model_name: str):
set_seed(19360226)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
torch_model = model_builder().cuda()
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=128)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
init_dev = get_current_device()
with ColoInitContext(device=init_dev):
model = model_builder()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data)
world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128)
model.eval()
torch_model.eval()
set_seed(dist.get_rank() * 3 + 128)
train_dataloader = iter(train_dataloader)
def train_iter():
input_ids, label = next(train_dataloader)
input_ids, label = input_ids.cuda(), label.cuda()
zero_optim.zero_grad()
torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
assert_close(torch_loss, loss)
zero_optim.step()
torch_optim.step()
check_param(model, torch_model)
def inference_iter():
input_ids, label = next(train_dataloader)
input_ids, label = input_ids.cuda(), label.cuda()
with torch.no_grad():
torch_output = torch_model(input_ids)
torch_loss = criterion(torch_output.float(), label)
zero_output = model(input_ids)
zero_loss = criterion(zero_output.float(), label)
assert_close(torch_loss, zero_loss)
train_iter()
inference_iter()
train_iter()
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_inference()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_inference(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_inference(1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment