"git@developer.sourcefind.cn:OpenDAS/openfold.git" did not exist on "6df89c763f239f7dec9468ec4f341f5861ce7232"
Unverified Commit 1f894e03 authored by ver217's avatar ver217 Committed by GitHub
Browse files

[gemini] zero supports gemini (#1093)

* add placement policy

* add gemini mgr

* update mem stats collector

* update zero

* update zero optim

* fix bugs

* zero optim monitor os

* polish unit test

* polish unit test

* add assert
parent 2b2dc1c8
import functools
from .memory_tracer.memstats_collector import MemStatsCollectorV2
from typing import List, Optional, Tuple
from time import time
from colossalai.tensor.chunk import Chunk, ChunkManager
from .placement_policy import PlacementPolicy, PlacementPolicyFactory
class GeminiManager:
"""
Stateful Tensor Manager, inspired from PatrickStar
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
# TODO: remove assert
assert placement_policy == 'cuda', 'placement_policy can only be "cuda" now'
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
policy_cls = PlacementPolicyFactory.create(placement_policy)
self._chunk_manager = chunk_manager
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1
self._cpu_gpu_move_volume = 0
self._layout_time = 0
self._evict_time = 0
self._warmup = True
def pre_iter(self):
if self._mem_stats_collector and self._warmup:
self._mem_stats_collector.start_collection()
def post_iter(self):
"""This function must be called when each iteration finishes
"""
if self._mem_stats_collector and self._warmup:
self._mem_stats_collector.finish_collection()
self._warmup = False
self._compute_idx = -1
self._cpu_gpu_move_volume = 0
self._layout_time = 0
self._evict_time = 0
def adjust_layout(self, chunks: Tuple[Chunk, ...], group_name: str) -> None:
""" Adjust the layout of statefuil tensor according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
"""
# find stateful tensor in state COMPUTE
start = time()
self._record_chunks_order(chunks)
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks, group_name)
self._layout_time += time() - start
vol, evict_time = self._placement_policy.evict_tensors(hold_cuda_tensor_list,
cuda_demand=cuda_demand,
warmup=self._warmup,
compute_list=self._compute_list,
compute_idx=self._compute_idx)
self._cpu_gpu_move_volume += vol
self._evict_time += evict_time
# move COMPUTE tensors to CUDA
self._cpu_gpu_move_volume += cuda_demand
@property
def cpu_gpu_move_volume(self):
return self._cpu_gpu_move_volume
# @functools.lru_cache(maxsize=None)
# TODO: test lru
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...], group_name: str):
cuda_demand = 0
for chunk in chunks:
if chunk.device_type == 'cpu' or chunk.is_free:
cuda_demand += chunk.mem
can_evict_chunks = []
for chunk in self._chunk_manager.chunk_groups[group_name]:
if not chunk.is_free and chunk.device_type == 'cuda' and chunk.can_move_device:
can_evict_chunks.append(chunk)
return cuda_demand, can_evict_chunks
def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:
self._compute_idx += 1
if self._warmup and self._placement_policy.need_mem_stats:
self._compute_list.append(chunks)
@property
def default_device(self):
return self._placement_policy.get_default_device()
def sample_overall_data(self):
if self._mem_stats_collector:
self._mem_stats_collector.sample_overall_data()
def sample_model_data(self):
if self._mem_stats_collector:
self._mem_stats_collector.sample_model_data()
@property
def chunk_manager(self):
return self._chunk_manager
@property
def cuda_margin_mem(self) -> Optional[float]:
if self._mem_stats_collector:
return self._mem_stats_collector.cuda_margin_mem
return None
@property
def is_cuda_margin_mem_avail(self) -> bool:
return self._placement_policy.need_mem_stats
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.utils.memory import colo_device_memory_used from colossalai.utils.memory import colo_device_memory_used, colo_device_memory_capacity
from colossalai.utils import get_current_device
from colossalai.gemini.stateful_tensor import StatefulTensor from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.tensor import ChunkManager from colossalai.tensor import ChunkManager
...@@ -145,3 +146,7 @@ class MemStatsCollectorV2(MemStatsCollector): ...@@ -145,3 +146,7 @@ class MemStatsCollectorV2(MemStatsCollector):
cpu_mem = self._chunk_manager.total_mem['cpu'] cpu_mem = self._chunk_manager.total_mem['cpu']
self._model_data_cuda_list.append(cuda_mem) self._model_data_cuda_list.append(cuda_mem)
self._model_data_cpu_list.append(cpu_mem) self._model_data_cpu_list.append(cpu_mem)
@property
def cuda_margin_mem(self) -> float:
return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda'))
from abc import ABC, abstractmethod
from time import time
from typing import List, Optional, Tuple, Dict
import torch
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollectorV2
from typing import Type
import functools
from colossalai.tensor.chunk import Chunk, ChunkManager
class PlacementPolicy(ABC):
need_mem_stats: bool = False
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
self.chunk_manager = chunk_manager
self.mem_stats_collector: Optional[MemStatsCollectorV2] = mem_stats_collector
@abstractmethod
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> None:
raise NotImplementedError
@staticmethod
def get_default_device() -> torch.device:
return torch.device('cpu')
class CPUPlacementPolicy(PlacementPolicy):
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> int:
volume = 0
for chunk in can_evict_chunks:
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
volume += chunk.mem
return volume, 0
class CUDAPlacementPolicy(PlacementPolicy):
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> int:
return 0, 0
@staticmethod
def get_default_device() -> torch.device:
return get_current_device()
class AutoPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = True
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
# model data will use 1-self._warmup_non_model_data_ratio CUDA memory in warmup phase
# TODO(ver217): make these args configurable
self._warmup_non_model_data_ratio: float = 0.8
self._steady_cuda_cap_ratio: float = 0.9
def evict_tensors(self,
can_evict_chunks: List[Chunk],
cuda_demand: int = 0,
warmup: bool = True,
compute_list: List[Tuple[Chunk, ...]] = [],
compute_idx: int = 0,
**kwargs) -> int:
"""
Evict tensors from CUDA device.
Args:
hold_cuda_tensor_list (List[StatefulTensor]): the list of tensor in state of HOLD-like
cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0.
warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True.
compute_list (List[StatefulTensor], optional): TODO. Defaults to [].
compute_idx (int, optional): the idx of computing device. Defaults to 0.
Raises:
RuntimeError:
Returns:
int: the volume of memory that is evicted
"""
start = time()
cuda_capacity = colo_device_memory_capacity(get_current_device())
used_cuda_model_data = self.chunk_manager.total_mem['cuda']
if warmup:
# We designate a part of CUDA memory for model data in warmup iterations.
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
else:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda')
cuda_capacity *= self._steady_cuda_cap_ratio
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
freed_cuda_model_data = 0
end = time()
if avail_cuda_model_data < cuda_demand:
# Move cuda_demand - avail_cuda_model_data volume of tensors
# to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
to_free_chunks = can_evict_chunks
if not warmup:
to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list))
# print(self._sort_can_evict_chunks.cache_info())
end = time()
for chunk in to_free_chunks:
if freed_cuda_model_data >= to_free_cuda_model_data:
break
freed_cuda_model_data += chunk.mem
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
if freed_cuda_model_data < to_free_cuda_model_data:
raise RuntimeError(
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
)
return freed_cuda_model_data, end - start
@staticmethod
@functools.lru_cache(maxsize=None)
def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list:
next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks}
for i in range(len(compute_list) - 1, compute_idx, -1):
for chunk in compute_list[i]:
if chunk in next_compute_idx:
next_compute_idx[chunk] = i
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
return [t for (t, idx) in next_compute_idx]
class PlacementPolicyFactory:
policies: Dict[str, PlacementPolicy] = {
'cpu': CPUPlacementPolicy,
'cuda': CUDAPlacementPolicy,
'auto': AutoPlacementPolicy
}
@staticmethod
def create(policy_name: str) -> Type[PlacementPolicy]:
if policy_name not in PlacementPolicyFactory.policies:
raise TypeError(f"Unknown tensor placement policy {policy_name}")
return PlacementPolicyFactory.policies[policy_name]
@staticmethod
def get_polocy_names():
return tuple(PlacementPolicyFactory.policies.keys())
@staticmethod
def get_default_device(policy_name: str) -> torch.device:
policy_cls = PlacementPolicyFactory.create(policy_name)
return policy_cls.get_default_device()
...@@ -4,8 +4,11 @@ from colossalai.core import global_context as gpc ...@@ -4,8 +4,11 @@ from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from functools import partial from functools import partial
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2 from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor.chunk import ChunkManager, TensorState from colossalai.tensor.chunk import ChunkManager, TensorState, Chunk
from colossalai.tensor.param_op_hook import use_param_op_hooks from colossalai.tensor.param_op_hook import use_param_op_hooks
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Dict
from colossalai.logging import get_dist_logger
def free_storage(data: torch.Tensor) -> None: def free_storage(data: torch.Tensor) -> None:
...@@ -89,12 +92,14 @@ class ColoDDP(torch.nn.Module): ...@@ -89,12 +92,14 @@ class ColoDDP(torch.nn.Module):
class ColoDDPV2(ColoDDP): class ColoDDPV2(ColoDDP):
def __init__(self, module: torch.nn.Module, chunk_manager: ChunkManager) -> None: def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
super().__init__(module) super().__init__(module)
self.chunk_manager = chunk_manager self.gemini_manager = gemini_manager
self.param_op_hook = ZeROHookV2(chunk_manager) self.chunk_manager = gemini_manager.chunk_manager
self.param_op_hook = ZeROHookV2(gemini_manager)
self.fp32_params = [] self.fp32_params = []
self.overflow_counter = 0 self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = {}
# TODO: get param order and filter unused params # TODO: get param order and filter unused params
for p in module.parameters(): for p in module.parameters():
assert p.dtype == torch.half assert p.dtype == torch.half
...@@ -102,22 +107,32 @@ class ColoDDPV2(ColoDDP): ...@@ -102,22 +107,32 @@ class ColoDDPV2(ColoDDP):
self.chunk_manager.append_tensor(p, 'fp16_param') self.chunk_manager.append_tensor(p, 'fp16_param')
self.chunk_manager.append_tensor(fp32_p, 'fp32_param') self.chunk_manager.append_tensor(fp32_p, 'fp32_param')
self.fp32_params.append(fp32_p) self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device
self._logger = get_dist_logger()
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
self.module.zero_grad(set_to_none=True) self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter()
with use_param_op_hooks(self.param_op_hook): with use_param_op_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs) outputs = self.module(*args, **kwargs)
self.chunk_manager.exec_lazy_release() self.chunk_manager.exec_lazy_release()
return outputs return outputs
def _post_backward(self): def _setup_grads_ptr(self):
self.chunk_manager.exec_lazy_release()
for p in self.module.parameters(): for p in self.module.parameters():
if self.chunk_manager.get_chunk(p).is_free or not p.requires_grad: if self.chunk_manager.get_chunk(p).is_free or not p.requires_grad:
p.grad = None p.grad = None
else: else:
p.grad = p.data p.grad = p.data
def _post_backward(self):
self.chunk_manager.exec_lazy_release()
self._setup_grads_ptr()
self._logger.info(
f'layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, PCIE move vol: {self.gemini_manager._cpu_gpu_move_volume}B'
)
self.gemini_manager.post_iter()
def backward(self, loss: torch.Tensor): def backward(self, loss: torch.Tensor):
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook): with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook):
loss.backward() loss.backward()
...@@ -141,7 +156,12 @@ class ColoDDPV2(ColoDDP): ...@@ -141,7 +156,12 @@ class ColoDDPV2(ColoDDP):
self.chunk_manager.release_chunk(chunk) self.chunk_manager.release_chunk(chunk)
if reduced and not chunk.is_free: if reduced and not chunk.is_free:
self.overflow_counter += chunk.has_inf_or_nan self.overflow_counter += chunk.has_inf_or_nan
self.chunk_manager.move_chunk(chunk, self.grads_device[p])
return empty_grad return empty_grad
def zero_grad(self, set_to_none: bool = False) -> None: def zero_grad(self, set_to_none: bool = False) -> None:
self.module.zero_grad(set_to_none=True) self.module.zero_grad(set_to_none=True)
def _set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
for tensor in chunk.get_tensors():
self.grads_device[tensor] = device
...@@ -178,6 +178,9 @@ class Chunk: ...@@ -178,6 +178,9 @@ class Chunk:
def __eq__(self, __o: object) -> bool: def __eq__(self, __o: object) -> bool:
return self is __o return self is __o
def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())
class ChunkManager: class ChunkManager:
...@@ -234,6 +237,10 @@ class ChunkManager: ...@@ -234,6 +237,10 @@ class ChunkManager:
def access_chunk(self, chunk: Chunk) -> None: def access_chunk(self, chunk: Chunk) -> None:
if chunk in self.accessed_chunks: if chunk in self.accessed_chunks:
if chunk.device_type != 'cuda':
self.total_mem[chunk.device_type] -= chunk.mem
chunk.move_device(get_current_device())
self.total_mem[chunk.device_type] += chunk.mem
return return
if not chunk.is_free: if not chunk.is_free:
self.total_mem[chunk.device_type] -= chunk.mem self.total_mem[chunk.device_type] -= chunk.mem
......
...@@ -5,6 +5,7 @@ from enum import Enum ...@@ -5,6 +5,7 @@ from enum import Enum
from typing import List from typing import List
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from colossalai.gemini.gemini_mgr import GeminiManager
class TrainingPhase(Enum): class TrainingPhase(Enum):
...@@ -14,9 +15,10 @@ class TrainingPhase(Enum): ...@@ -14,9 +15,10 @@ class TrainingPhase(Enum):
class ZeROHookV2(ParamOpHook): class ZeROHookV2(ParamOpHook):
def __init__(self, chunk_manager: ChunkManager) -> None: def __init__(self, gemini_manager: GeminiManager) -> None:
super().__init__() super().__init__()
self._chunk_manager = chunk_manager self._gemini_manager = gemini_manager
self._chunk_manager = gemini_manager.chunk_manager
self._training_phase = TrainingPhase.FORWARD self._training_phase = TrainingPhase.FORWARD
def pre_op(self, params): def pre_op(self, params):
...@@ -24,9 +26,11 @@ class ZeROHookV2(ParamOpHook): ...@@ -24,9 +26,11 @@ class ZeROHookV2(ParamOpHook):
for p in params: for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._chunk_manager.exec_lazy_release() self._chunk_manager.exec_lazy_release()
# TODO: evict chunks self._gemini_manager.sample_overall_data()
self._gemini_manager.adjust_layout(chunks, 'fp16_param')
for chunk in chunks: for chunk in chunks:
self._chunk_manager.access_chunk(chunk) self._chunk_manager.access_chunk(chunk)
self._gemini_manager.sample_model_data()
def post_op(self, params): def post_op(self, params):
for p in params: for p in params:
......
...@@ -7,6 +7,7 @@ from typing import Dict ...@@ -7,6 +7,7 @@ from typing import Dict
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, disposable
class OptimState(Enum): class OptimState(Enum):
...@@ -19,6 +20,7 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -19,6 +20,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
def __init__(self, def __init__(self,
optim: Optimizer, optim: Optimizer,
module: ColoDDPV2, module: ColoDDPV2,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32, initial_scale: float = 2**32,
min_scale: float = 1, min_scale: float = 1,
growth_factor: float = 2, growth_factor: float = 2,
...@@ -29,6 +31,8 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -29,6 +31,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
super().__init__(optim) super().__init__(optim)
assert isinstance(module, ColoDDPV2) assert isinstance(module, ColoDDPV2)
self.module = module self.module = module
self.gemini_manager = module.gemini_manager
self.chunk_manager = self.gemini_manager.chunk_manager
self.optim_state = OptimState.UNSCALED self.optim_state = OptimState.UNSCALED
self.fp16_param_to_fp32_param: Dict[torch.Tensor, torch.Tensor] = {} self.fp16_param_to_fp32_param: Dict[torch.Tensor, torch.Tensor] = {}
for p, fp32_p in zip(module.parameters(), module.fp32_params): for p, fp32_p in zip(module.parameters(), module.fp32_params):
...@@ -45,6 +49,18 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -45,6 +49,18 @@ class ZeroOptimizer(ColossalaiOptimizer):
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device()) self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device())
self._logger = get_dist_logger() self._logger = get_dist_logger()
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
# Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid
# Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors,
# and it must set `num_fp32_shards_per_param` correctly
self._should_move_fp32_params_h2d: bool = self.gemini_manager.is_cuda_margin_mem_avail and self.gpu_margin_mem_ratio > 0.0 and getattr(
optim, 'num_fp32_shards_per_param', 0) >= 2
if self.gpu_margin_mem_ratio > 0.0 and not self.gemini_manager.is_cuda_margin_mem_avail:
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0])
self._register_states = disposable(self._register_states_)
def _update_params_ptr(self): def _update_params_ptr(self):
for group in self.optim.param_groups: for group in self.optim.param_groups:
for p in group['params']: for p in group['params']:
...@@ -82,6 +98,7 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -82,6 +98,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
return self.optim.zero_grad(set_to_none=True) return self.optim.zero_grad(set_to_none=True)
def step(self, *args, **kwargs): def step(self, *args, **kwargs):
self._maybe_move_fp32_params()
# unscale grads if scaled # unscale grads if scaled
if self.optim_state == OptimState.SCALED: if self.optim_state == OptimState.SCALED:
self._unscale_grads() self._unscale_grads()
...@@ -94,6 +111,7 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -94,6 +111,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
return return
self._update_params_ptr() self._update_params_ptr()
ret = self.optim.step(*args, **kwargs) ret = self.optim.step(*args, **kwargs)
self._register_states()
self._update_fp16_params() self._update_fp16_params()
return ret return ret
...@@ -109,3 +127,29 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -109,3 +127,29 @@ class ZeroOptimizer(ColossalaiOptimizer):
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
self.module.backward_by_grad(tensor, grad) self.module.backward_by_grad(tensor, grad)
def _maybe_move_fp32_params(self):
if self._should_move_fp32_params_h2d:
self._should_move_fp32_params_h2d = False
available_cuda_margin_mem = self.gemini_manager.cuda_margin_mem * self.gpu_margin_mem_ratio
fp32_params_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
fp32_params_used_cuda_margin_mem = 0
for fp16_param_chunk, fp32_param_chunk in zip(self.chunk_manager.chunk_groups['fp16_param'],
self.chunk_manager.chunk_groups['fp32_param']):
if fp32_param_chunk.is_free:
continue
if fp32_params_used_cuda_margin_mem + fp32_param_chunk.mem < fp32_params_available_cuda_margin_mem:
self.chunk_manager.move_chunk(fp32_param_chunk, get_current_device())
# stores grad now
self.chunk_manager.move_chunk(fp16_param_chunk, get_current_device())
self.module._set_chunk_grad_device(fp16_param_chunk, get_current_device())
fp32_params_used_cuda_margin_mem += fp32_param_chunk.mem
self.module._setup_grads_ptr()
def _register_states_(self):
for group in self.optim.param_groups:
for p in group['params']:
state = self.optim.state[p]
for val in state.values():
if isinstance(val, torch.Tensor):
self.chunk_manager.add_extern_static_tensor(val)
...@@ -14,6 +14,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs ...@@ -14,6 +14,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDPV2 from colossalai.nn.parallel import ColoDDPV2
from colossalai.testing import parameterize from colossalai.testing import parameterize
from colossalai.gemini.gemini_mgr import GeminiManager
def check_param_equal(model, torch_model): def check_param_equal(model, torch_model):
...@@ -44,7 +45,8 @@ def run_gpt(use_chunk, use_zero): ...@@ -44,7 +45,8 @@ def run_gpt(use_chunk, use_zero):
model = model.half() model = model.half()
chunk_size = 38 * 1024**2 if use_chunk else None chunk_size = 38 * 1024**2 if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero) chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
model = ColoDDPV2(model, chunk_manager) gemini_manager = GeminiManager('cuda', chunk_manager)
model = ColoDDPV2(model, gemini_manager)
torch_model = DDP(torch_model, device_ids=[gpc.get_global_rank()], process_group=gpc.get_group(ParallelMode.DATA)) torch_model = DDP(torch_model, device_ids=[gpc.get_global_rank()], process_group=gpc.get_group(ParallelMode.DATA))
print(chunk_manager) print(chunk_manager)
check_param_equal(model, torch_model) check_param_equal(model, torch_model)
......
...@@ -18,6 +18,7 @@ from colossalai.nn.optimizer import HybridAdam ...@@ -18,6 +18,7 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.gemini_mgr import GeminiManager
def check_param_equal(model, torch_model): def check_param_equal(model, torch_model):
...@@ -53,7 +54,8 @@ def run_gpt(use_chunk, use_zero): ...@@ -53,7 +54,8 @@ def run_gpt(use_chunk, use_zero):
chunk_size = 38 * 1024**2 if use_chunk else None chunk_size = 38 * 1024**2 if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero) chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
model = ColoDDPV2(model, chunk_manager) gemini_manager = GeminiManager('cuda', chunk_manager)
model = ColoDDPV2(model, gemini_manager)
optim = HybridAdam(model.parameters(), lr=1e-3) optim = HybridAdam(model.parameters(), lr=1e-3)
optim = ZeroOptimizer(optim, model, initial_scale=32) optim = ZeroOptimizer(optim, model, initial_scale=32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment