Unverified Commit c6a1a626 authored by HELSON's avatar HELSON Committed by GitHub
Browse files

[hotfix] fix zero's incompatibility with checkpoint in torch-1.12 (#1786)

* [hotfix] fix zero's incompatibility with checkpoint in torch-1.12

* [zero] add cpu shard init

* [zero] add tiny example test

* [colo_tensor] fix bugs for torch-1.11
parent 32c1b843
import torch
import torch.distributed as dist
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Optional, Dict, List from typing import Dict, List, Optional
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.utils import get_current_device
class TensorState(Enum): class TensorState(Enum):
...@@ -58,6 +59,7 @@ class Chunk: ...@@ -58,6 +59,7 @@ class Chunk:
process_group: ColoProcessGroup, process_group: ColoProcessGroup,
dtype: torch.dtype, dtype: torch.dtype,
init_device: Optional[torch.device] = None, init_device: Optional[torch.device] = None,
cpu_shard_init: bool = False,
keep_gathered: bool = False, keep_gathered: bool = False,
pin_memory: bool = False) -> None: pin_memory: bool = False) -> None:
""" """
...@@ -102,6 +104,11 @@ class Chunk: ...@@ -102,6 +104,11 @@ class Chunk:
self.cpu_shard = None self.cpu_shard = None
self.is_gathered = True self.is_gathered = True
# configure the init deivce of the shard
# no-offload default: fp16, fp32 -> CUDA
# offload default: fp16, fp32 -> CPU
self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device()
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size() self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
self.shard_mem = self.chunk_mem // self.pg_size self.shard_mem = self.chunk_mem // self.pg_size
...@@ -242,11 +249,8 @@ class Chunk: ...@@ -242,11 +249,8 @@ class Chunk:
self.tensors_state_monitor[tensor_state] += 1 self.tensors_state_monitor[tensor_state] += 1
self.utilized_size = new_utilized_size self.utilized_size = new_utilized_size
def close_chunk(self, shard_dev: Optional[torch.device] = None): def close_chunk(self):
"""Close the chunk. Any tensor can't be appended to a closed chunk later. """Close the chunk. Any tensor can't be appended to a closed chunk later.
Args:
shard_dev: the device where the shard locates
""" """
# sanity check # sanity check
assert self.chunk_temp is not None assert self.chunk_temp is not None
...@@ -265,21 +269,16 @@ class Chunk: ...@@ -265,21 +269,16 @@ class Chunk:
self.chunk_temp = None self.chunk_temp = None
self.__scatter() self.__scatter()
# always gathered chunk does not have shard
if self.keep_gathered: if self.keep_gathered:
if shard_dev is None: return
shard_dev = get_current_device()
else:
assert shard_dev.type == 'cuda'
elif shard_dev is None:
shard_dev = torch.device('cpu')
if self.pin_memory or shard_dev.type == 'cpu': if self.pin_memory or self.shard_device.type == 'cpu':
self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory) self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory)
self.cpu_shard.copy_(self.cuda_shard) self.cpu_shard.copy_(self.cuda_shard)
self.cpu_vis_flag = True # cpu_shard has been visited self.cpu_vis_flag = True # cpu_shard has been visited
if shard_dev.type == 'cpu': if self.shard_device.type == 'cpu':
self.cuda_shard = None self.cuda_shard = None
def shard_move(self, device: torch.device, force_copy: bool = False): def shard_move(self, device: torch.device, force_copy: bool = False):
......
import torch
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
from collections import deque from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple
from colossalai.utils import get_current_device import torch
from colossalai.gemini.chunk import Chunk, ChunkFullError, TensorState
from colossalai.tensor import ColoTensor from colossalai.tensor import ColoTensor
from colossalai.gemini.chunk import ChunkFullError, TensorState, Chunk from colossalai.utils import get_current_device
class ChunkManager: class ChunkManager:
...@@ -31,13 +32,19 @@ class ChunkManager: ...@@ -31,13 +32,19 @@ class ChunkManager:
self.accessed_mem: int = 0 self.accessed_mem: int = 0
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int, pin_memory: bool = False) -> None: def append_tensor(self,
tensor: ColoTensor,
group_type: str,
config_key: int,
cpu_offload: bool = False,
pin_memory: bool = False) -> None:
"""Append a tensor to a chunk. """Append a tensor to a chunk.
Args: Args:
tensor: the tensor appended to the chunk tensor: the tensor appended to the chunk
group_type: the data type of the group group_type: the data type of the group
config_key: the key of the group's name, usually the size of the dp world config_key: the key of the group's name, usually the size of the dp world
cpu_offload: if True, the chunk will be closed on CPU
pin_memory: whether the chunk is pinned in the cpu memory pin_memory: whether the chunk is pinned in the cpu memory
""" """
assert tensor not in self.tensor_chunk_map assert tensor not in self.tensor_chunk_map
...@@ -67,6 +74,7 @@ class ChunkManager: ...@@ -67,6 +74,7 @@ class ChunkManager:
chunk_size=chunk_size, chunk_size=chunk_size,
process_group=tensor.process_group, process_group=tensor.process_group,
dtype=tensor.dtype, dtype=tensor.dtype,
cpu_shard_init=cpu_offload,
pin_memory=pin_memory, pin_memory=pin_memory,
**chunk_kwargs, **chunk_kwargs,
) )
...@@ -206,9 +214,8 @@ class ChunkManager: ...@@ -206,9 +214,8 @@ class ChunkManager:
return self.chunk_groups[group_name] return self.chunk_groups[group_name]
def __close_one_chunk(self, chunk: Chunk): def __close_one_chunk(self, chunk: Chunk):
device = get_current_device() if chunk.keep_gathered else self.device # keep gathered chunk in cuda
self.__sub_memroy_usage(chunk.memory_usage) self.__sub_memroy_usage(chunk.memory_usage)
chunk.close_chunk(device) chunk.close_chunk()
self.__add_memory_usage(chunk.memory_usage) self.__add_memory_usage(chunk.memory_usage)
def __sub_memroy_usage(self, usage: Dict[str, int]): def __sub_memroy_usage(self, usage: Dict[str, int]):
......
import torch
import functools import functools
from .memory_tracer.memstats_collector import MemStatsCollectorV2
from typing import List, Optional, Tuple
from time import time from time import time
from typing import List, Optional, Tuple
import torch
from colossalai.gemini.chunk import Chunk, ChunkManager from colossalai.gemini.chunk import Chunk, ChunkManager
from .memory_tracer.memstats_collector import MemStatsCollectorV2
from .placement_policy import PlacementPolicyFactory from .placement_policy import PlacementPolicyFactory
...@@ -25,6 +28,7 @@ class GeminiManager: ...@@ -25,6 +28,7 @@ class GeminiManager:
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None: def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
assert placement_policy in PlacementPolicyFactory.get_polocy_names() assert placement_policy in PlacementPolicyFactory.get_polocy_names()
self.policy_name = placement_policy
policy_cls = PlacementPolicyFactory.create(placement_policy) policy_cls = PlacementPolicyFactory.create(placement_policy)
self._chunk_manager = chunk_manager self._chunk_manager = chunk_manager
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None
......
import torch
import itertools import itertools
import torch.distributed as dist from collections import OrderedDict
from functools import partial from functools import partial
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Dict, Iterable, List, Optional, Set from typing import Dict, Iterable, List, Optional, Set
import torch
import torch.distributed as dist
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from collections import OrderedDict from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ProcessGroup as ColoProcessGroup
from .reducer import Reducer from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.utils import get_current_device
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.gemini.chunk import TensorState, Chunk, ChunkManager from .reducer import Reducer
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
try: try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
...@@ -221,6 +224,7 @@ class ZeroDDP(ColoDDP): ...@@ -221,6 +224,7 @@ class ZeroDDP(ColoDDP):
self.overflow_counter = 0 self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = {} self.grads_device: Dict[torch.Tensor, torch.device] = {}
cpu_offload = self.gemini_manager.policy_name != 'cuda'
# TODO: get param order and filter unused params # TODO: get param order and filter unused params
for p in module.parameters(): for p in module.parameters():
assert isinstance(p, ColoParameter) assert isinstance(p, ColoParameter)
...@@ -232,10 +236,17 @@ class ZeroDDP(ColoDDP): ...@@ -232,10 +236,17 @@ class ZeroDDP(ColoDDP):
fp32_data = p.data.float() fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
p.data = p.data.half() p.data = p.data.half()
dp_world_size = p.process_group.dp_world_size() dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.append_tensor(p, 'fp16_param', dp_world_size, pin_memory) self.chunk_manager.append_tensor(tensor=p,
self.chunk_manager.append_tensor(fp32_p, 'fp32_param', dp_world_size, pin_memory) group_type='fp16_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.chunk_manager.append_tensor(tensor=fp32_p,
group_type='fp32_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.fp32_params.append(fp32_p) self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups() self.chunk_manager.close_all_groups()
...@@ -247,6 +258,10 @@ class ZeroDDP(ColoDDP): ...@@ -247,6 +258,10 @@ class ZeroDDP(ColoDDP):
chunk_32 = self.chunk_manager.get_chunk(fp32_p) chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16) chunk_32.init_pair(chunk_16)
# keep gathered chunks are in CUDA
if chunk_16.keep_gathered:
self.grads_device[p] = get_current_device()
self._logger = get_dist_logger() self._logger = get_dist_logger()
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
......
from .op_wrapper import _COLOSSAL_OPS
from .const import TensorType
from copy import copy from copy import copy
import torch
from functools import lru_cache from functools import lru_cache
from typing import Callable, Optional, Set
import torch
from colossalai.tensor import ColoTensorSpec from colossalai.tensor import ColoTensorSpec, ProcessGroup, ReplicaSpec
from colossalai.tensor import ProcessGroup, ReplicaSpec
from colossalai.tensor.dist_spec_mgr import DistSpecManager from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
from typing import Optional, Set, Callable
from .const import TensorType
from .op_wrapper import _COLOSSAL_OPS
@lru_cache(None) @lru_cache(None)
...@@ -67,6 +68,7 @@ class ColoTensor(torch.Tensor): ...@@ -67,6 +68,7 @@ class ColoTensor(torch.Tensor):
data (torch.Tensor): a torch tensor used as the payload the colotensor. data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()). spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
""" """
torch_minor = int(torch.__version__.split('.')[1])
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor': def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
""" """
...@@ -166,6 +168,16 @@ class ColoTensor(torch.Tensor): ...@@ -166,6 +168,16 @@ class ColoTensor(torch.Tensor):
if func in _COLOSSAL_OPS: if func in _COLOSSAL_OPS:
func = _COLOSSAL_OPS[func] func = _COLOSSAL_OPS[func]
if cls.torch_minor >= 12:
# in order to trigger pre-op hook in the forward of checkpoint module
# we have to capture the `backward` function
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
if func is torch.Tensor.backward:
assert len(args) == 1 # only has 1 paramter
backward_tensor = torch.Tensor(args[0])
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
return backward_tensor.backward(**tensor_kwargs)
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs) ret = func(*args, **kwargs)
if func in _get_my_nowrap_functions(): if func in _get_my_nowrap_functions():
......
from enum import Enum
from typing import Dict, Set, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from enum import Enum
from torch.optim import Optimizer
from torch.nn import Parameter from torch.nn import Parameter
from colossalai.nn.parallel.data_parallel import ZeroDDP from torch.optim import Optimizer
from typing import Dict, Tuple, Set
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, disposable from colossalai.nn.parallel.data_parallel import ZeroDDP
from colossalai.gemini.chunk import Chunk, ChunkManager from colossalai.utils import disposable, get_current_device
class OptimState(Enum): class OptimState(Enum):
...@@ -219,6 +221,8 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -219,6 +221,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
def get_range_pair(local_chunk: Chunk, local_param: Parameter): def get_range_pair(local_chunk: Chunk, local_param: Parameter):
param_info = local_chunk.tensors_info[local_param] param_info = local_chunk.tensors_info[local_param]
if local_chunk.keep_gathered:
return param_info.offset, param_info.end
begin = max(0, param_info.offset - local_chunk.shard_begin) begin = max(0, param_info.offset - local_chunk.shard_begin)
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin) end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
return begin, end return begin, end
......
import torch from functools import partial
import colossalai
import pytest import pytest
import torch.multiprocessing as mp import torch
import torch.distributed as dist import torch.distributed as dist
from functools import partial import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use, parameterize
from colossalai.utils import free_port, get_current_device import colossalai
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor import ColoParameter
from colossalai.gemini import TensorState from colossalai.gemini import TensorState
from colossalai.gemini.chunk import Chunk from colossalai.gemini.chunk import Chunk
from colossalai.tensor import ColoParameter
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
def dist_sum(x): def dist_sum(x):
...@@ -42,6 +44,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory): ...@@ -42,6 +44,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
process_group=pg, process_group=pg,
dtype=torch.float32, dtype=torch.float32,
init_device=init_device, init_device=init_device,
cpu_shard_init=True,
keep_gathered=keep_gathered, keep_gathered=keep_gathered,
pin_memory=pin_memory) pin_memory=pin_memory)
......
...@@ -40,7 +40,8 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): ...@@ -40,7 +40,8 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
def exam_gpt_fwd_bwd(placement_policy): @parameterize('keep_gather', [False, True])
def exam_gpt_fwd_bwd(placement_policy, keep_gather):
set_seed(42) set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2') get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
...@@ -55,7 +56,7 @@ def exam_gpt_fwd_bwd(placement_policy): ...@@ -55,7 +56,7 @@ def exam_gpt_fwd_bwd(placement_policy):
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict) chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True) model = ZeroDDP(model, gemini_manager, pin_memory=True)
...@@ -101,4 +102,4 @@ def test_gpt(world_size): ...@@ -101,4 +102,4 @@ def test_gpt(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_gpt(1) test_gpt(4)
...@@ -9,7 +9,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -9,7 +9,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai import colossalai
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
...@@ -98,10 +98,55 @@ def exam_gpt_fwd_bwd(placement_policy): ...@@ -98,10 +98,55 @@ def exam_gpt_fwd_bwd(placement_policy):
check_param(model, torch_model) check_param(model, torch_model)
@parameterize('placement_policy', ['cuda', 'cpu'])
def exam_tiny_example(placement_policy):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
model.eval()
torch_model.eval()
set_seed(dist.get_rank() * 3 + 128)
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 2:
break
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
# debug_print([0], zero_logits, torch_logits)
zero_optim.step()
torch_optim.step()
check_param(model, torch_model)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
config = {} config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_gpt_fwd_bwd() exam_gpt_fwd_bwd()
exam_tiny_example()
@pytest.mark.dist @pytest.mark.dist
...@@ -113,4 +158,4 @@ def test_gpt(world_size): ...@@ -113,4 +158,4 @@ def test_gpt(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_gpt(1) test_gpt(2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment