"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "f7076cd346f48aee850b8c54e6e129c33a404308"
Unverified Commit fb7fca71 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer metrics] fix cpu mem metrics; reformat runtime metric (#10937)



* fix cpu mem metrics; reformat runtime metric

* adjust dependency

* extend docs

* soft dependency

* cleanup

* fix the runtime metric issue

* restore

* move docs, cross reference from 2 places, improve

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 5057213b
......@@ -16,6 +16,7 @@
Torch utilities for the Trainer class.
"""
import datetime
import json
import math
import os
......@@ -615,6 +616,15 @@ def _get_learning_rate(self):
return last_lr
def _secs2timedelta(secs):
"""
convert seconds to hh:mm:ss.msec, msecs rounded to 2 decimals
"""
msec = int(abs(secs - int(secs)) * 100)
return f"{datetime.timedelta(seconds=int(secs))}.{msec:02d}"
def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]:
"""
Reformat Trainer metrics values to a human-readable format
......@@ -631,6 +641,8 @@ def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]:
for k, v in metrics_copy.items():
if "_mem_" in k:
metrics_copy[k] = f"{ v >> 20 }MB"
elif "_runtime" in k:
metrics_copy[k] = _secs2timedelta(v)
elif k == "total_flos":
metrics_copy[k] = f"{ int(v) >> 30 }GF"
elif type(metrics_copy[k]) == float:
......@@ -650,6 +662,72 @@ def log_metrics(self, split, metrics):
Mode/split name: one of ``train``, ``eval``, ``test``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predictmetrics: metrics dict
Notes on memory reports:
In order to get memory usage report you need to install ``psutil``. You can do that with ``pip install psutil``.
Now when this method is run, you will see a report that will include: ::
init_mem_cpu_alloc_delta = 1301MB
init_mem_cpu_peaked_delta = 154MB
init_mem_gpu_alloc_delta = 230MB
init_mem_gpu_peaked_delta = 0MB
train_mem_cpu_alloc_delta = 1345MB
train_mem_cpu_peaked_delta = 0MB
train_mem_gpu_alloc_delta = 693MB
train_mem_gpu_peaked_delta = 7MB
**Understanding the reports:**
- the first segment, e.g., ``train__``, tells you which stage the metrics are for. Reports starting with ``init_``
will be added to the first stage that gets run. So that if only evaluation is run, the memory usage for the
``__init__`` will be reported along with the ``eval_`` metrics.
- the third segment, is either ``cpu`` or ``gpu``, tells you whether it's the general RAM or the gpu0 memory
metric.
- ``*_alloc_delta`` - is the difference in the used/allocated memory counter between the end and the start of the
stage - it can be negative if a function released more memory than it allocated.
- ``*_peaked_delta`` - is any extra memory that was consumed and then freed - relative to the current allocated
memory counter - it is never negative. When you look at the metrics of any stage you add up ``alloc_delta`` +
``peaked_delta`` and you know how much memory was needed to complete that stage.
The reporting happens only for process of rank 0 and gpu 0 (if there is a gpu). Typically this is enough since the
main process does the bulk of work, but it could be not quite so if model parallel is used and then other GPUs may
use a different amount of gpu memory. This is also not the same under DataParallel where gpu0 may require much more
memory than the rest since it stores the gradient and optimizer states for all participating GPUS. Perhaps in the
future these reports will evolve to measure those too.
The CPU RAM metric measures RSS (Resident Set Size) includes both the memory which is unique to the process and the
memory shared with other processes. It is important to note that it does not include swapped out memory, so the
reports could be imprecise.
The CPU peak memory is measured using a sampling thread. Due to python's GIL it may miss some of the peak memory if
that thread didn't get a chance to run when the highest memory was used. Therefore this report can be less than
reality. Using ``tracemalloc`` would have reported the exact peak memory, but it doesn't report memory allocations
outside of python. So if some C++ CUDA extension allocated its own memory it won't be reported. And therefore it
was dropped in favor of the memory sampling approach, which reads the current process memory usage.
The GPU allocated and peak memory reporting is done with ``torch.cuda.memory_allocated()`` and
``torch.cuda.max_memory_allocated()``. This metric reports only "deltas" for pytorch-specific allocations, as
``torch.cuda`` memory management system doesn't track any memory allocated outside of pytorch. For example, the
very first cuda call typically loads CUDA kernels, which may take from 0.5 to 2GB of GPU memory.
Note that this tracker doesn't account for memory allocations outside of :class:`~transformers.Trainer`'s
``__init__``, ``train``, ``evaluate`` and ``predict`` calls.
Because ``evaluation`` calls may happen during ``train``, we can't handle nested invocations because
``torch.cuda.max_memory_allocated`` is a single counter, so if it gets reset by a nested eval call, ``train``'s
tracker will report incorrect info. If this `pytorch issue <https://github.com/pytorch/pytorch/issues/16266>`__
gets resolved it will be possible to change this class to be re-entrant. Until then we will only track the outer
level of ``train``, ``evaluate`` and ``predict`` methods. Which means that if ``eval`` is called during ``train``,
it's the latter that will account for its memory usage and that of the former.
This also means that if any other tool that is used along the :class:`~transformers.Trainer` calls
``torch.cuda.reset_peak_memory_stats``, the gpu peak memory stats could be invalid. And the
:class:`~transformers.Trainer` will disrupt the normal behavior of any such tools that rely on calling
``torch.cuda.reset_peak_memory_stats`` themselves.
For best performance you may want to consider turning the memory profiling off for production runs.
"""
if not self.is_world_process_zero():
return
......@@ -675,6 +753,10 @@ def save_metrics(self, split, metrics, combined=True):
The metrics returned from train/evaluate/predict
combined (:obj:`bool`, `optional`, defaults to :obj:`True`):
Creates combined metrics by updating ``all_results.json`` with metrics of this call
To understand the metrics please read the docstring of :meth:`~transformers.Trainer.log_metrics`. The only
difference is that raw unformatted numbers are saved in the current method.
"""
if not self.is_world_process_zero():
return
......
......@@ -22,14 +22,15 @@ import inspect
import os
import random
import re
import threading
import time
import tracemalloc
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
import numpy as np
from .file_utils import (
ExplicitEnum,
is_psutil_available,
is_sagemaker_distributed_available,
is_tf_available,
is_torch_available,
......@@ -258,6 +259,8 @@ class TrainerMemoryTracker:
"""
A helper class that tracks cpu and gpu memory.
This class will silently skip unless ``psutil`` is available. Install with ``pip install psutil``.
When a stage completes, it can pass metrics dict to update with the memory metrics gathered during this stage.
Example ::
......@@ -268,37 +271,9 @@ class TrainerMemoryTracker:
metrics = {"train_runtime": 10.5}
self._memory_tracker.stop_and_update_metrics(metrics)
At the moment gpu tracking is only for pytorch, but can be extended to support tensorflow.
Understanding the reports:
- ``*_alloc_delta`` - is the difference in the used/allocated memory counter between the end and the start of the
stage - it can be negative if a function released more memory than it allocated.
- ``*_peaked_delta`` - is any extra memory that was consumed and then freed - relative to the current allocated
memory counter - it is never negative.
So when you look at the metrics of any stage you add up ``alloc_delta`` + ``peaked_delta`` and you know how much
memory was needed to complete that stage.
At the moment GPU tracking is only for ``pytorch``, but can be extended to support ``tensorflow``.
The reporting happens only for process of rank 0 and gpu 0 (if there is a gpu). Typically this is enough since the
main process does the bulk of work, but it could be not quite so if model parallel is used and then other gpus may
use a different amount of gpu RAM. Perhaps in the future this tracker will evolve to measure those too.
Note that this tracker doesn't account for memory allocations outside of :class:`~transformers.Trainer`'s
``__init__``, ``train``, ``evaluate`` and ``predict`` calls.
Because ``evaluation`` calls may happen during ``train``, we can't handle nested invocations because
``torch.cuda.max_memory_allocated`` is a single counter, so if it gets reset by a nested eval call, ``train``'s
tracker will report incorrect info. If this `pytorch issue <https://github.com/pytorch/pytorch/issues/16266>`__
gets resolved it will be possible to change this class to be re-entrant. Until then we will only track the outer
level of ``train``, ``evaluate`` and ``predict`` methods. Which means that if ``eval`` is called during ``train``,
it's the latter that will account for its memory usage and that of the former.
This also means that if any other tool that is used along the :class:`~transformers.Trainer` calls
``torch.cuda.reset_peak_memory_stats``, the gpu peak memory stats could be invalid. And the
:class:`~transformers.Trainer` will disrupt the normal behavior of any such tools that rely on calling
``torch.cuda.reset_peak_memory_stats`` themselves.
To understand this class' intricacies please read the documentation of :meth:`~transformers.Trainer.log_metrics`.
"""
......@@ -311,6 +286,18 @@ class TrainerMemoryTracker:
}
def __init__(self, skip_memory_metrics=False):
self.skip_memory_metrics = skip_memory_metrics
if not is_psutil_available():
# soft dependency on psutil
self.skip_memory_metrics = True
if self.skip_memory_metrics:
return
import psutil # noqa
if is_torch_cuda_available():
import torch
......@@ -319,10 +306,11 @@ class TrainerMemoryTracker:
else:
self.torch = None
self.process = psutil.Process()
self.cur_stage = None
self.cpu = {}
self.init_reported = False
self.skip_memory_metrics = skip_memory_metrics
def derive_stage(self):
""" derives the stage/caller name automatically """
......@@ -334,6 +322,22 @@ class TrainerMemoryTracker:
f"was called from {caller}, but only expect to be called from one of {self.stages.keys()}"
)
def cpu_mem_used(self):
""" get resident set size memory for the current process """
return self.process.memory_info().rss
def peak_monitor_func(self):
self.cpu_mem_used_peak = -1
while True:
self.cpu_mem_used_peak = max(self.cpu_mem_used(), self.cpu_mem_used_peak)
# can't sleep or will not catch the peak right (this comment is here on purpose)
# time.sleep(0.001) # 1msec
if not self.peak_monitoring:
break
def start(self):
""" start tracking for the caller's stage """
if self.skip_memory_metrics:
......@@ -346,21 +350,23 @@ class TrainerMemoryTracker:
self.cur_stage = stage
gc.collect()
if self.torch is not None:
self.torch.cuda.reset_peak_memory_stats()
self.torch.cuda.empty_cache()
gc.collect()
# gpu
if self.torch is not None:
self.gpu[self.cur_stage] = {}
self.gpu[self.cur_stage]["alloc"] = self.torch.cuda.memory_allocated()
self.gpu[self.cur_stage]["peaked"] = 0
self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated()
# cpu
self.cpu[self.cur_stage] = {}
tracemalloc.start()
self.cpu_mem_used_at_start = self.cpu_mem_used()
self.peak_monitoring = True
peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
peak_monitor_thread.daemon = True
peak_monitor_thread.start()
def stop(self, stage):
""" stop tracking for the passed stage """
......@@ -369,24 +375,35 @@ class TrainerMemoryTracker:
if self.cur_stage is not None and self.cur_stage != stage:
return
# this sends a signal to peak_monitor_func to complete its loop
self.peak_monitoring = False
# first ensure all objects get collected and their memory is freed
gc.collect()
if self.torch is not None:
self.torch.cuda.empty_cache()
gc.collect()
# concepts:
# - alloc_delta: the difference of allocated memory between the end and the start
# - peaked_delta: the difference between the peak memory and the current memory
# in order to know how much memory the measured code consumed one needs to sum these two
# gpu
if self.torch is not None:
mem_cur = self.torch.cuda.memory_allocated()
# this is the difference between the start and the end allocated memory
self.gpu[self.cur_stage]["alloc"] = mem_cur - self.gpu[self.cur_stage]["alloc"] # can be negative
# this is the difference if any between the start and the peak
self.gpu[self.cur_stage]["peaked"] = max(0, self.torch.cuda.max_memory_allocated() - mem_cur)
self.gpu_mem_used_now = self.torch.cuda.memory_allocated()
self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated()
self.gpu[self.cur_stage] = dict(
alloc=(self.gpu_mem_used_now - self.gpu_mem_used_at_start),
peaked=max(0, self.gpu_mem_used_peak - self.gpu_mem_used_now),
)
# cpu
cpu_mem_used_delta, cpu_mem_used_peak = tracemalloc.get_traced_memory()
tracemalloc.stop() # reset accounting
self.cpu[self.cur_stage]["alloc"] = cpu_mem_used_delta # can be negative
self.cpu[self.cur_stage]["peaked"] = max(0, cpu_mem_used_peak - cpu_mem_used_delta)
self.cpu_mem_used_now = self.cpu_mem_used()
self.cpu[self.cur_stage] = dict(
alloc=(self.cpu_mem_used_now - self.cpu_mem_used_at_start),
peaked=max(0, self.cpu_mem_used_peak - self.cpu_mem_used_now),
)
# reset - cycle finished
self.cur_stage = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment