"tools/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "0685f4ffc9b3171b02315155b44e4a700964dd1a"
Unverified Commit 340e59f9 authored by HELSON's avatar HELSON Committed by GitHub
Browse files

[utils] add synchronized cuda memory monitor (#740)

parent e6212f56
from cgitb import Hook
from colossalai.registry import HOOKS from colossalai.registry import HOOKS
from torch import Tensor from torch import Tensor
from colossalai.trainer.hooks import BaseHook from colossalai.trainer.hooks import BaseHook
from colossalai.utils.memory_tracer import AsyncMemoryMonitor from colossalai.utils.memory_tracer import AsyncMemoryMonitor
from ._metric_hook import LearningRateMetric, MetricHook
@HOOKS.register_module @HOOKS.register_module
class MemTraceHook(BaseHook): class MemTraceHook(BaseHook):
...@@ -11,6 +10,7 @@ class MemTraceHook(BaseHook): ...@@ -11,6 +10,7 @@ class MemTraceHook(BaseHook):
This hook is used to record memory usage info, and pass to trainer.states This hook is used to record memory usage info, and pass to trainer.states
You can use it as other trainer hook and fetch data from trainer.states['metrics][mode] You can use it as other trainer hook and fetch data from trainer.states['metrics][mode]
""" """
def __init__( def __init__(
self, self,
priority: int = 0, priority: int = 0,
...@@ -36,9 +36,9 @@ class MemTraceHook(BaseHook): ...@@ -36,9 +36,9 @@ class MemTraceHook(BaseHook):
def before_test_iter(self, trainer): def before_test_iter(self, trainer):
self._memory_monitor.start() self._memory_monitor.start()
return super().before_test(trainer) return super().before_test(trainer)
def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
self._memory_monitor.finish() self._memory_monitor.finish()
trainer.states['metrics']['train'] = self._memory_monitor.state_dict trainer.states['metrics']['train'] = self._memory_monitor.state_dict
trainer.states['metrics']['test'] = self._memory_monitor.state_dict trainer.states['metrics']['test'] = self._memory_monitor.state_dict
return super().after_test_iter(trainer, output, label, loss) return super().after_test_iter(trainer, output, label, loss)
\ No newline at end of file
from .async_memtracer import AsyncMemoryMonitor from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor
from .memstats_collector import MemStatsCollector from .memstats_collector import MemStatsCollector
__all__ = ['AsyncMemoryMonitor', 'MemStatsCollector'] __all__ = ['AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector']
from concurrent.futures import ThreadPoolExecutor from abc import abstractmethod
from time import sleep, time from concurrent.futures import ThreadPoolExecutor
import pickle from time import sleep, time
import json
import torch
import torch
from colossalai.utils.memory import colo_device_memory_used
from colossalai.utils import get_current_device from colossalai.utils.memory import colo_device_memory_used
from colossalai.utils import get_current_device
class AsyncMemoryMonitor:
""" class MemoryMonitor:
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU """Base class for all types of memory monitor.
at interval of `1/(10**power)` sec. All monitors should have a list called `time_stamps` and a list called `mem_stats`.
"""
The idea comes from Runtime Memory Tracer of PatrickStar
`PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_ def __init__(self):
self.time_stamps = []
Usage:: self.mem_stats = []
async_mem_monitor = AsyncMemoryMonitor() def __len__(self):
input = torch.randn(2, 20).cuda() return len(self.mem_stats)
OP1 = torch.nn.Linear(20, 30).cuda()
OP2 = torch.nn.Linear(30, 40).cuda() @abstractmethod
def start(self):
async_mem_monitor.start() pass
output = OP1(input)
async_mem_monitor.finish() @abstractmethod
async_mem_monitor.start() def finish(self):
output = OP2(output) pass
async_mem_monitor.finish()
async_mem_monitor.save('log.pkl') def state_dict(self):
return {
"time_stamps": self.time_stamps,
Args: "mem_stats": self.mem_stats,
power (int, optional): the power of time interva. Defaults to 10. }
.. _PatrickStar\: Parallel Training of Pre-trained Models via Chunk-based Memory Management: def save(self, filename):
https://arxiv.org/abs/2108.05818 with open(filename, "w") as f:
""" json.dump(self.state_dict(), f)
def __init__(self, power: int = 10): def clear(self):
self.keep_measuring = False self.mem_stats.clear()
self.time_stamps.clear()
current_device = get_current_device()
def _set_cuda_device(): class AsyncMemoryMonitor(MemoryMonitor):
torch.cuda.set_device(current_device) """
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device) at interval of `1/(10**power)` sec.
self.monitor_thread = None
self.interval = 1 / (10**power) The idea comes from Runtime Memory Tracer of PatrickStar
self.time_stamps = [] `PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
self.mem_stats = []
Usage::
def __len__(self):
return len(self.mem_stats) async_mem_monitor = AsyncMemoryMonitor()
input = torch.randn(2, 20).cuda()
def set_interval(self, power: int): OP1 = torch.nn.Linear(20, 30).cuda()
self.clear() OP2 = torch.nn.Linear(30, 40).cuda()
self.interval = 1 / (10**power)
async_mem_monitor.start()
def is_measuring(self): output = OP1(input)
return self.keep_measuring async_mem_monitor.finish()
async_mem_monitor.start()
def start(self): output = OP2(output)
self.keep_measuring = True async_mem_monitor.finish()
self.monitor_thread = self.executor.submit(self._measure_usage) async_mem_monitor.save('log.pkl')
def finish(self): Args:
if self.keep_measuring is False: power (int, optional): the power of time interva. Defaults to 10.
return 0
self.keep_measuring = False .. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
max_usage = self.monitor_thread.result() https://arxiv.org/abs/2108.05818
self.monitor_thread = None """
self.time_stamps.append(time())
self.mem_stats.append(max_usage) def __init__(self, power: int = 10):
return max_usage super().__init__()
self.keep_measuring = False
def _measure_usage(self):
max_usage = 0 current_device = get_current_device()
while self.keep_measuring:
max_usage = max( def _set_cuda_device():
max_usage, torch.cuda.set_device(current_device)
colo_device_memory_used(get_current_device()),
) self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device)
sleep(self.interval) self.monitor_thread = None
return max_usage self.interval = 1 / (10**power)
@property def set_interval(self, power: int):
def state_dict(self): self.clear()
return { self.interval = 1 / (10**power)
"time_stamps": self.time_stamps,
"mem_stats": self.mem_stats, def is_measuring(self):
} return self.keep_measuring
def save(self, filename): def start(self):
with open(filename, "wb") as f: self.keep_measuring = True
pickle.dump(self.state_dict(), f) self.monitor_thread = self.executor.submit(self._measure_usage)
def clear(self): def finish(self):
self.mem_stats.clear() if self.keep_measuring is False:
self.time_stamps.clear() return 0
self.keep_measuring = False
max_usage = self.monitor_thread.result()
self.monitor_thread = None
self.time_stamps.append(time())
self.mem_stats.append(max_usage)
return max_usage
def _measure_usage(self):
max_usage = 0
while self.keep_measuring:
max_usage = max(
max_usage,
colo_device_memory_used(get_current_device()),
)
sleep(self.interval)
return max_usage
class SyncCudaMemoryMonitor(MemoryMonitor):
"""
A synchronized cuda memory monitor.
It only record the maximum allocated cuda memory from start point to finish point.
"""
def __init__(self, power: int = 10):
super().__init__()
def start(self):
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
def finish(self):
torch.cuda.synchronize()
self.time_stamps.append(time())
max_usage = torch.cuda.max_memory_allocated()
self.mem_stats.append(max_usage)
return max_usage
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_device_memory_used from colossalai.utils.memory import colo_device_memory_used
from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor from colossalai.utils.memory_tracer import AsyncMemoryMonitor
import torch import torch
import time import time
from typing import List from typing import List
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment