"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "16122d5fac6deb065c5ff713c5fc84ee946f8c0e"
Unverified Commit d7e0303d authored by ver217's avatar ver217 Committed by GitHub
Browse files

[zero] use GeminiMemoryManager when sampling model data (#850)

parent 232142f4
from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.utils.memory import colo_device_memory_used from colossalai.utils.memory import colo_device_memory_used
from colossalai.gemini.stateful_tensor import StatefulTensor
import torch import torch
import time import time
...@@ -92,7 +92,8 @@ class MemStatsCollector: ...@@ -92,7 +92,8 @@ class MemStatsCollector:
"""Sampling model data statistics. """Sampling model data statistics.
""" """
if self._start_flag: if self._start_flag:
cuda_mem, cpu_mem = GLOBAL_MODEL_DATA_TRACER.both_mem_usage cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
self._model_data_cuda_list.append(cuda_mem) self._model_data_cuda_list.append(cuda_mem)
self._model_data_cpu_list.append(cpu_mem) self._model_data_cpu_list.append(cpu_mem)
...@@ -114,24 +115,6 @@ class MemStatsCollector: ...@@ -114,24 +115,6 @@ class MemStatsCollector:
self._sampling_time.append(time.time()) self._sampling_time.append(time.time())
self._mem_monitor.start() self._mem_monitor.start()
def sample_memstats(self) -> None:
"""
Sampling memory statistics.
Record the current model data CUDA memory usage as well as system CUDA memory usage.
Advance the sampling cnter.
"""
if self._start_flag:
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
self._overall_cuda_list.append(self._mem_monitor.finish())
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage)
# FIXME(jiaruifang) cpu sys used should also return from self._mem_monitor()
self._overall_cpu_list.append(colo_device_memory_used(torch.device(f'cpu')))
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
self._sampling_time.append(time.time())
self._mem_monitor.start()
def clear(self) -> None: def clear(self) -> None:
self._model_data_cuda_list = [] self._model_data_cuda_list = []
self._overall_cuda_list = [] self._overall_cuda_list = []
......
...@@ -7,7 +7,6 @@ from colossalai.utils.memory import colo_device_memory_capacity ...@@ -7,7 +7,6 @@ from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.gemini.stateful_tensor import StatefulTensor from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.gemini.memory_tracer import MemStatsCollector from colossalai.gemini.memory_tracer import MemStatsCollector
from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER
from typing import Type from typing import Type
...@@ -79,7 +78,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy): ...@@ -79,7 +78,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
""" """
volume = 0 volume = 0
cuda_capacity = colo_device_memory_capacity(get_current_device()) cuda_capacity = colo_device_memory_capacity(get_current_device())
used_cuda_model_data = GLOBAL_MODEL_DATA_TRACER.cuda_usage used_cuda_model_data = StatefulTensor.GST_MGR.total_mem['cuda']
if warmup: if warmup:
# We designate a part of CUDA memory for model data in warmup iterations. # We designate a part of CUDA memory for model data in warmup iterations.
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
......
...@@ -13,8 +13,6 @@ from colossalai.engine.paramhooks import BaseParamHookMgr ...@@ -13,8 +13,6 @@ from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device, disposable from colossalai.utils import get_current_device, disposable
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.gemini.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_device_memory_capacity from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
...@@ -106,7 +104,6 @@ class ShardedModelV2(nn.Module): ...@@ -106,7 +104,6 @@ class ShardedModelV2(nn.Module):
self._use_memory_tracer = tensor_placement_policy == 'auto' self._use_memory_tracer = tensor_placement_policy == 'auto'
if self._use_memory_tracer: if self._use_memory_tracer:
GLOBAL_MODEL_DATA_TRACER.register_model(self)
self._memstats_collector = MemStatsCollector() self._memstats_collector = MemStatsCollector()
self._start_collect_memstats = disposable(self._memstats_collector.start_collection) self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection) self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
......
...@@ -10,10 +10,7 @@ from colossalai.context.parallel_mode import ParallelMode ...@@ -10,10 +10,7 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.gemini.memory_tracer.model_data_memtracer import \ from colossalai.gemini.tensor_utils import (colo_model_data_tensor_move_inline, colo_tensor_mem_usage)
GLOBAL_MODEL_DATA_TRACER
from colossalai.gemini.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
colo_tensor_mem_usage)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from torch import Tensor from torch import Tensor
...@@ -130,8 +127,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer): ...@@ -130,8 +127,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0]) f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0])
self._use_memory_tracer = self.model.use_memory_tracer self._use_memory_tracer = self.model.use_memory_tracer
if self._use_memory_tracer:
GLOBAL_MODEL_DATA_TRACER.register_optimizer(self)
@property @property
def loss_scale(self): def loss_scale(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment