Commit e532679c authored by oahzxl's avatar oahzxl
Browse files

Merge branch 'main' of https://github.com/oahzxl/ColossalAI into chunk

parents c1492e50 7d5640b9
import torch
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple
from colossalai.utils import get_current_device
import torch
from colossalai.gemini.chunk import Chunk, ChunkFullError, TensorState
from colossalai.tensor import ColoTensor
from colossalai.gemini.chunk import ChunkFullError, TensorState, Chunk
from colossalai.utils import get_current_device
class ChunkManager:
......@@ -16,13 +17,13 @@ class ChunkManager:
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
"""
def __init__(self, chunk_configuration: Dict[int, Dict], init_device: Optional[torch.device] = None) -> None:
def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
self.device = init_device or get_current_device()
self.size_config: Dict[int, int] = dict()
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
self.kwargs_config = chunk_configuration
for k, v in self.kwargs_config.items():
self.size_config[k] = v.pop('chunk_size')
self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size')
v['init_device'] = self.device
self.chunk_groups: Dict[str, Deque] = dict()
......@@ -31,20 +32,28 @@ class ChunkManager:
self.accessed_mem: int = 0
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int, pin_memory: bool = False) -> None:
"""Append a tensor to a chunk.
def register_tensor(self,
tensor: ColoTensor,
group_type: str,
config_key: int,
cpu_offload: bool = False,
pin_memory: bool = False) -> None:
"""
Register a tensor to the chunk manager.
Then, the tensor should be accessed by `get_chunks`.
Args:
tensor: the tensor appended to the chunk
group_type: the data type of the group
config_key: the key of the group's name, usually the size of the dp world
group_type: the data type of the group.
config_key: the key of the group's name, the size of the dp world
cpu_offload: if True, the chunk will be closed on CPU
pin_memory: whether the chunk is pinned in the cpu memory
"""
assert tensor not in self.tensor_chunk_map
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
assert config_key in self.size_config
assert config_key in self.dp_degree_chunk_size_dict
chunk_size = self.size_config[config_key]
chunk_size = self.dp_degree_chunk_size_dict[config_key]
chunk_kwargs = self.kwargs_config[config_key]
group_name = "{}_{}".format(group_type, config_key)
chunk_group = self.__get_chunk_group(group_name)
......@@ -67,6 +76,7 @@ class ChunkManager:
chunk_size=chunk_size,
process_group=tensor.process_group,
dtype=tensor.dtype,
cpu_shard_init=cpu_offload,
pin_memory=pin_memory,
**chunk_kwargs,
)
......@@ -206,9 +216,8 @@ class ChunkManager:
return self.chunk_groups[group_name]
def __close_one_chunk(self, chunk: Chunk):
device = get_current_device() if chunk.keep_gathered else self.device # keep gathered chunk in cuda
self.__sub_memroy_usage(chunk.memory_usage)
chunk.close_chunk(device)
chunk.close_chunk()
self.__add_memory_usage(chunk.memory_usage)
def __sub_memroy_usage(self, usage: Dict[str, int]):
......
import math
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch.nn as nn
from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator
from colossalai.tensor import ColoParameter
......@@ -12,7 +13,8 @@ def in_ddp(param: nn.Parameter) -> bool:
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
"""Filter those parameters whose size is too large from others.
"""
Filter those parameters whose size is too large (more than 3x standard deviations) from others.
"""
params_size = [p.numel() for p in model.parameters() if in_ddp(p)]
params_size_arr = np.array(params_size)
......@@ -39,11 +41,20 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
return left + acc
def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
"""Clasify each parameter by its size of DP group.
def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int, List[ColoParameter]]:
"""classify_params_by_dp_degree
Classify the parameters by their dp degree
Args:
param_order (OrderedParamGenerator): the order of param be visied
Returns:
Dict[int, List[ColoParameter]]: a dict contains the classification results.
The keys are dp_degrees and the values are parameters.
"""
params_dict: Dict[int, List[ColoParameter]] = dict()
for param in model.parameters():
for param in param_order.generate():
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if not in_ddp(param):
continue
......@@ -62,24 +73,45 @@ def search_chunk_configuration(
search_range_mb: float,
search_interval_byte: int, # hidden size is the best value for the interval
min_chunk_size_mb: float = 32,
filter_exlarge_params: bool = True) -> Tuple[Dict, int]:
filter_exlarge_params: bool = True,
memstas: Optional[MemStats] = None) -> Tuple[Dict, int]:
"""search_chunk_configuration
Args:
model (nn.Module): torch module
search_range_mb (float): searching range in mega byte.
search_interval_byte (int): searching interval in byte.
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
Returns:
Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte.
"""
if memstas is not None:
param_order = memstas.param_order()
else:
# build the param visited order right now
param_order = OrderedParamGenerator()
for p in model.parameters():
param_order.append(p)
search_range_byte = round(search_range_mb * 1024**2)
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
assert search_range_byte >= 0
params_dict = clasify_params(model)
params_dict = classify_params_by_dp_degree(param_order)
config_dict: Dict[int, Dict] = dict()
size_dict: Dict[int, List[int]] = dict()
for key in params_dict:
params_list = params_dict[key]
for dp_degree in params_dict:
params_list = params_dict[dp_degree]
size_list = [p.numel() for p in params_list]
# let small parameters keep gathered in CUDA all the time
total_size = sum(size_list)
if total_size < min_chunk_size_byte:
config_dict[key] = dict(chunk_size=total_size, keep_gathered=True)
config_dict[dp_degree] = dict(chunk_size=total_size, keep_gathered=True)
else:
size_dict[key] = size_list
size_dict[dp_degree] = size_list
if filter_exlarge_params:
_filter_exlarge_params(model, size_dict)
......@@ -100,9 +132,9 @@ def search_chunk_configuration(
min_chunk_waste = temp_waste
best_chunk_size = chunk_size
for key in params_dict:
if key in config_dict:
for dp_degree in params_dict:
if dp_degree in config_dict:
continue
config_dict[key] = dict(chunk_size=best_chunk_size, keep_gathered=False)
config_dict[dp_degree] = dict(chunk_size=best_chunk_size, keep_gathered=False)
return config_dict, min_chunk_waste
......@@ -7,6 +7,7 @@ import torch.nn as nn
from colossalai.gemini.chunk import ChunkManager
from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration
from colossalai.gemini.memory_tracer import MemStats
def init_chunk_manager(model: nn.Module,
......@@ -37,13 +38,13 @@ def init_chunk_manager(model: nn.Module,
total_size = sum(params_sizes) / 1024**2
dist.barrier()
begine = time()
begin = time()
config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict)
dist.barrier()
end = time()
span_s = end - begine
span_s = end - begin
wasted_size /= 1024**2
if dist.get_rank() == 0:
......
import torch
import functools
from .memory_tracer.memstats_collector import MemStatsCollectorV2
from typing import List, Optional, Tuple
from time import time
from typing import List, Optional, Tuple
import torch
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.gemini.memory_tracer import MemStats
from .memory_tracer import ChunkMemStatsCollector
from .placement_policy import PlacementPolicyFactory
......@@ -21,13 +25,20 @@ class GeminiManager:
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration.
"""
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
assert placement_policy in PlacementPolicyFactory.get_policy_names()
self.policy_name = placement_policy
policy_cls = PlacementPolicyFactory.create(placement_policy)
self._chunk_manager = chunk_manager
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None
self._premade_memstats_ = memstats is not None
self._memstats = memstats
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager,
self._memstats) if policy_cls.need_mem_stats else None
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1
......@@ -39,7 +50,20 @@ class GeminiManager:
self._warmup = True
self._comp_cuda_demand_time = 0
def pre_iter(self):
def memstats(self):
"""memstats
get the memory statistics during training.
The stats could be collected by a runtime memory tracer, or collected by the GeminiManager.
Note, for the latter, you can not access the memstats before warmup iteration finishes.
"""
if self._premade_memstats_:
return self._memstats
else:
assert not self._warmup, "Gemini Manager has memstats after warm up! Now is during warmup."
return self._mem_stats_collector._memstats
def pre_iter(self, *args):
if self._mem_stats_collector and self._warmup:
self._mem_stats_collector.start_collection()
......@@ -57,7 +81,7 @@ class GeminiManager:
self._comp_cuda_demand_time = 0
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
""" Adjust the layout of statefuil tensor according to the information provided
""" Adjust the layout of stateful tensors according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
"""
# find stateful tensor in state COMPUTE
......@@ -109,9 +133,9 @@ class GeminiManager:
if self._mem_stats_collector:
self._mem_stats_collector.sample_overall_data()
def sample_model_data(self):
def record_model_data_volume(self):
if self._mem_stats_collector:
self._mem_stats_collector.sample_model_data()
self._mem_stats_collector.record_model_data_volume()
@property
def chunk_manager(self):
......
from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor
from .memstats_collector import MemStatsCollector
from .param_runtime_order import OrderedParamGenerator # isort:skip
from .memory_stats import MemStats # isort:skip
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
from .memstats_collector import MemStatsCollector # isort:skip
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
from .static_memstats_collector import StaticMemStatsCollector # isort:skip
__all__ = ['AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER']
__all__ = [
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
'StaticMemStatsCollector', 'MemStats', 'OrderedParamGenerator'
]
from typing import Optional
from colossalai.gemini.chunk import ChunkManager
from colossalai.gemini.memory_tracer import MemStats
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from .memstats_collector import MemStatsCollector
class ChunkMemStatsCollector(MemStatsCollector):
def __init__(self, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
"""
Memory Statistic Collector for Chunks.
Args:
chunk_manager (ChunkManager): the chunk manager.
memstats (Optional[MemStats], optional): memory statistics collected by RMT. Defaults to None.
"""
super().__init__(memstats)
self._chunk_manager = chunk_manager
# override
def record_model_data_volume(self) -> None:
"""
record model data volumn on cuda and cpu.
"""
if self._start_flag and not self.use_outside_memstats:
cuda_mem = self._chunk_manager.total_mem['cuda']
self._memstats.record_max_cuda_model_data(cuda_mem)
@property
def cuda_margin_mem(self) -> float:
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda
import json
from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor
from time import sleep, time
import json
import torch
from colossalai.utils import colo_device_memory_used
from colossalai.utils import get_current_device
from colossalai.utils import colo_device_memory_used, get_current_device
class MemoryMonitor:
......@@ -134,7 +133,13 @@ class SyncCudaMemoryMonitor(MemoryMonitor):
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
def finish(self):
def finish(self) -> int:
"""
return max gpu memory used since latest `start()`.
Returns:
int: max GPU memory
"""
torch.cuda.synchronize()
self.time_stamps.append(time())
max_usage = torch.cuda.max_memory_allocated()
......
from typing import Any, Dict, List, Optional
import torch
from colossalai.gemini.memory_tracer import OrderedParamGenerator
class MemStats(object):
def __init__(self) -> None:
"""
Store the non model data statistics used for Gemini and ZeroOptimizer.
"""
# (preop_step, List[param])
self._step_param_dict = dict()
# (param, List[preop_step])
self._param_step_dict = dict()
# (preop_step, non_model_data) non model data used during preop_step ~ (preop_step+1)
self._step_nmd_dict = dict()
self._param_runtime_order = OrderedParamGenerator()
self._preop_step = 0
self._prev_overall_cuda = -1
self._max_overall_cuda = 0
self._prev_md_cuda = -1
# old version
self._model_data_cuda_list = []
self._model_data_cpu_list = []
self._overall_cuda_list = []
self._overall_cpu_list = []
self._non_model_data_cuda_list = []
self._non_model_data_cpu_list = []
def calc_max_cuda_non_model_data(self):
if self._prev_overall_cuda != -1 and self._prev_md_cuda != -1:
max_cuda_non_model_data = self._prev_overall_cuda - self._prev_md_cuda
self._step_nmd_dict[self._preop_step - 1] = max_cuda_non_model_data
# compatibility of the old version.
self._non_model_data_cuda_list.append(max_cuda_non_model_data)
def record_max_cuda_model_data(self, val):
self._prev_md_cuda = val
def record_max_cuda_overall_data(self, val):
self._prev_overall_cuda = val
self._max_overall_cuda = max(self._max_overall_cuda, val)
@property
def max_overall_cuda(self):
return self._max_overall_cuda
def increase_preop_step(self, param_list: List[torch.nn.Parameter]):
"""
the time step is increased. param list is used between current and the next
time step.
Args:
param_list (List[torch.nn.Parameter]): a list of torch paramters.
"""
for p in param_list:
if p not in self._param_step_dict:
self._param_step_dict[p] = [self._preop_step]
else:
self._param_step_dict[p].append(self._preop_step)
self._param_runtime_order.append(p)
self._step_param_dict[self._preop_step] = param_list
self._preop_step += 1
def param_used_step(self, param: torch.nn.Parameter) -> Optional[List[int]]:
"""param_used_step
get the timestep list using the param
Args:
param (torch.nn.Parameter): a torch param
Returns:
Optional[List[int]]: a list of int indicates the time step of preop hook.
"""
if param not in self._param_step_dict:
return None
else:
return self._param_step_dict[param]
def param_order(self):
if self._param_runtime_order.is_empty():
raise RuntimeError
else:
return self._param_runtime_order
def non_model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._non_model_data_cuda_list
elif device_type == 'cpu':
return self._non_model_data_cpu_list
else:
raise TypeError
def max_non_model_data(self, device_type: str) -> float:
if device_type == 'cuda':
return max(self._non_model_data_cuda_list)
elif device_type == 'cpu':
return max(self._non_model_data_cpu_list)
else:
raise TypeError
def clear(self):
self._model_data_cuda_list = []
self._overall_cuda_list = []
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._non_model_data_cpu_list = []
self._non_model_data_cuda_list = []
self._param_runtime_order.clear()
self._step_param_dict.clear()
self._param_step_dict.clear()
self._step_nmd_dict.clear()
self._preop_step = 0
self._prev_overall_cuda = -1
self._prev_md_cuda = -1
import time
from typing import List, Optional
import torch
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.utils.memory import colo_device_memory_used, colo_device_memory_capacity
from colossalai.utils import get_current_device
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.gemini.chunk import ChunkManager
from colossalai.utils.memory import colo_device_memory_used
import torch
import time
from typing import List
from .memory_stats import MemStats
class MemStatsCollector:
......@@ -21,48 +22,22 @@ class MemStatsCollector:
It has a Sampling counter which is reset after DNN training iteration.
"""
def __init__(self) -> None:
def __init__(self, memstats: Optional[MemStats] = None) -> None:
self._mem_monitor = SyncCudaMemoryMonitor()
self._model_data_cuda_list = []
self._overall_cuda_list = []
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._non_model_data_cuda_list = []
self._non_model_data_cpu_list = []
self._sampling_time = []
self._start_flag = False
self._step_idx = 0
self._step_total = 0
def overall_mem_stats(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._overall_cuda_list
elif device_type == 'cpu':
return self._overall_cpu_list
else:
raise TypeError
def model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._model_data_cuda_list
elif device_type == 'cpu':
return self._model_data_cpu_list
else:
raise TypeError
def non_model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._non_model_data_cuda_list
elif device_type == 'cpu':
return self._non_model_data_cpu_list
if memstats is not None:
self.use_outside_memstats = True
self._memstats = memstats
else:
raise TypeError
self.use_outside_memstats = False
self._memstats = MemStats()
def next_period_non_model_data_usage(self, device_type: str) -> int:
"""Get max non model data memory usage of current sampling period
"""Maximum non model data memory usage during the next Op run
Args:
device_type (str): device type, can be 'cpu' or 'cuda'.
......@@ -72,7 +47,10 @@ class MemStatsCollector:
"""
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
next_non_model_data = self.non_model_data_list(device_type)[self._step_idx]
assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \
f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\
f"step total {self._step_total}"
next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx]
self._step_idx = (self._step_idx + 1) % self._step_total
return next_non_model_data
......@@ -86,67 +64,37 @@ class MemStatsCollector:
def finish_collection(self):
self.sample_overall_data()
self._step_total = len(self._sampling_time)
# self._step_total = len(self._sampling_time)
self._step_total = len(self._memstats.non_model_data_list('cuda'))
self._start_flag = False
self._mem_monitor.finish()
print(f'finish_collection {self._step_total}')
def sample_model_data(self) -> None:
"""Sampling model data statistics.
# deprecated
def record_model_data_volume(self) -> None:
"""
if self._start_flag:
Sampling model data statistics.
"""
if self._start_flag and not self.use_outside_memstats:
# The following code work for ZeroInitContext, which is deprecated in v0.1.12
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
self._model_data_cuda_list.append(cuda_mem)
self._model_data_cpu_list.append(cpu_mem)
self._memstats.record_max_cuda_model_data(cuda_mem)
def sample_overall_data(self) -> None:
"""Sampling non model data statistics.
"""
if self._start_flag:
# overall data recording is after model data recording
if len(self._model_data_cuda_list) == 0:
return
self._overall_cuda_list.append(self._mem_monitor.finish())
self._overall_cpu_list.append(colo_device_memory_used(torch.device('cpu')))
Sampling overall and non model data cuda memory statistics.
"""
if self._start_flag and not self.use_outside_memstats:
cuda_overall = self._mem_monitor.finish()
self._memstats.record_max_cuda_overall_data(cuda_overall)
self._memstats.calc_max_cuda_non_model_data()
assert len(self._model_data_cuda_list) == len(self._overall_cuda_list)
self._mem_monitor.start()
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
if self._start_flag:
self._sampling_time.append(time.time())
self._mem_monitor.start()
def clear(self) -> None:
self._model_data_cuda_list = []
self._overall_cuda_list = []
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._non_model_data_cpu_list = []
self._non_model_data_cuda_list = []
self._memstats.clear()
self._start_flag = False
self._step_idx = 0
self._step_total = 0
class MemStatsCollectorV2(MemStatsCollector):
def __init__(self, chunk_manager: ChunkManager) -> None:
super().__init__()
self._chunk_manager = chunk_manager
def sample_model_data(self) -> None:
"""Sampling model data statistics.
"""
if self._start_flag:
cuda_mem = self._chunk_manager.total_mem['cuda']
cpu_mem = self._chunk_manager.total_mem['cpu']
self._model_data_cuda_list.append(cuda_mem)
self._model_data_cpu_list.append(cpu_mem)
@property
def cuda_margin_mem(self) -> float:
return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda'))
from abc import ABC
import torch
class ParamGenerator(ABC):
def append(self, param: torch.nn.Parameter):
pass
def generate(self):
pass
def clear(self):
pass
class OrderedParamGenerator(ParamGenerator):
"""OrderedParamGenerator
Contain the order of parameters visited during runtime.
"""
def __init__(self) -> None:
self.param_visited_order = []
def append(self, param: torch.nn.Parameter):
self.param_visited_order.append(param)
def generate(self):
visited_set = set()
for p in self.param_visited_order:
if p not in visited_set:
yield p
visited_set.add(p)
del visited_set
def is_empty(self):
return len(self.param_visited_order) == 0
def clear(self):
self.param_visited_order = []
import torch.nn
from colossalai.gemini.memory_tracer import MemStats
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook
from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
__all__ = ['RuntimeMemTracer']
class RuntimeMemTracer():
"""RuntimeMemTracer for the module training using ColoParameter.
Trace non-model memory usage during fwd+bwd process.
It is obtained by using a tensor with the same shape as the training process as the inputs
and running an single fwd+bwd to trace the statistics.
NOTE()
1. The premise to use this tracer is that the target DNN execute the same operations at each iterations,
2. Module buffers are viewed as non-model data.
"""
def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
super().__init__()
self.module = module
self.dtype = dtype
self._gradstat = GradMemStats()
self._memstats = MemStats()
self.param_op_hook = ParamMemTracerHook(self._memstats, self._gradstat)
self.grad_hook = GradMemTracerHook(self._gradstat)
self.cpu_param_data_dict = {}
for p in module.parameters():
p.data = p.data.to(dtype)
self._cast_buffers_to_cuda_dtype()
def parameters_in_runtime_order(self):
return self._memstats._param_runtime_order.generate()
def memstats(self):
return self._memstats
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def _backup_params(self):
"""
The function is called before forward. Backup model params on cpu.
"""
for p in self.module.parameters():
self.cpu_param_data_dict[p] = torch.empty(p.data.shape, dtype=self.dtype, device="cpu")
self.cpu_param_data_dict[p].copy_(p.data)
def _restore_params(self):
"""
This function is called after backward. Restore model params.
"""
for p in self.module.parameters():
p.data = torch.empty(p.data.shape, dtype=self.dtype, device="cpu", requires_grad=p.data.requires_grad)
p.data.copy_(self.cpu_param_data_dict[p])
self.cpu_param_data_dict.clear()
def _pre_forward(self):
self._clear_cuda_mem_info()
self._backup_params()
self.grad_hook.register_grad_hook(self.module)
self.param_op_hook.mem_monitor.start()
def forward(self, *args, **kwargs):
args, kwargs = _cast_float(args, self.dtype), _cast_float(kwargs, self.dtype)
self.module.zero_grad(set_to_none=True)
self._pre_forward()
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
return outputs
def backward(self, loss):
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
loss.backward()
self._post_backward()
def _post_backward(self):
cuda_volume = self.param_op_hook.mem_monitor.finish()
self._memstats.record_max_cuda_overall_data(cuda_volume)
# calc the last Op non model data
self._memstats.calc_max_cuda_non_model_data()
self.grad_hook.remove_grad_hook()
self._restore_params()
def _clear_cuda_mem_info(self):
self._memstats.clear()
self._gradstat.clear()
def _cast_buffers_to_cuda_dtype(self):
for buffer in self.module.buffers():
buffer.data = buffer.cuda()
if torch.is_floating_point(buffer):
buffer.data = buffer.data.to(self.dtype)
from typing import Optional
import torch
import torch.nn as nn
from torch.fx import symbolic_trace
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta
from colossalai.gemini.chunk import ChunkManager
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
from .chunk_memstats_collector import ChunkMemStatsCollector
class ModuleInfos:
def __init__(self, module: torch.nn.Module, module_name: str, module_full_name: str,
parent_module: torch.nn.Module):
self.module = module
self.module_name = module_name
self.module_full_name = module_full_name
self.parent_module = parent_module
class StaticMemStatsCollector(ChunkMemStatsCollector):
"""
A Static Memory statistic collector.
"""
def __init__(self, module: nn.Module, chunk_manager: ChunkManager) -> None:
super().__init__(chunk_manager)
self.module = module
self.module_info_list = []
def init_mem_stats(self, *inputs):
self.register_opnodes_recursively(self.module)
self.refactor_module()
self.module = self.module.cpu()
self.module.train()
data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs]
gm = symbolic_trace(self.module)
interp = MetaInfoProp(gm)
interp.propagate(*data)
total_mem = 0
for inp in inputs:
total_mem += inp.numel() * inp.element_size()
last_node = None
module_name_list = [mInfo.module_full_name for mInfo in self.module_info_list]
for node in gm.graph.nodes:
total_mem = total_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node)
if node.op == "call_module":
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
self._non_model_data_cuda_list.append(total_mem)
last_node = node
self._non_model_data_cuda_list.append(total_mem)
self._non_model_data_cuda_list = self._non_model_data_cuda_list[1:]
cur_module_mem_fwd = 0
cur_module_mem_bwd = 0
grad_module_out = last_node.meta["fwd_mem_out"]
for node in gm.graph.nodes.__reversed__():
cur_module_mem_fwd = cur_module_mem_fwd + calculate_fwd_tmp(node) + calculate_fwd_out(node)
cur_module_mem_bwd = cur_module_mem_bwd + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
if node.op == "call_module":
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
self._non_model_data_cuda_list.append(total_mem + grad_module_out + cur_module_mem_bwd)
total_mem = total_mem - cur_module_mem_fwd
cur_module_mem_fwd = 0
cur_module_mem_bwd = 0
grad_module_out = node.meta["bwd_mem_out"]
self._step_total = len(self._non_model_data_cuda_list)
self.recover_module()
def refactor_module(self):
for modInfo in self.module_info_list:
temp_node = nn.Sequential(nn.ReLU(), modInfo.module)
modInfo.parent_module.__setattr__(modInfo.module_name, temp_node)
def recover_module(self):
for modInfo in self.module_info_list:
modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module)
def register_opnodes_recursively(self,
module: torch.nn.Module,
name: str = "",
full_name: str = "",
parent_module: Optional[torch.nn.Module] = None):
assert isinstance(module, torch.nn.Module)
for child_name, child in module.named_children():
self.register_opnodes_recursively(child, child_name, full_name + "_" + child_name, module)
# Early return on modules with no parameters.
if len(list(module.parameters(recurse=False))) == 0:
return
self.module_info_list.append(ModuleInfos(module, name, full_name[1:], parent_module))
from colossalai.context.singleton_meta import SingletonMeta
from typing import Optional, Tuple
import torch
from typing import Tuple, Optional
from colossalai.logging import DistributedLogger
def colo_model_optimizer_usage(optim) -> Tuple[int, int]:
......@@ -58,52 +57,3 @@ def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
cpu_mem_usage += t_cpu
return cuda_mem_usage, cpu_mem_usage
class ModelDataTracer(metaclass=SingletonMeta):
"""
A tracer singleton to trace model data usage during runtime.
You have to register a model on the singleton first.
"""
def __init__(self) -> None:
self._logger = DistributedLogger("ModelDataTracer")
self._model = None
self._opitimizer = None
def _get_mem_usage(self) -> Tuple[int, int]:
"""
get the memory usage of the model registered.
Returns:
Tuple[int, int]: cuda, cpu mem usage
"""
cuda_use_opt, cpu_use_opt = colo_model_optimizer_usage(self._opitimizer)
cuda_use_model, cpu_use_model = colo_model_mem_usage(self._model)
return cuda_use_opt + cuda_use_model, cpu_use_opt + cpu_use_model
def register_model(self, model) -> None:
if self._model is not None:
self._logger.warning("ModelDataTracer has already registered a model")
self._model = model
def register_optimizer(self, optimizer) -> None:
if self._opitimizer is not None:
self._logger.warning("ModelDataTracer has already registered an optimizer")
self._opitimizer = optimizer
@property
def cpu_usage(self):
_, cpu_usage = self._get_mem_usage()
return cpu_usage
@property
def cuda_usage(self):
cuda_usage, _ = self._get_mem_usage()
return cuda_usage
@property
def both_mem_usage(self):
return self._get_mem_usage()
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
from .utils import register_ophooks_recursively, BaseOpHook
from ._memtracer_ophook import MemTracerOpHook
from .utils import BaseOpHook, register_ophooks_recursively
__all__ = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively"]
__all__ = ["BaseOpHook", "register_ophooks_recursively"]
import json
import pickle
from pathlib import Path
from colossalai.context.parallel_mode import ParallelMode
import torch
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.registry import OPHOOKS
from colossalai.logging import get_dist_logger
from colossalai.core import global_context as gpc
from typing import Union
import math
@OPHOOKS.register_module
class MemTracerOpHook(BaseOpHook):
"""
Collect GPU memory usage information
Args:
warmup (int): This parameter indicates how many iterations to truncate before profiling, defaults to 50.
refreshrate (int): This parameter decides the frequency of write file, defaults to 10.
data_prefix (string): The prefix of the stats data file, defaults to "memstats".
"""
def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"):
from colossalai.gemini.memory_tracer import AsyncMemoryMonitor
super().__init__()
self.async_mem_monitor = AsyncMemoryMonitor()
self._curiter = 0
self._logger = get_dist_logger()
self._count = 0
self._warmup = warmup
self._refreshrate = refreshrate
self._data_prefix = data_prefix
# in distributed environment
if gpc.is_initialized(ParallelMode.GLOBAL):
self._rank = gpc.get_global_rank()
else:
self._rank = 0
def _isvalid(self, module) -> bool:
assert isinstance(module, torch.nn.Module)
return module.training
def _resample(self):
# calculate the average iteration time
total_time = (self.async_mem_monitor.time_stamps[-1] - self.async_mem_monitor.time_stamps[0])
avg_it_time = total_time / self.warmup
self._logger.debug(f"total time for {self.warmup} iterations is {total_time}s")
# adjust the sampling power
power: int = round(-math.log(avg_it_time, 10)) + 1
self._logger.debug(f"the power is {power}")
self.async_mem_monitor.set_interval(power)
@property
def refreshrate(self) -> int:
return self._refreshrate
@property
def warmup(self) -> int:
return self._warmup
@property
def curiter(self) -> int:
return self._curiter
@property
def valid_iter(self) -> int:
return self.curiter - self.warmup
def pre_fwd_exec(self, module: torch.nn.Module, *args):
if self._isvalid(module):
self.async_mem_monitor.finish()
self.async_mem_monitor.start()
def post_fwd_exec(self, module: torch.nn.Module, *args):
if self._isvalid(module):
self.async_mem_monitor.finish()
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
if self._isvalid(module):
self.async_mem_monitor.finish()
self.async_mem_monitor.start()
def post_bwd_exec(self, module: torch.nn.Module, input):
if self._isvalid(module):
self.async_mem_monitor.finish()
def pre_iter(self):
pass
def post_iter(self):
self.async_mem_monitor.finish()
# in the warmup stage
if self.curiter < self.warmup:
pass
# adjust the sampling rate
elif self.curiter == self.warmup:
# use adaptive sample rate
self._resample()
# record data to log file
else:
# every `refreshrate` times, refresh the file
if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0:
# output file info
self._logger.info(f"dump a memory statistics as pickle to {self._data_prefix}-{self._rank}.pkl")
home_dir = Path.home()
with open(home_dir.joinpath(f".cache/colossal/mem-{self._rank}.pkl"), "wb") as f:
pickle.dump(self.async_mem_monitor.state_dict, f)
self._count += 1
self._logger.debug(f"data file has been refreshed {self._count} times")
# finish a iteration
self._curiter += 1
def save_results(self, data_file: Union[str, Path]):
with open(data_file, "w") as f:
f.write(json.dumps(self.async_mem_monitor.state_dict))
import torch
from colossalai.registry import OPHOOKS
from . import BaseOpHook
@OPHOOKS.register_module
class ShardGradHook(BaseOpHook):
class ShardGradMemTracerHook(BaseOpHook):
"""
A hook to process sharded param before and afther FWD and BWD operator executing.
"""
......
from contextlib import contextmanager
from enum import Enum
from functools import partial
from typing import List
import torch
from colossalai.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor
from colossalai.gemini.tensor_utils import alloc_storage, free_storage
from colossalai.tensor.param_op_hook import ColoParamOpHook
class TrainingPhase(Enum):
FORWARD = 0
BACKWARD = 1
class GradMemStats():
def __init__(self) -> None:
self.unreleased_grad_flag = {}
self.unreleased_grad_volume = 0
def clear(self):
self.unreleased_grad_flag.clear()
self.unreleased_grad_volume = 0
class GradMemTracerHook():
def __init__(self, grad_stats: GradMemStats):
self.grad_hook_list = []
self._grad_stats = grad_stats
def grad_handle(self, p, grad):
assert self._grad_stats.unreleased_grad_flag[p]
free_storage(grad)
self._grad_stats.unreleased_grad_volume -= grad.numel() * grad.element_size()
self._grad_stats.unreleased_grad_flag[p] = False
def register_grad_hook(self, module: torch.nn.Module):
for p in module.parameters():
if p.requires_grad:
self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p)))
self._grad_stats.unreleased_grad_flag[p] = False
def remove_grad_hook(self):
for hook in self.grad_hook_list:
hook.remove()
class ParamMemTracerHook(ColoParamOpHook):
def __init__(self, memstats: MemStats, gradstats: GradMemStats) -> None:
super().__init__()
self._training_phase = TrainingPhase.FORWARD
self._memstats = memstats
self._grad_stats = gradstats
self.mem_monitor = SyncCudaMemoryMonitor()
def _free_cuda_params(self, params):
for p in params:
if p.data.device.type == "cpu":
raise NotImplementedError("Only free cuda memory")
free_storage(p.data)
def _allocate_params_on_cuda(self, params: List[torch.nn.Parameter]):
"""
move params to cuda
Args:
params (List[torch.nn.Parameter]): target params
Raises:
NotImplementedError: raise error when param has cpu grad
"""
for p in params:
cur_dev = p.data.device.type
if cur_dev == "cpu":
if p.grad is not None and p.grad.device.type == "cpu":
raise NotImplementedError("Only run in forward propagation")
p.data = torch.empty(p.data.shape,
device="cuda",
dtype=p.data.dtype,
requires_grad=p.data.requires_grad)
elif cur_dev == "cuda":
alloc_storage(p.data)
def record_model_data_volume(self, params):
"""
get cuda model data used by params
"""
data_volume = self._grad_stats.unreleased_grad_volume
for p in params:
cur_model_data_volume = p.data.numel() * p.data.element_size()
data_volume += cur_model_data_volume
if self._training_phase == TrainingPhase.BACKWARD and p.requires_grad:
# add param.grad, actually param.grad is None in this time
data_volume += cur_model_data_volume
if not self._grad_stats.unreleased_grad_flag[p]:
self._grad_stats.unreleased_grad_volume += cur_model_data_volume
self._grad_stats.unreleased_grad_flag[p] = True
# record max non model data used for this Op
self._memstats.record_max_cuda_model_data(data_volume)
def pre_op(self, params):
max_cuda_used_pre_op = self.mem_monitor.finish()
# record max cuda overall data for prev OP.
self._memstats.record_max_cuda_overall_data(max_cuda_used_pre_op)
# record max cuda non model data for prev OP.
self._memstats.calc_max_cuda_non_model_data()
self._allocate_params_on_cuda(params)
# record max cuda model data for current OP
self.record_model_data_volume(params)
self.mem_monitor.start()
self._memstats.increase_preop_step(params)
def post_op(self, params):
self._free_cuda_params(params)
def pre_forward(self, params: List[torch.Tensor]) -> None:
self.pre_op(params)
def post_forward(self, params: List[torch.Tensor]) -> None:
self.post_op(params)
def pre_backward(self, params: List[torch.Tensor]) -> None:
self.pre_op(params)
def post_backward(self, params: List[torch.Tensor]) -> None:
self.post_op(params)
@contextmanager
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
old_training_phase = self._training_phase
try:
self._training_phase = training_phase
yield
finally:
self._training_phase = old_training_phase
switch_to_backward = switch_training_phase
switch_to_forward = partial(switch_to_backward, training_phase=TrainingPhase.FORWARD)
import functools
from abc import ABC, abstractmethod
from time import time
from typing import List, Optional, Tuple, Dict
from typing import Dict, List, Optional, Tuple, Type
import torch
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollectorV2
from typing import Type
import functools
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.gemini.memory_tracer import ChunkMemStatsCollector
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
class PlacementPolicy(ABC):
need_mem_stats: bool = False
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
self.chunk_manager = chunk_manager
self.mem_stats_collector: Optional[MemStatsCollectorV2] = mem_stats_collector
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
@abstractmethod
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
......@@ -29,7 +31,9 @@ class PlacementPolicy(ABC):
class CPUPlacementPolicy(PlacementPolicy):
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
......@@ -44,7 +48,9 @@ class CPUPlacementPolicy(PlacementPolicy):
class CUDAPlacementPolicy(PlacementPolicy):
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
......@@ -65,7 +71,9 @@ class AutoPlacementPolicy(PlacementPolicy):
_warmup_non_model_data_ratio: float = 0.8
_steady_cuda_cap_ratio: float = 0.9
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self,
......@@ -154,7 +162,9 @@ class ConstPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = False
_accessed_memory_boundary = 512 * 1024**2
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self,
......@@ -226,7 +236,7 @@ class PlacementPolicyFactory:
return PlacementPolicyFactory.policies[policy_name]
@staticmethod
def get_polocy_names():
def get_policy_names():
return tuple(PlacementPolicyFactory.policies.keys())
@staticmethod
......
......@@ -3,6 +3,20 @@ from colossalai.gemini.stateful_tensor import StatefulTensor
from typing import Union, Tuple
def is_storage_empty(tensor: torch.Tensor) -> bool:
return tensor.storage().size() == 0
def free_storage(tensor: torch.Tensor) -> None:
if not is_storage_empty(tensor):
tensor.storage().resize_(0)
def alloc_storage(tensor: torch.Tensor) -> None:
if is_storage_empty(tensor):
tensor.storage().resize_(tensor.numel())
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
if isinstance(tensor, StatefulTensor):
t = tensor.payload
......
......@@ -22,7 +22,9 @@ class TensorParallelEnv(object):
depth_3d: int = None,
input_group_3d=None,
weight_group_3d=None,
output_group_3d=None):
output_group_3d=None,
input_x_weight_group_3d=None,
output_x_weight_group_3d=None):
self.mode = mode
self.vocab_parallel = vocab_parallel
self.parallel_input_1d = parallel_input_1d
......@@ -33,6 +35,8 @@ class TensorParallelEnv(object):
self.input_group_3d = input_group_3d
self.weight_group_3d = weight_group_3d
self.output_group_3d = output_group_3d
self.input_x_weight_group_3d = input_x_weight_group_3d
self.output_x_weight_group_3d = output_x_weight_group_3d
def save(self):
return dict(mode=self.mode,
......@@ -44,7 +48,9 @@ class TensorParallelEnv(object):
depth_3d=self.depth_3d,
input_group_3d=self.input_group_3d,
weight_group_3d=self.weight_group_3d,
output_group_3d=self.output_group_3d)
output_group_3d=self.output_group_3d,
input_x_weight_group_3d=self.input_x_weight_group_3d,
output_x_weight_group_3d=self.output_x_weight_group_3d)
tensor_parallel_env = TensorParallelEnv()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment