Unverified Commit ad92220c authored by Quentin Duval's avatar Quentin Duval Committed by GitHub
Browse files

[feat] layer memory tracking (#808)



* [feat] layer memory tracking

* [feat] layer memory tracking (add tests in CI)

* [feat] layer memory tracking: doc typos

* [feat] layer memory tracking: mypy fixes

* [feat] layer memory tracking: fixes for FSDP all gather tracking on pytorch 1.9 and above

* [feat] layer memory tracking: lint

* [feat] layer memory tracking: mypy
Co-authored-by: default avatarQuentinDuval <QuentinDuval@users.noreply.github.com>
parent 51e43b61
......@@ -55,6 +55,7 @@ modules and easy to use APIs.
tutorials/offload_model
tutorials/adascale
tutorials/pipe
tutorials/layer_memory_tracking
|
|
......
Efficient memory usage using Activation Checkpointing
=====================================================
Adaped from `torch.utils.checkpoint`, this is a friendlier wrapper for performing activation checkpointing.
Adapted from `torch.utils.checkpoint`, this is a friendlier wrapper for performing activation checkpointing.
Compared to the PyTorch version, this version wraps a `nn.Module` and allows for all subsequent calls to be
checkpointed.
......
Tooling to diagnose and fix memory problems
===========================================
FairScale comes with some experimental tooling to help track, visualize and suggest fix for memory issues occurring during the forward/backward pass of your models.
Visualizing the memory profile
------------------------------
To track and visualize the memory profile of a model, you can use the `LayerwiseMemoryTracker`:
.. code-block:: python
from fairscale.experimental.tooling.layer_memory_tracker import LayerwiseMemoryTracker
import torch
import torchvision.models
# Create a model
model = torchvision.models.resnet50().cuda()
criterion = torch.nn.CrossEntropyLoss()
# Create some dummy inputs
batch_size = 16
x = torch.randn(size=(batch_size, 3, 224, 224)).cuda()
y = torch.tensor(list(range(batch_size)), dtype=torch.int64).cuda()
# Start monitoring the model
tracker = LayerwiseMemoryTracker()
tracker.monitor(model)
# Do a forward/backward with dummy inputs
criterion(model(x), y).backward()
# Stop monitoring the model
tracker.stop()
# Show some useful default plots
tracker.show_plots()
The resulting graphs will include:
- a graph of the memory profile (memory allocated and reserved) during the forward/backward
.. image:: _static/img/layer_memory_profiles.png
:width: 1000px
:align: center
- a graph of the amount of memory allocations done for activations done during the forward/backward
.. image:: _static/img/layer_memory_activations.png
:width: 1000px
:align: center
- a graph of the amount of memory used for parameters by each the layers traversed done during the forward/backward
.. image:: _static/img/layer_memory_parameters.png
:width: 500px
:align: center
In all these graphs:
- the blue part of the curve is used for the forward pass, the orange for the backward pass
- the X axis is only used for ordering of the computational steps (it does not represent the index of the layer in the model)
How to use those graphs?
------------------------
It is not always obvious to understand how much memory a model will be using. Those graphs allows to visualize:
- what is the main cause of memory consumption: this would be memory activations in the graph above
- what are the layers that are worth sharding: those at the end of the convolution net as in the case above
- where should we place activation checkpoints to diminish memory consumption
If those graphs are not useful to you, you can always use the raw data collected by the `LayerwiseMemoryTracker` instead,
or use any of the other utility functions provided in the tool:
.. code-block:: python
# Access all raw traces / forward traces only / backward traces only
tracker.memory_traces
tracker.forward_traces
tracker.backward_traces
# Access a quick summary of the traces with information on:
# - the peak memory usage
# - the top layers in terms of memory consumption
tracker.summary
Activation checkpoint suggestions
---------------------------------
In additional to visualisation, the `LayerwiseMemoryTracker` traces can be used to suggest activation checkpoints
locations, which can be used to reduce the memory consumption of the forward/backward, but trading some compute:
.. code-block:: python
from fairscale.experimental.tooling.layer_memory_tracker import suggest_checkpoint_location
suggestion = suggest_checkpoint_location(tracker.memory_traces, num_checkpoints=0)
print(suggestion.max_memory) # Outputs: 1435630080
suggestion = suggest_checkpoint_location(tracker.memory_traces, num_checkpoints=2)
print(suggestion.max_memory) # Outputs: 485095936
print(suggestion.split_modules) # Outputs: ['layer1.1.bn3', 'layer2.2.conv3']
This sample code tells us that we can reduce the memory consumption due to activations from 1.4G to around 500M by
checkpointing activations at the locations `layer1.1.bn3` and `layer2.2.conv3`.
These locations can serve as first guesses and might not always be practical due to the model code. In the case of a
torchvision resnet, we can adapt those locations by trying to checkpoint around layer1 and layer2:
.. code-block:: python
model = torchvision.models.resnet50().cuda()
model.layer1 = checkpoint_wrapper(model.layer1)
model.layer3 = checkpoint_wrapper(torch.nn.Sequential(model.layer2, model.layer3))
model.layer2 = torch.nn.Identity()
Leading to the following memory profile, saving around 400MB of activation memory at the cost of more compute:
.. image:: _static/img/layer_memory_profile_optimized.png
:width: 500px
:align: center
Dedicated features to FSDP distributed training
-----------------------------------------------
When training a big model with `FullyShardedDataParallel`, you can use the `LayerwiseMemoryTracker` to track the
amount of memory exchanged by FSDP to consolidate sharded layers:
.. code-block:: python
from fairscale.nn import FullyShardedDataParallel as FSDP
from fairscale.experimental.tooling.layer_memory_tracker import ProcessGroupTracker
# Create a process group for FSDP
group = torch.distributed.new_group()
group = ProcessGroupTracker(group)
# Create a FSDP model
model = torchvision.models.resnet50().cuda()
model.layer1 = FSDP(model.layer1, process_group=group)
model.layer2 = FSDP(model.layer2, process_group=group)
model.layer3 = FSDP(model.layer3, process_group=group)
model.layer4 = FSDP(model.layer4, process_group=group)
model = FSDP(model, process_group=group)
Now, the `LayerwiseMemoryTracker` will provide an additional graph where we can see:
- the memory spikes (in blue for forward, in orange for backward) of the `all_gather` calls
- an estimation (in green) of cumulative parameter memory (only available for the forward pass)
.. image:: _static/img/all_gathered_memory.png
:width: 500px
:align: center
Limitations
------------
The `LayerwiseMemoryTracker` has a bunch of limitations it is important to be aware of:
1. It only works on GPU models: models cannot sit on the CPU
2. Some of the GPU memory might not tracked by PyTorch (for example some NCCL buffers) and therefore will not be tracked with this tooling either
3. Beside memory allocated and memory cached, which are based on PyTorch, the results are based on heuristics, and might miss some memory in some cases
4. Some features (such as cumulative all gathered memory for FSDP) do not work in the backward pass
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum, auto
from functools import lru_cache
from typing import Any, Callable, Dict, Iterator, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from torch.utils.hooks import RemovableHandle
from fairscale.nn import FullyShardedDataParallel
class TraceForwardEvent(NamedTuple):
"""
Complementary trace event collected during the forward pass
to trace the memory increase and the memory taken by activations
"""
memory_diff: int
memory_activations: int
def to_dict(self) -> Dict[str, Any]:
return {
"memory_diff": self.memory_diff,
"memory_activations": self.memory_activations,
}
@classmethod
def from_dict(cls, serialized: Dict[str, Any]) -> "TraceForwardEvent":
return TraceForwardEvent(
memory_diff=serialized["memory_diff"], memory_activations=serialized["memory_activations"],
)
class TraceBackwardEvent(NamedTuple):
"""
Complementary trace event collected during the forward pass
to trace the memory taken by activations
"""
memory_activations: int
def to_dict(self) -> Dict[str, Any]:
return {"memory_activations": self.memory_activations}
@classmethod
def from_dict(cls, serialized: Dict[str, Any]) -> "TraceBackwardEvent":
return TraceBackwardEvent(memory_activations=serialized["memory_activations"])
class LayerMemoryTrace(NamedTuple):
"""
Trace event providing the current memory usage at a point
occuring during the forward or backward
module_name: name of the module under processing
module_params: size of the module parameters
allocated: state of the PyTorch allocated memory
reserved: state of the PyTorch reserved memory
is_forward: whether the trace was collected during forward
all_gathered: memory gathered since last event by FSDP
cumul_all_gathered: total amount of memory currently gathered by FSDP
event: additional information on the trace
"""
module_name: str
module_params: int
allocated: int
reserved: int
is_forward: bool
all_gathered: int
cumul_all_gathered: int
event: Union[TraceForwardEvent, TraceBackwardEvent]
def to_dict(self) -> Dict[str, Any]:
return {
"module_name": self.module_name,
"module_params": self.module_params,
"allocated": self.allocated,
"reserved": self.reserved,
"is_forward": self.is_forward,
"all_gathered": self.all_gathered,
"cumul_all_gathered": self.cumul_all_gathered,
"event": self.event.to_dict(),
}
@classmethod
def from_dict(cls, serialized: Dict[str, Any]) -> "LayerMemoryTrace":
if serialized["is_forward"]:
event: Union[TraceForwardEvent, TraceBackwardEvent] = TraceForwardEvent.from_dict(serialized["event"])
else:
event = TraceBackwardEvent.from_dict(serialized["event"])
return LayerMemoryTrace(
module_name=serialized["module_name"],
module_params=serialized["module_params"],
allocated=serialized["allocated"],
reserved=serialized["reserved"],
is_forward=serialized["is_forward"],
all_gathered=serialized["all_gathered"],
cumul_all_gathered=serialized["cumul_all_gathered"],
event=event,
)
@dataclass
class LayerwiseMemoryTrackerSummary:
"""
Summary of the memory allocation during forward/backward
- max_memory_allocated: the peak of memory allocated
- max_memory_cached: the peak of memory cached by PyTorch
- total_activation_allocations: cumulative count of activations allocations
- total_forward_allocations: cumulative count of forward pass allocations
- top_forward_activation_producers: layers that allocated the most activations
"""
max_memory_allocated: int
max_memory_cached: int
total_activation_allocations: int
total_forward_allocations: int
top_forward_activation_producers: List[LayerMemoryTrace]
class ProcessGroupTrackingEvent(Enum):
"""
Types of events that can be tracked in the process group:
- allgather: will track calls to ProcessGroup.allgather
"""
allgather = auto()
class ProcessGroupTracker:
"""
To be used as a wrapper around a ProcessGroup to track
the calls to specific ProcessGroup function such as
"allgather" calls.
The tracker will send a notification to the listener
when such calls occur.
Best used in conjunction with LayerwiseMemoryTracker:
```
# wrap the group used for FSDP
group = ProcessGroupTracker(group)
# use this group when creating FSDP blocks
model = FullyShardedDataParallel(model, process_group=group),
# monitor the model as before
tracker = LayerwiseMemoryTracker()
tracker.monitor(model)
# the detailed traces will now contain information
# about the amount of all gathered data
tracker.memory_traces
```
"""
def __init__(self, group: Any, listener: Optional[Callable] = None):
self.group = group
self.listener = listener
def __getattr__(self, item: str) -> Any:
# Forward: for functions not traces
if item == "allgather":
# For PyTorch 1.8 and below
return self._build_wrapper(fct=self.group.allgather)
elif item == "_allgather_base":
# For PyTorch 1.9 and above
return self._build_wrapper(fct=getattr(self.group, item))
return getattr(self.group, item)
def _build_wrapper(self, fct: Callable) -> Callable:
def wrapper(
output_tensors: Union[torch.Tensor, Sequence[torch.Tensor]],
input_tensors: Union[torch.Tensor, Sequence[torch.Tensor]],
*args: list,
**kwargs: dict,
) -> Any:
if self.listener is not None:
self.listener(ProcessGroupTrackingEvent.allgather, output_tensors, input_tensors)
return fct(output_tensors, input_tensors, *args, **kwargs)
return wrapper
class LayerwiseMemoryTracker:
"""
Observe a module to get the graph of the memory consumption during
the forward and backward, layer by layer, with:
- a breakdown of the memory used (activations memory estimation)
- additional details such as amount of data exchanged with all gather
Requires the model to be on a CUDA device to track its memory
Example usage (no FSDP):
```
# create your model
model = models.resnet50().cuda()
# monitor the model
tracker = LayerwiseMemoryTracker()
tracker.monitor(model)
# Do a forward/backward
criterion(model(input), target).backward()
# show the plots
tracker.show_plots()
# get the detailed traces
tracker.memory_traces
# print a summary
print(tracker.summary)
```
Advanced usage (for FSDP):
```
# wrap the group used for FSDP
group = ProcessGroupTracker(group)
# use this group when creating FSDP blocks
model = FullyShardedDataParallel(model, process_group=group),
# monitor the model as before
tracker = LayerwiseMemoryTracker()
tracker.monitor(model)
# the detailed traces will now contain information
# about the amount of all gathered data
tracker.memory_traces
```
"""
def __init__(self) -> None:
self.memory_traces: List[LayerMemoryTrace] = []
self._hooks: List[RemovableHandle] = []
self._previous_module_name: Optional[str] = None
self._last_all_gather_memory = 0
self._cumul_all_gather_memory: List[int] = []
self._memory_pre_forward = 0
self._traced_module_names: Set[str] = set()
def monitor(self, model: nn.Module) -> None:
"""
Install hooks on the model to track its memory usage
"""
for name, m in model.named_modules():
h1 = m.register_forward_pre_hook(self._create_pre_forward_hook(name))
h2 = m.register_forward_hook(self._create_post_forward_hook(name))
h3 = m.register_backward_hook(self._create_backward_hook(name))
self._hooks.extend([h1, h2, h3])
if isinstance(m, FullyShardedDataParallel):
if isinstance(m.process_group, ProcessGroupTracker):
m.process_group.listener = self._handle_process_group_call
torch.cuda.empty_cache()
def clear_traces(self) -> None:
"""
Clear all the traces: new traces will be written on a clean slate
"""
self.memory_traces.clear()
def stop(self) -> None:
"""
Stop any form of tracking (removes the hooks used to monitor the model)
"""
for h in self._hooks:
h.remove()
self._hooks.clear()
self._previous_module_name = None
self._memory_pre_forward = 0
self._last_all_gather_memory = 0
self._cumul_all_gather_memory.clear()
@property
def forward_traces(self) -> List[LayerMemoryTrace]:
"""
Get the part of the traces which corresponds to the forward pass
"""
return [t for t in self.memory_traces if t.is_forward]
@property
def backward_traces(self) -> List[LayerMemoryTrace]:
"""
Get the part of the traces which corresponds to the backward pass
"""
return [t for t in self.memory_traces if not t.is_forward]
@property
def max_memory_allocated(self) -> int:
"""
Peak memory allocated during the forward/backward pass
"""
return max(t.allocated for t in self.memory_traces)
@property
def max_memory_cached(self) -> int:
"""
Peak memory cached during the forward/backward pass
"""
return max(t.reserved for t in self.memory_traces)
@property
def summary(self) -> LayerwiseMemoryTrackerSummary:
"""
A quick summary of interesting statistics on the memory usage
during the forward/backward pass
"""
total_diff = sum(t.event.memory_diff for t in self.forward_traces) # type: ignore
total_act = sum(t.event.memory_activations for t in self.forward_traces)
top_act_producers = self.top_forward_activation_producers(top=10)
return LayerwiseMemoryTrackerSummary(
max_memory_allocated=self.max_memory_allocated,
max_memory_cached=self.max_memory_cached,
total_activation_allocations=total_act,
total_forward_allocations=total_diff,
top_forward_activation_producers=top_act_producers,
)
def top_forward_activation_producers(self, top: int = 10) -> List[LayerMemoryTrace]:
"""
What are the top activation producers during the forward pass
"""
return sorted(self.forward_traces, key=lambda a: a.event.memory_activations, reverse=True)[:top]
def show_plots(self, figsize: Tuple[int, int] = (16, 20), capture: bool = False) -> Optional[Any]:
"""
Show useful memory plots. Use "capture=True" to return an image
rather than displaying the plots.
"""
return compare_memory_traces_in_plot({"run": self.memory_traces}, figsize=figsize, capture=capture)
def save_traces(self, path: str) -> None:
"""
Save the traces in a JSON file
"""
import json
with open(path, "w") as f:
json_traces = [t.to_dict() for t in self.memory_traces]
json.dump({"traces": json_traces}, f)
@classmethod
def load(cls, path: str) -> "LayerwiseMemoryTracker":
import json
out = cls()
with open(path, "r") as f:
traces = json.load(f)["traces"]
out.memory_traces = [LayerMemoryTrace.from_dict(t) for t in traces]
return out
def _create_pre_forward_hook(self, name: str) -> Callable:
def _pre_forward_hook(module: nn.Module, inputs: Any) -> None:
torch.cuda.synchronize()
allocated, reserved = self._capture_memory()
self._previous_module_name = name
self._memory_pre_forward = allocated
if isinstance(module, FullyShardedDataParallel):
self._cumul_all_gather_memory.append(0)
return _pre_forward_hook
def _handle_process_group_call(self, event: ProcessGroupTrackingEvent, *args: Sequence[Any]) -> None:
torch.cuda.synchronize()
if event == ProcessGroupTrackingEvent.allgather:
outputs, inputs = args
output_size = self._get_module_output_size(outputs)
self._last_all_gather_memory += output_size
if self._cumul_all_gather_memory:
self._cumul_all_gather_memory[-1] += output_size
def _create_post_forward_hook(self, name: str) -> Callable:
def _post_forward_hook(
module: nn.Module, inputs: Sequence[torch.Tensor], outputs: Sequence[torch.Tensor]
) -> None:
torch.cuda.synchronize()
if isinstance(module, FullyShardedDataParallel):
self._cumul_all_gather_memory.pop()
# Only if it is a leaf module
if name == self._previous_module_name:
allocated, reserved = self._capture_memory()
self._traced_module_names.add(name)
# Get the memory allocated for output activations
ys = self._filter_allocated_output(inputs, outputs)
activations = sum(self._get_module_output_size(y) for y in ys)
# Compute the memory diff + memory taken by the activations
self.memory_traces.append(
LayerMemoryTrace(
module_name=name,
module_params=self._get_parameter_size(module),
allocated=allocated,
reserved=reserved,
is_forward=True,
all_gathered=self._last_all_gather_memory,
cumul_all_gathered=sum(self._cumul_all_gather_memory),
event=TraceForwardEvent(
memory_diff=allocated - self._memory_pre_forward, memory_activations=activations,
),
)
)
self._last_all_gather_memory = 0
# Clean previous forward call values
self._previous_module_name = None
self._memory_pre_forward = 0
return _post_forward_hook
def _create_backward_hook(self, name: str) -> Callable:
def _backward_hook(module: nn.Module, grad_input: torch.Tensor, grad_output: torch.Tensor) -> None:
torch.cuda.synchronize()
if name not in self._traced_module_names:
return
ys = self._filter_allocated_output(grad_input, grad_output)
memory = sum(self._get_module_output_size(y) for y in ys)
allocated, reserved = self._capture_memory()
self.memory_traces.append(
LayerMemoryTrace(
module_name=name,
module_params=self._get_parameter_size(module),
allocated=allocated,
reserved=reserved,
is_forward=False,
all_gathered=self._last_all_gather_memory,
cumul_all_gathered=0,
event=TraceBackwardEvent(memory_activations=memory),
)
)
# Cleaning accumulated values since last call
self._last_all_gather_memory = 0
return _backward_hook
@staticmethod
def _capture_memory() -> Tuple[int, int]:
torch.cuda.synchronize()
allocated_mb = torch.cuda.memory_allocated()
reserved_mb = torch.cuda.memory_reserved() # type: ignore
return allocated_mb, reserved_mb
@classmethod
def _get_parameter_size(cls, module: nn.Module) -> int:
return sum(p.numel() * cls._get_dtype_size(p) for p in module.parameters())
@classmethod
def _get_module_output_size(cls, xs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> int:
"""
Return the minimum memory requirement to store the tensors
provided as parameters
"""
if isinstance(xs, torch.Tensor):
x = xs
p = cls._get_dtype_size(x)
for d in x.shape:
p *= d
return p
elif isinstance(xs, tuple) or isinstance(xs, list):
return sum(cls._get_module_output_size(x) for x in xs)
return 0
@classmethod
def _get_dtype_size(cls, x: torch.Tensor) -> int:
return 2 if x.dtype == torch.float16 else 4
@classmethod
def _filter_allocated_output(
cls, inputs: Union[torch.Tensor, Sequence[torch.Tensor]], outputs: Union[torch.Tensor, Sequence[torch.Tensor]]
) -> List[torch.Tensor]:
"""
Only return the outputs that are allocated and not views, reshape
or stride of the inputs
"""
xs = cls._collect_tensors(inputs)
ys = cls._collect_tensors(outputs)
return [y for y in ys if all(not cls._is_same_storage(x, y) for x in xs)]
@staticmethod
def _is_same_storage(x: torch.Tensor, y: torch.Tensor) -> bool:
"""
Indicate if x and y share the same storage, meaning that one of them
is a view, reshape or stride of the other or from a common tensor
"""
return x.storage().data_ptr() == y.storage().data_ptr() # type: ignore
@staticmethod
def _collect_tensors(module_io_tensors: Union[torch.Tensor, Sequence[torch.Tensor]]) -> List[torch.Tensor]:
"""
Extract the tensors out of the provided input or output of a nn.Module
"""
tensors = []
to_visit = [module_io_tensors]
while to_visit:
x = to_visit.pop()
if isinstance(x, torch.Tensor):
tensors.append(x)
elif isinstance(x, tuple) or isinstance(x, list):
to_visit.extend(module_io_tensors)
return tensors
def find_best_reset_points(activation_sizes: List[int], num_checkpoints: int) -> Tuple[int, List[int]]:
"""
Assuming constant memory requirement from the model, its gradients
and the associated optimizer state (realistic for small models
or models that are sharded enough to be considered small), this
function computes the ideal placement for the checkpoints by
returning the limits at which we should reset memory.
"""
n = len(activation_sizes)
@lru_cache(maxsize=None)
def visit(pos: int, remaining: int) -> Tuple[int, List[int]]:
if pos == n:
return 0, []
if remaining == 0:
return sum(activation_sizes[pos:]), []
min_val = float("inf")
allocation = []
current_chunk = 0
for curr_pos in range(pos, n):
current_chunk += activation_sizes[curr_pos]
sub_result, sub_alloc = visit(curr_pos + 1, remaining - 1)
result = max(current_chunk, sub_result)
if result < min_val:
min_val = result
allocation = list(sub_alloc)
allocation.append(curr_pos + 1)
return int(min_val), allocation
best_score, best_allocation = visit(0, num_checkpoints)
return best_score, best_allocation[::-1]
@dataclass
class SuggestedCheckpoints:
max_memory: int
split_modules: List[str]
all_modules: List[str]
def suggest_checkpoint_location(
traces: List[LayerMemoryTrace], num_checkpoints: int, num_skipped_layers: int = 0
) -> SuggestedCheckpoints:
"""
Given a trace of a model, collected with or without checkpoint,
return the best places to insert a reset of activation memory.
The names of the returned modules are the boundaries of the
suggested checkpoint_wrapper wrappings
"""
# From the traces, extract how much activation memory
# is generated during the forward pass, layer by layer
visited = set()
modules, allocations = [], []
for t in traces:
if t.is_forward:
name = t.module_name
memory = t.event.memory_activations
if name not in visited:
visited.add(name)
modules.append(name)
allocations.append(memory)
# To skip some layers where we do not want activations
if num_skipped_layers:
modules = modules[num_skipped_layers:]
allocations = allocations[num_skipped_layers:]
# Compute the best positions to reset the memory
max_memory, reset_indices = find_best_reset_points(allocations, num_checkpoints=num_checkpoints)
# Then map it back to module names
return SuggestedCheckpoints(
max_memory=max_memory, split_modules=[modules[i] for i in reset_indices], all_modules=modules,
)
def _assert_visualisation_library_installed() -> None:
try:
import PIL # NOQA
import matplotlib # NOQA
except ImportError:
install_matplotlib = "pip install matplotlib"
install_pil = "pip install Pillow"
error_message = "Visualizing memory plots requires matplotlib and Pillow installed"
assert False, f"{error_message}: {install_matplotlib}, {install_pil}"
def compare_memory_traces_in_plot(
memory_traces_by_job: Dict[str, List[LayerMemoryTrace]], figsize: Tuple[int, int] = (16, 20), capture: bool = False,
) -> Optional[Any]:
"""
Create a plot of the memory allocation over time during the forward/backward
passes, with a breakdown of the memory used for activation VS parameters
"""
_assert_visualisation_library_installed()
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=figsize, ncols=2, nrows=3)
graph_creator = _MemoryGraphCreator()
ax[0, 0].set_title("memory allocated")
for job_name, memory_traces in memory_traces_by_job.items():
graph_creator.allocated_memory_curve(ax[0, 0], job_name, memory_traces)
if len(memory_traces_by_job) > 1:
ax[0, 0].legend()
ax[0, 1].set_title("memory reserved")
for job_name, memory_traces in memory_traces_by_job.items():
graph_creator.reserved_memory_curve(ax[0, 1], job_name, memory_traces)
if len(memory_traces_by_job) > 1:
ax[0, 1].legend()
ax[1, 0].set_title("activation allocations")
for job_name, memory_traces in memory_traces_by_job.items():
graph_creator.activation_allocations(ax[1, 0], job_name, memory_traces)
if len(memory_traces_by_job) > 1:
ax[1, 0].legend()
ax[1, 1].set_title("cumulative forward activations")
for job_name, memory_traces in memory_traces_by_job.items():
graph_creator.cumulative_activations(ax[1, 1], job_name, memory_traces)
if len(memory_traces_by_job) > 1:
ax[1, 1].legend()
ax[2, 0].set_title("all gathered memory")
for job_name, memory_traces in memory_traces_by_job.items():
graph_creator.all_gathered_memory(ax[2, 0], job_name, memory_traces)
if len(memory_traces_by_job) > 1:
ax[2, 0].legend()
ax[2, 1].set_title("parameter memory")
for job_name, memory_traces in memory_traces_by_job.items():
graph_creator.module_parameters(ax[2, 1], job_name, memory_traces)
if len(memory_traces_by_job) > 1:
ax[2, 1].legend()
if not capture:
plt.show()
return None
else:
return matplotlib_figure_to_image(fig)
class _MemoryGraphCreator:
"""
Helper class to create graphs to display memory
"""
def __init__(self) -> None:
import matplotlib
self.font = {
"family": matplotlib.rcParams["font.family"],
"weight": "normal",
"size": 12,
}
def allocated_memory_curve(self, ax: Any, job_name: str, memory_traces: List[LayerMemoryTrace]) -> None:
allocated_memory = [t.allocated for t in memory_traces]
x, y_forward, y_backward = self._split_forward_backward(memory_traces, allocated_memory)
ax.plot(x, y_forward, x, y_backward, label=job_name)
max_index = np.argmax(allocated_memory)
max_trace = memory_traces[max_index] # type: ignore
max_module = ".".join([n for n in max_trace.module_name.split(".") if not n.startswith("_")])
max_phase = "fwd" if max_trace.is_forward else "bwd"
ax.set_ylim([None, max_trace.allocated * 1.1])
x_text, y_text = max(0, max_index * 0.8), max_trace.allocated * 1.04 # type: ignore
ax.text(x_text, y_text, f"{max_module} ({max_phase})", fontdict=self.font)
self._y_axis_in_gigabytes(ax)
def reserved_memory_curve(self, ax: Any, job_name: str, memory_traces: List[LayerMemoryTrace]) -> None:
reserved_memory = [t.reserved for t in memory_traces]
x, y_forward, y_backward = self._split_forward_backward(memory_traces, reserved_memory)
ax.plot(x, y_forward, x, y_backward, label=job_name)
self._y_axis_in_gigabytes(ax)
def activation_allocations(self, ax: Any, job_name: str, memory_traces: List[LayerMemoryTrace]) -> None:
event_allocations = [t.event.memory_activations for t in memory_traces]
x, y_forward, y_backward = self._split_forward_backward(memory_traces, event_allocations)
ax.plot(x, y_forward, x, y_backward, label=job_name)
self._y_axis_in_gigabytes(ax)
def cumulative_activations(self, ax: Any, job_name: str, memory_traces: List[LayerMemoryTrace]) -> None:
event_allocations = [t.event.memory_activations for t in memory_traces]
x, y_forward, y_backward = self._split_forward_backward(memory_traces, event_allocations)
cumulative_forward_activations = np.cumsum(y_forward)
ax.plot(x, cumulative_forward_activations, label=job_name)
self._y_axis_in_gigabytes(ax)
def all_gathered_memory(self, ax: Any, job_name: str, memory_traces: List[LayerMemoryTrace]) -> None:
# Plot the all_gathered and cumulative all_gathered memory
gathered_memory = [t.all_gathered for t in memory_traces]
cumul_gathered_memory = [t.cumul_all_gathered for t in memory_traces]
x, y_forward, y_backward = self._split_forward_backward(memory_traces, gathered_memory)
ax.plot(x, y_forward, x, y_backward, label=job_name)
ax.plot(x, cumul_gathered_memory, label=job_name)
self._y_axis_in_gigabytes(ax)
# Adding the name of the layer with max cumulative all_gathered memory
max_index = np.argmax(cumul_gathered_memory)
max_trace = memory_traces[max_index] # type: ignore
max_module = ".".join([n for n in max_trace.module_name.split(".") if not n.startswith("_")])
ax.set_ylim([None, max_trace.cumul_all_gathered * 1.1])
x_text, y_text = max(0, max_index * 0.8), max_trace.cumul_all_gathered * 1.04 # type: ignore
ax.text(x_text, y_text, f"{max_module} (fwd)", fontdict=self.font)
def module_parameters(self, ax: Any, job_name: str, memory_traces: List[LayerMemoryTrace]) -> None:
module_parameters = [t.module_params for t in memory_traces]
x, y_forward, y_backward = self._split_forward_backward(memory_traces, module_parameters)
ax.plot(x, y_forward, x, y_backward, label=job_name)
self._y_axis_in_gigabytes(ax)
@staticmethod
def _y_axis_in_gigabytes(ax: Any) -> None:
ax.ticklabel_format(axis="y", style="sci", scilimits=(9, 9))
@classmethod
def _split_forward_backward(
cls, memory_traces: List[LayerMemoryTrace], values: List[Any]
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
x_values = np.array(list(range(len(memory_traces))))
mask_forwards, mask_backwards = cls._mask_forward_backward(memory_traces)
return (
x_values,
np.ma.masked_where(mask_backwards, values), # type: ignore
np.ma.masked_where(mask_forwards, values), # type: ignore
)
@classmethod
def _mask_forward_backward(cls, memory_traces: List[LayerMemoryTrace]) -> Tuple[np.ndarray, np.ndarray]:
mask_forwards = np.array([t.is_forward for t in memory_traces])
return mask_forwards, ~mask_forwards
@contextmanager
def null_context() -> Iterator[None]:
yield
def matplotlib_figure_to_image(fig: Any) -> Any:
"""
Convert a matplotlib figure to an image in RGB format, for instance
to save it on disk
"""
import io
from PIL import Image
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
return Image.open(buf).convert("RGB")
......@@ -226,7 +226,7 @@ def spawn_for_all_world_sizes(test_func: Callable, world_sizes: List[int] = get_
def worker_process(
rank: int, world_size: int, filename: str, filename_rpc: str, func: Callable, args: Any, error_queue: Any
) -> None:
"""Main function for unit tests launced with torch_spawn"""
"""Main function for unit tests launched with torch_spawn"""
if not dist_init(rank, world_size, filename, filename_rpc):
logging.warning("failed initializing torch distributed")
......
......@@ -46,3 +46,4 @@ tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
tests/experimental/nn/test_offload.py
tests/experimental/nn/test_auto_shard.py
tests/experimental/optim/test_dynamic_loss_scaler.py
tests/experimental/tooling/test_layer_memory_tracker.py
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from fairscale.experimental.tooling.layer_memory_tracker import (
LayerwiseMemoryTracker,
ProcessGroupTracker,
find_best_reset_points,
)
from fairscale.nn import FullyShardedDataParallel
from fairscale.utils.testing import GPT2, dist_init, skip_if_no_cuda, skip_if_single_gpu, temp_files_ctx
@skip_if_no_cuda()
def test_memory_tracking_traces():
"""
Minimal test case to check that we can collect memory traces
outside of the context of distributed training (DDP or FSDP)
"""
# Create a model with a hierarchy of modules
torch.manual_seed(0)
model = nn.Sequential(
nn.Sequential(
nn.Conv2d(3, 64, kernel_size=(3, 3), padding=(1, 1), bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=False),
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
),
nn.Flatten(start_dim=1),
nn.Sequential(nn.Linear(64, 2), nn.ReLU(inplace=True)),
).cuda()
# Track a fake forward / backward
tracker = LayerwiseMemoryTracker()
tracker.monitor(model)
x = torch.randn(size=(2, 3, 224, 224)).cuda()
target = torch.LongTensor([0, 1]).cuda()
criterion = nn.CrossEntropyLoss()
criterion(model(x), target).backward()
# Verify that only leaf modules are tracked and that the order
# of the traces is consistent with backward/forward
tracked_names = [t.module_name for t in tracker.memory_traces]
expected_names = ["0.0", "0.1", "0.2", "0.3", "1", "2.0", "2.1"]
assert set(expected_names) == set(tracked_names)
assert tracked_names == (expected_names + expected_names[::-1])
# Verify that memory tracking for ReLU is sound
assert (
2 * 64 * 224 * 224 * 4 == tracker.forward_traces[2].event.memory_activations
), "ReLU(inplace=False) should allocate activations"
assert 0 == tracker.forward_traces[6].event.memory_activations, "ReLU(inplace=True) should NOT allocate activations"
# Verify that overall memory tracking is sound
summary = tracker.summary
assert summary.total_forward_allocations >= summary.total_activation_allocations
# Verify that the identification of top memory activation producer works:
# these are the first layers, all allocating (2, 64, 224, 224) feature maps
top_act_producers = summary.top_forward_activation_producers[:3]
assert "0.0" == top_act_producers[0].module_name
assert "0.1" == top_act_producers[1].module_name
assert "0.2" == top_act_producers[2].module_name
assert 3 * 3 * 64 * 3 * 4 == top_act_producers[0].module_params
assert 64 * 2 * 4 == top_act_producers[1].module_params
assert 0 == top_act_producers[2].module_params
for trace in top_act_producers:
assert 2 * 64 * 224 * 224 * 4 == trace.event.memory_activations
@skip_if_no_cuda
def test_memory_tracking_nlp_model():
"""
Check that we can collect memory traces of a realistic model
outside of the context of distributed training (DDP or FSDP)
"""
BACH_SIZE = 10
INPUT_DIM = 16
model = GPT2(
embed_dim=256, num_heads=2, num_layers=6, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
).cuda()
tracker = LayerwiseMemoryTracker()
tracker.monitor(model)
input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).cuda()
output = model(input_tensor)
output.sum().backward()
assert len(tracker.memory_traces) > 0, "failed to collected memory traces"
assert len(tracker.forward_traces) > 0, "failed to collect forward memory traces"
assert len(tracker.backward_traces) > 0, "failed to collect backward memory traces"
assert tracker.summary.total_activation_allocations == 12462080
@skip_if_single_gpu
def test_memory_tracking_ddp():
"""
Check that we can collect memory traces of a simplistic model
in the context of DDP distributed training
"""
with temp_files_ctx(num=2) as sync_files:
world_size = 2
mp.spawn(
_layer_memory_tracking_ddp_worker, (sync_files, world_size), nprocs=world_size,
)
def _layer_memory_tracking_ddp_worker(gpu_id: int, sync_files: Tuple[str, str], world_size: int):
dist_init(world_size=world_size, rank=gpu_id, filename=sync_files[0], filename_rpc=sync_files[1])
torch.backends.cudnn.deterministic = True
# Create different inputs on each GPU
batch_size = 16
torch.manual_seed(gpu_id)
fake_inputs = torch.randn(size=(batch_size, 10)).cuda(gpu_id)
fake_targets = torch.randn(size=(batch_size, 10)).cuda(gpu_id)
fake_criterion = nn.MSELoss()
# Create a simple model
torch.manual_seed(0)
torch.cuda.manual_seed(0)
model = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 10),)
model = model.cuda(gpu_id)
ddp_model = DistributedDataParallel(model, device_ids=[gpu_id])
# Track the model on a forward / backward pass
tracker = LayerwiseMemoryTracker()
tracker.monitor(ddp_model)
fake_criterion(ddp_model(fake_inputs), fake_targets).backward()
tracker.stop()
# Check the overall structure of the collected traces
forward_names = [f"module.{i}" for i in range(5)]
backward_names = [f"module.{i}" for i in reversed(range(5))]
trace_names = [t.module_name for t in tracker.memory_traces]
assert trace_names == (forward_names + backward_names)
@skip_if_single_gpu
def test_memory_tracking_fsdp():
"""
Check that we can collect memory traces of a simplistic model
in the context of FSDP distributed training
"""
with temp_files_ctx(num=2) as sync_files:
world_size = 2
mp.spawn(
_layer_memory_tracking_fsdp_worker, (sync_files, world_size), nprocs=world_size,
)
def _layer_memory_tracking_fsdp_worker(gpu_id: int, sync_files: Tuple[str, str], world_size: int):
dist_init(world_size=world_size, rank=gpu_id, filename=sync_files[0], filename_rpc=sync_files[1])
torch.backends.cudnn.deterministic = True
# Create different inputs on each GPU
batch_size = 16
torch.manual_seed(gpu_id)
fake_inputs = torch.randn(size=(batch_size, 10)).cuda(gpu_id)
fake_targets = torch.randn(size=(batch_size, 10)).cuda(gpu_id)
fake_criterion = nn.MSELoss()
# Create a global group and a tracker around it
group = dist.new_group()
group = ProcessGroupTracker(group)
# Create a simple model
torch.manual_seed(0)
torch.cuda.manual_seed(0)
model = nn.Sequential(
nn.Linear(10, 10).cuda(gpu_id),
nn.ReLU(),
FullyShardedDataParallel(nn.Linear(10, 10).cuda(gpu_id), flatten_parameters=False, process_group=group,),
nn.ReLU(),
FullyShardedDataParallel(nn.Linear(10, 10).cuda(gpu_id), flatten_parameters=True, process_group=group,),
)
model = model.cuda(gpu_id)
dist_model = FullyShardedDataParallel(model, flatten_parameters=False, process_group=group)
# Track the model on a forward / backward pass
tracker = LayerwiseMemoryTracker()
tracker.monitor(dist_model)
fake_criterion(dist_model(fake_inputs), fake_targets).backward()
tracker.stop()
# Check results of all gathers tracking (feature specific to FSDP)
all_gathered_traces = [
(t.module_name, t.all_gathered, t.cumul_all_gathered) for t in tracker.memory_traces if t.all_gathered > 0
]
assert all_gathered_traces == [
("_fsdp_wrapped_module._fpw_module.0", 440, 440),
("_fsdp_wrapped_module._fpw_module.2._fsdp_wrapped_module._fpw_module", 440, 880),
("_fsdp_wrapped_module._fpw_module.4._fsdp_wrapped_module._fpw_module", 440, 880),
("_fsdp_wrapped_module._fpw_module.4._fsdp_wrapped_module._fpw_module", 440, 0),
("_fsdp_wrapped_module._fpw_module.2._fsdp_wrapped_module._fpw_module", 440, 0),
], all_gathered_traces
def test_find_best_reset_points():
"""
Verify that the reset points are correctly computed
"""
activations = [10, 8, 8, 9, 7, 7, 5, 4, 4]
# Check boundary condition: no checkpoints
memory, split_points = find_best_reset_points(activations, num_checkpoints=0)
assert memory == sum(activations)
# Check boundary condition: checkpoints everywhere
memory, split_points = find_best_reset_points(activations, num_checkpoints=len(activations))
assert memory == max(activations)
# Check one checkpoint allocation
memory, split_points = find_best_reset_points(activations, num_checkpoints=1)
assert memory == 35
assert split_points == [4]
assert sum(activations[: split_points[0]]) == 35
assert sum(activations[split_points[0] :]) == 27
# Check multiple checkpoint allocation
memory, split_points = find_best_reset_points(activations, num_checkpoints=2)
assert memory == 24
delimiters = [0] + split_points + [len(activations)]
splits_memory = [sum(activations[i:j]) for i, j in zip(delimiters[:-1], delimiters[1:])]
assert max(splits_memory) == memory
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment