Unverified Commit c4739a72 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[Gemini] polish memstats collector (#1962)

parent fea3cb66
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from colossalai.gemini.chunk import Chunk, ChunkManager from colossalai.gemini.chunk import Chunk, ChunkManager
from .memory_tracer.memstats_collector import MemStatsCollectorV2, MemStatsCollectorStatic from .memory_tracer import ChunkMemStatsCollector, StaticMemStatsCollector
from .placement_policy import PlacementPolicyFactory from .placement_policy import PlacementPolicyFactory
...@@ -26,7 +26,8 @@ class GeminiManager: ...@@ -26,7 +26,8 @@ class GeminiManager:
chunk_manager (ChunkManager): A ``ChunkManager`` instance. chunk_manager (ChunkManager): A ``ChunkManager`` instance.
""" """
def __init__(self, placement_policy: str, def __init__(self,
placement_policy: str,
chunk_manager: ChunkManager, chunk_manager: ChunkManager,
module: Optional[torch.nn.Module] = None, module: Optional[torch.nn.Module] = None,
use_static_memstats: bool = False) -> None: use_static_memstats: bool = False) -> None:
...@@ -35,14 +36,14 @@ class GeminiManager: ...@@ -35,14 +36,14 @@ class GeminiManager:
self.policy_name = placement_policy self.policy_name = placement_policy
policy_cls = PlacementPolicyFactory.create(placement_policy) policy_cls = PlacementPolicyFactory.create(placement_policy)
self._chunk_manager = chunk_manager self._chunk_manager = chunk_manager
# self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None # self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager) if policy_cls.need_mem_stats else None
self.use_static_memstats = use_static_memstats self.use_static_memstats = use_static_memstats
if policy_cls.need_mem_stats: if policy_cls.need_mem_stats:
if use_static_memstats: if use_static_memstats:
assert module is not None assert module is not None
self._mem_stats_collector = MemStatsCollectorStatic(module, chunk_manager) self._mem_stats_collector = StaticMemStatsCollector(module, chunk_manager)
else: else:
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager)
else: else:
self._mem_stats_collector = None self._mem_stats_collector = None
......
from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor from .memstats_collector import MemStatsCollector # isort:skip
from .memstats_collector import MemStatsCollector from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER # isort:skip
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
from .static_memstats_collector import StaticMemStatsCollector # isort:skip
__all__ = ['AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER'] __all__ = [
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER'
]
from colossalai.gemini.chunk import ChunkManager
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from .memstats_collector import MemStatsCollector
class ChunkMemStatsCollector(MemStatsCollector):
def __init__(self, chunk_manager: ChunkManager) -> None:
super().__init__()
self._chunk_manager = chunk_manager
def sample_model_data(self) -> None:
"""Sampling model data statistics.
"""
if self._start_flag:
cuda_mem = self._chunk_manager.total_mem['cuda']
cpu_mem = self._chunk_manager.total_mem['cpu']
self._model_data_cuda_list.append(cuda_mem)
self._model_data_cpu_list.append(cpu_mem)
@property
def cuda_margin_mem(self) -> float:
return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda'))
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.utils.memory import colo_device_memory_used, colo_device_memory_capacity
from colossalai.utils import get_current_device
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.gemini.chunk import ChunkManager
import torch
import torch.nn as nn
import time import time
from typing import List, Optional from typing import List
from colossalai.fx.passes.meta_info_prop import MetaInfoProp import torch
from colossalai.fx.profiler import (calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size)
from torch.fx import symbolic_trace
if is_compatible_with_meta(): from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.fx.profiler import MetaTensor from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.utils.memory import colo_device_memory_used
class MemStatsCollector: class MemStatsCollector:
...@@ -138,121 +129,3 @@ class MemStatsCollector: ...@@ -138,121 +129,3 @@ class MemStatsCollector:
self._start_flag = False self._start_flag = False
self._step_idx = 0 self._step_idx = 0
self._step_total = 0 self._step_total = 0
class MemStatsCollectorV2(MemStatsCollector):
def __init__(self, chunk_manager: ChunkManager) -> None:
super().__init__()
self._chunk_manager = chunk_manager
def sample_model_data(self) -> None:
"""Sampling model data statistics.
"""
if self._start_flag:
cuda_mem = self._chunk_manager.total_mem['cuda']
cpu_mem = self._chunk_manager.total_mem['cpu']
self._model_data_cuda_list.append(cuda_mem)
self._model_data_cpu_list.append(cpu_mem)
@property
def cuda_margin_mem(self) -> float:
return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda'))
class MemStatsCollectorStatic(MemStatsCollectorV2):
"""
A Static Memory statistic collector.
"""
def __init__(self, module: nn.Module, chunk_manager: ChunkManager) -> None:
super().__init__(chunk_manager)
self.module = module
self.module_info_list = []
def init_mem_stats(self, *inputs):
self.register_opnodes_recursively(self.module)
self.refactor_module()
self.module = self.module.cpu()
self.module.train()
data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs]
gm = symbolic_trace(self.module)
interp = MetaInfoProp(gm)
interp.propagate(*data)
total_mem = 0
for inp in inputs:
total_mem += inp.numel() * inp.element_size()
last_node = None
module_name_list = [mInfo.module_full_name for mInfo in self.module_info_list]
for node in gm.graph.nodes:
total_mem = total_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node)
if node.op == "call_module":
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
self._non_model_data_cuda_list.append(total_mem)
last_node = node
self._non_model_data_cuda_list.append(total_mem)
self._non_model_data_cuda_list = self._non_model_data_cuda_list[1:]
cur_module_mem_fwd = 0
cur_module_mem_bwd = 0
grad_module_out = last_node.meta["fwd_mem_out"]
for node in gm.graph.nodes.__reversed__():
cur_module_mem_fwd = cur_module_mem_fwd + calculate_fwd_tmp(node) + calculate_fwd_out(node)
cur_module_mem_bwd = cur_module_mem_bwd + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
if node.op == "call_module":
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
self._non_model_data_cuda_list.append(total_mem + grad_module_out + cur_module_mem_bwd)
total_mem = total_mem - cur_module_mem_fwd
cur_module_mem_fwd = 0
cur_module_mem_bwd = 0
grad_module_out = node.meta["bwd_mem_out"]
self._step_total = len(self._non_model_data_cuda_list)
self.recover_module()
def refactor_module(self):
for modInfo in self.module_info_list:
temp_node = nn.Sequential(nn.ReLU(), modInfo.module)
modInfo.parent_module.__setattr__(modInfo.module_name, temp_node)
def recover_module(self):
for modInfo in self.module_info_list:
modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module)
def register_opnodes_recursively(self,
module: torch.nn.Module,
name: str = "",
full_name: str = "",
parent_module: Optional[torch.nn.Module] = None):
assert isinstance(module, torch.nn.Module)
for child_name, child in module.named_children():
self.register_opnodes_recursively(child, child_name, full_name + "_" + child_name, module)
# Early return on modules with no parameters.
if len(list(module.parameters(recurse=False))) == 0:
return
self.module_info_list.append(ModuleInfos(module, name, full_name[1:], parent_module))
class ModuleInfos:
def __init__(self,
module: torch.nn.Module,
module_name: str,
module_full_name: str,
parent_module: torch.nn.Module):
self.module = module
self.module_name = module_name
self.module_full_name = module_full_name
self.parent_module = parent_module
\ No newline at end of file
from typing import Optional
import torch
import torch.nn as nn
from torch.fx import symbolic_trace
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta
from colossalai.gemini.chunk import ChunkManager
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
from .chunk_memstats_collector import ChunkMemStatsCollector
class ModuleInfos:
def __init__(self, module: torch.nn.Module, module_name: str, module_full_name: str,
parent_module: torch.nn.Module):
self.module = module
self.module_name = module_name
self.module_full_name = module_full_name
self.parent_module = parent_module
class StaticMemStatsCollector(ChunkMemStatsCollector):
"""
A Static Memory statistic collector.
"""
def __init__(self, module: nn.Module, chunk_manager: ChunkManager) -> None:
super().__init__(chunk_manager)
self.module = module
self.module_info_list = []
def init_mem_stats(self, *inputs):
self.register_opnodes_recursively(self.module)
self.refactor_module()
self.module = self.module.cpu()
self.module.train()
data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs]
gm = symbolic_trace(self.module)
interp = MetaInfoProp(gm)
interp.propagate(*data)
total_mem = 0
for inp in inputs:
total_mem += inp.numel() * inp.element_size()
last_node = None
module_name_list = [mInfo.module_full_name for mInfo in self.module_info_list]
for node in gm.graph.nodes:
total_mem = total_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node)
if node.op == "call_module":
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
self._non_model_data_cuda_list.append(total_mem)
last_node = node
self._non_model_data_cuda_list.append(total_mem)
self._non_model_data_cuda_list = self._non_model_data_cuda_list[1:]
cur_module_mem_fwd = 0
cur_module_mem_bwd = 0
grad_module_out = last_node.meta["fwd_mem_out"]
for node in gm.graph.nodes.__reversed__():
cur_module_mem_fwd = cur_module_mem_fwd + calculate_fwd_tmp(node) + calculate_fwd_out(node)
cur_module_mem_bwd = cur_module_mem_bwd + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
if node.op == "call_module":
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
self._non_model_data_cuda_list.append(total_mem + grad_module_out + cur_module_mem_bwd)
total_mem = total_mem - cur_module_mem_fwd
cur_module_mem_fwd = 0
cur_module_mem_bwd = 0
grad_module_out = node.meta["bwd_mem_out"]
self._step_total = len(self._non_model_data_cuda_list)
self.recover_module()
def refactor_module(self):
for modInfo in self.module_info_list:
temp_node = nn.Sequential(nn.ReLU(), modInfo.module)
modInfo.parent_module.__setattr__(modInfo.module_name, temp_node)
def recover_module(self):
for modInfo in self.module_info_list:
modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module)
def register_opnodes_recursively(self,
module: torch.nn.Module,
name: str = "",
full_name: str = "",
parent_module: Optional[torch.nn.Module] = None):
assert isinstance(module, torch.nn.Module)
for child_name, child in module.named_children():
self.register_opnodes_recursively(child, child_name, full_name + "_" + child_name, module)
# Early return on modules with no parameters.
if len(list(module.parameters(recurse=False))) == 0:
return
self.module_info_list.append(ModuleInfos(module, name, full_name[1:], parent_module))
import functools
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from time import time from time import time
from typing import List, Optional, Tuple, Dict from typing import Dict, List, Optional, Tuple, Type
import torch import torch
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollectorV2
from typing import Type
import functools
from colossalai.gemini.chunk import Chunk, ChunkManager from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.gemini.memory_tracer import ChunkMemStatsCollector
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
class PlacementPolicy(ABC): class PlacementPolicy(ABC):
need_mem_stats: bool = False need_mem_stats: bool = False
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
self.chunk_manager = chunk_manager self.chunk_manager = chunk_manager
self.mem_stats_collector: Optional[MemStatsCollectorV2] = mem_stats_collector self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
@abstractmethod @abstractmethod
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
...@@ -29,7 +31,9 @@ class PlacementPolicy(ABC): ...@@ -29,7 +31,9 @@ class PlacementPolicy(ABC):
class CPUPlacementPolicy(PlacementPolicy): class CPUPlacementPolicy(PlacementPolicy):
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
...@@ -44,7 +48,9 @@ class CPUPlacementPolicy(PlacementPolicy): ...@@ -44,7 +48,9 @@ class CPUPlacementPolicy(PlacementPolicy):
class CUDAPlacementPolicy(PlacementPolicy): class CUDAPlacementPolicy(PlacementPolicy):
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available' assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
...@@ -65,7 +71,9 @@ class AutoPlacementPolicy(PlacementPolicy): ...@@ -65,7 +71,9 @@ class AutoPlacementPolicy(PlacementPolicy):
_warmup_non_model_data_ratio: float = 0.8 _warmup_non_model_data_ratio: float = 0.8
_steady_cuda_cap_ratio: float = 0.9 _steady_cuda_cap_ratio: float = 0.9
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self, def evict_tensors(self,
...@@ -154,7 +162,9 @@ class ConstPlacementPolicy(PlacementPolicy): ...@@ -154,7 +162,9 @@ class ConstPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = False need_mem_stats: bool = False
_accessed_memory_boundary = 512 * 1024**2 _accessed_memory_boundary = 512 * 1024**2
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self, def evict_tensors(self,
......
import functools import functools
import itertools
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Optional, Iterator, Tuple
from copy import deepcopy from copy import deepcopy
import itertools from typing import Any, Iterator, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector
from colossalai.gemini.ophooks import register_ophooks_recursively from colossalai.gemini.ophooks import register_ophooks_recursively
from colossalai.zero.utils import ZeroHook
from colossalai.gemini.paramhooks import BaseParamHookMgr from colossalai.gemini.paramhooks import BaseParamHookMgr
from colossalai.gemini.stateful_tensor import TensorState
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory
from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device, disposable from colossalai.utils import disposable, get_current_device
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector, MemStatsCollectorStatic
from colossalai.utils.memory import colo_device_memory_capacity from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from torch.distributed import ProcessGroup from colossalai.zero.utils import ZeroHook
from torch.nn.parameter import Parameter
from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu
from colossalai.gemini.stateful_tensor import TensorState
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, from ._utils import (
get_gradient_predivide_factor) cast_float_arguments,
cast_tensor_to_fp16,
cast_tensor_to_fp32,
chunk_and_pad,
free_storage,
get_gradient_predivide_factor,
)
try: try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
...@@ -116,7 +124,7 @@ class ShardedModelV2(nn.Module): ...@@ -116,7 +124,7 @@ class ShardedModelV2(nn.Module):
self._use_memory_tracer = tensor_placement_policy == 'auto' self._use_memory_tracer = tensor_placement_policy == 'auto'
if self._use_memory_tracer: if self._use_memory_tracer:
if self.user_static_memstats: if self.user_static_memstats:
self._memstats_collector = MemStatsCollectorStatic(self.module) self._memstats_collector = StaticMemStatsCollector(self.module)
else: else:
self._memstats_collector = MemStatsCollector() self._memstats_collector = MemStatsCollector()
self._start_collect_memstats = disposable(self._memstats_collector.start_collection) self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment