Unverified Commit ad92220c authored by Quentin Duval's avatar Quentin Duval Committed by GitHub
Browse files

[feat] layer memory tracking (#808)



* [feat] layer memory tracking

* [feat] layer memory tracking (add tests in CI)

* [feat] layer memory tracking: doc typos

* [feat] layer memory tracking: mypy fixes

* [feat] layer memory tracking: fixes for FSDP all gather tracking on pytorch 1.9 and above

* [feat] layer memory tracking: lint

* [feat] layer memory tracking: mypy
Co-authored-by: default avatarQuentinDuval <QuentinDuval@users.noreply.github.com>
parent 51e43b61
...@@ -55,6 +55,7 @@ modules and easy to use APIs. ...@@ -55,6 +55,7 @@ modules and easy to use APIs.
tutorials/offload_model tutorials/offload_model
tutorials/adascale tutorials/adascale
tutorials/pipe tutorials/pipe
tutorials/layer_memory_tracking
| |
| |
......
Efficient memory usage using Activation Checkpointing Efficient memory usage using Activation Checkpointing
===================================================== =====================================================
Adaped from `torch.utils.checkpoint`, this is a friendlier wrapper for performing activation checkpointing. Adapted from `torch.utils.checkpoint`, this is a friendlier wrapper for performing activation checkpointing.
Compared to the PyTorch version, this version wraps a `nn.Module` and allows for all subsequent calls to be Compared to the PyTorch version, this version wraps a `nn.Module` and allows for all subsequent calls to be
checkpointed. checkpointed.
......
Tooling to diagnose and fix memory problems
===========================================
FairScale comes with some experimental tooling to help track, visualize and suggest fix for memory issues occurring during the forward/backward pass of your models.
Visualizing the memory profile
------------------------------
To track and visualize the memory profile of a model, you can use the `LayerwiseMemoryTracker`:
.. code-block:: python
from fairscale.experimental.tooling.layer_memory_tracker import LayerwiseMemoryTracker
import torch
import torchvision.models
# Create a model
model = torchvision.models.resnet50().cuda()
criterion = torch.nn.CrossEntropyLoss()
# Create some dummy inputs
batch_size = 16
x = torch.randn(size=(batch_size, 3, 224, 224)).cuda()
y = torch.tensor(list(range(batch_size)), dtype=torch.int64).cuda()
# Start monitoring the model
tracker = LayerwiseMemoryTracker()
tracker.monitor(model)
# Do a forward/backward with dummy inputs
criterion(model(x), y).backward()
# Stop monitoring the model
tracker.stop()
# Show some useful default plots
tracker.show_plots()
The resulting graphs will include:
- a graph of the memory profile (memory allocated and reserved) during the forward/backward
.. image:: _static/img/layer_memory_profiles.png
:width: 1000px
:align: center
- a graph of the amount of memory allocations done for activations done during the forward/backward
.. image:: _static/img/layer_memory_activations.png
:width: 1000px
:align: center
- a graph of the amount of memory used for parameters by each the layers traversed done during the forward/backward
.. image:: _static/img/layer_memory_parameters.png
:width: 500px
:align: center
In all these graphs:
- the blue part of the curve is used for the forward pass, the orange for the backward pass
- the X axis is only used for ordering of the computational steps (it does not represent the index of the layer in the model)
How to use those graphs?
------------------------
It is not always obvious to understand how much memory a model will be using. Those graphs allows to visualize:
- what is the main cause of memory consumption: this would be memory activations in the graph above
- what are the layers that are worth sharding: those at the end of the convolution net as in the case above
- where should we place activation checkpoints to diminish memory consumption
If those graphs are not useful to you, you can always use the raw data collected by the `LayerwiseMemoryTracker` instead,
or use any of the other utility functions provided in the tool:
.. code-block:: python
# Access all raw traces / forward traces only / backward traces only
tracker.memory_traces
tracker.forward_traces
tracker.backward_traces
# Access a quick summary of the traces with information on:
# - the peak memory usage
# - the top layers in terms of memory consumption
tracker.summary
Activation checkpoint suggestions
---------------------------------
In additional to visualisation, the `LayerwiseMemoryTracker` traces can be used to suggest activation checkpoints
locations, which can be used to reduce the memory consumption of the forward/backward, but trading some compute:
.. code-block:: python
from fairscale.experimental.tooling.layer_memory_tracker import suggest_checkpoint_location
suggestion = suggest_checkpoint_location(tracker.memory_traces, num_checkpoints=0)
print(suggestion.max_memory) # Outputs: 1435630080
suggestion = suggest_checkpoint_location(tracker.memory_traces, num_checkpoints=2)
print(suggestion.max_memory) # Outputs: 485095936
print(suggestion.split_modules) # Outputs: ['layer1.1.bn3', 'layer2.2.conv3']
This sample code tells us that we can reduce the memory consumption due to activations from 1.4G to around 500M by
checkpointing activations at the locations `layer1.1.bn3` and `layer2.2.conv3`.
These locations can serve as first guesses and might not always be practical due to the model code. In the case of a
torchvision resnet, we can adapt those locations by trying to checkpoint around layer1 and layer2:
.. code-block:: python
model = torchvision.models.resnet50().cuda()
model.layer1 = checkpoint_wrapper(model.layer1)
model.layer3 = checkpoint_wrapper(torch.nn.Sequential(model.layer2, model.layer3))
model.layer2 = torch.nn.Identity()
Leading to the following memory profile, saving around 400MB of activation memory at the cost of more compute:
.. image:: _static/img/layer_memory_profile_optimized.png
:width: 500px
:align: center
Dedicated features to FSDP distributed training
-----------------------------------------------
When training a big model with `FullyShardedDataParallel`, you can use the `LayerwiseMemoryTracker` to track the
amount of memory exchanged by FSDP to consolidate sharded layers:
.. code-block:: python
from fairscale.nn import FullyShardedDataParallel as FSDP
from fairscale.experimental.tooling.layer_memory_tracker import ProcessGroupTracker
# Create a process group for FSDP
group = torch.distributed.new_group()
group = ProcessGroupTracker(group)
# Create a FSDP model
model = torchvision.models.resnet50().cuda()
model.layer1 = FSDP(model.layer1, process_group=group)
model.layer2 = FSDP(model.layer2, process_group=group)
model.layer3 = FSDP(model.layer3, process_group=group)
model.layer4 = FSDP(model.layer4, process_group=group)
model = FSDP(model, process_group=group)
Now, the `LayerwiseMemoryTracker` will provide an additional graph where we can see:
- the memory spikes (in blue for forward, in orange for backward) of the `all_gather` calls
- an estimation (in green) of cumulative parameter memory (only available for the forward pass)
.. image:: _static/img/all_gathered_memory.png
:width: 500px
:align: center
Limitations
------------
The `LayerwiseMemoryTracker` has a bunch of limitations it is important to be aware of:
1. It only works on GPU models: models cannot sit on the CPU
2. Some of the GPU memory might not tracked by PyTorch (for example some NCCL buffers) and therefore will not be tracked with this tooling either
3. Beside memory allocated and memory cached, which are based on PyTorch, the results are based on heuristics, and might miss some memory in some cases
4. Some features (such as cumulative all gathered memory for FSDP) do not work in the backward pass
This diff is collapsed.
...@@ -226,7 +226,7 @@ def spawn_for_all_world_sizes(test_func: Callable, world_sizes: List[int] = get_ ...@@ -226,7 +226,7 @@ def spawn_for_all_world_sizes(test_func: Callable, world_sizes: List[int] = get_
def worker_process( def worker_process(
rank: int, world_size: int, filename: str, filename_rpc: str, func: Callable, args: Any, error_queue: Any rank: int, world_size: int, filename: str, filename_rpc: str, func: Callable, args: Any, error_queue: Any
) -> None: ) -> None:
"""Main function for unit tests launced with torch_spawn""" """Main function for unit tests launched with torch_spawn"""
if not dist_init(rank, world_size, filename, filename_rpc): if not dist_init(rank, world_size, filename, filename_rpc):
logging.warning("failed initializing torch distributed") logging.warning("failed initializing torch distributed")
......
...@@ -46,3 +46,4 @@ tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py ...@@ -46,3 +46,4 @@ tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
tests/experimental/nn/test_offload.py tests/experimental/nn/test_offload.py
tests/experimental/nn/test_auto_shard.py tests/experimental/nn/test_auto_shard.py
tests/experimental/optim/test_dynamic_loss_scaler.py tests/experimental/optim/test_dynamic_loss_scaler.py
tests/experimental/tooling/test_layer_memory_tracker.py
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from fairscale.experimental.tooling.layer_memory_tracker import (
LayerwiseMemoryTracker,
ProcessGroupTracker,
find_best_reset_points,
)
from fairscale.nn import FullyShardedDataParallel
from fairscale.utils.testing import GPT2, dist_init, skip_if_no_cuda, skip_if_single_gpu, temp_files_ctx
@skip_if_no_cuda()
def test_memory_tracking_traces():
"""
Minimal test case to check that we can collect memory traces
outside of the context of distributed training (DDP or FSDP)
"""
# Create a model with a hierarchy of modules
torch.manual_seed(0)
model = nn.Sequential(
nn.Sequential(
nn.Conv2d(3, 64, kernel_size=(3, 3), padding=(1, 1), bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=False),
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
),
nn.Flatten(start_dim=1),
nn.Sequential(nn.Linear(64, 2), nn.ReLU(inplace=True)),
).cuda()
# Track a fake forward / backward
tracker = LayerwiseMemoryTracker()
tracker.monitor(model)
x = torch.randn(size=(2, 3, 224, 224)).cuda()
target = torch.LongTensor([0, 1]).cuda()
criterion = nn.CrossEntropyLoss()
criterion(model(x), target).backward()
# Verify that only leaf modules are tracked and that the order
# of the traces is consistent with backward/forward
tracked_names = [t.module_name for t in tracker.memory_traces]
expected_names = ["0.0", "0.1", "0.2", "0.3", "1", "2.0", "2.1"]
assert set(expected_names) == set(tracked_names)
assert tracked_names == (expected_names + expected_names[::-1])
# Verify that memory tracking for ReLU is sound
assert (
2 * 64 * 224 * 224 * 4 == tracker.forward_traces[2].event.memory_activations
), "ReLU(inplace=False) should allocate activations"
assert 0 == tracker.forward_traces[6].event.memory_activations, "ReLU(inplace=True) should NOT allocate activations"
# Verify that overall memory tracking is sound
summary = tracker.summary
assert summary.total_forward_allocations >= summary.total_activation_allocations
# Verify that the identification of top memory activation producer works:
# these are the first layers, all allocating (2, 64, 224, 224) feature maps
top_act_producers = summary.top_forward_activation_producers[:3]
assert "0.0" == top_act_producers[0].module_name
assert "0.1" == top_act_producers[1].module_name
assert "0.2" == top_act_producers[2].module_name
assert 3 * 3 * 64 * 3 * 4 == top_act_producers[0].module_params
assert 64 * 2 * 4 == top_act_producers[1].module_params
assert 0 == top_act_producers[2].module_params
for trace in top_act_producers:
assert 2 * 64 * 224 * 224 * 4 == trace.event.memory_activations
@skip_if_no_cuda
def test_memory_tracking_nlp_model():
"""
Check that we can collect memory traces of a realistic model
outside of the context of distributed training (DDP or FSDP)
"""
BACH_SIZE = 10
INPUT_DIM = 16
model = GPT2(
embed_dim=256, num_heads=2, num_layers=6, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
).cuda()
tracker = LayerwiseMemoryTracker()
tracker.monitor(model)
input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).cuda()
output = model(input_tensor)
output.sum().backward()
assert len(tracker.memory_traces) > 0, "failed to collected memory traces"
assert len(tracker.forward_traces) > 0, "failed to collect forward memory traces"
assert len(tracker.backward_traces) > 0, "failed to collect backward memory traces"
assert tracker.summary.total_activation_allocations == 12462080
@skip_if_single_gpu
def test_memory_tracking_ddp():
"""
Check that we can collect memory traces of a simplistic model
in the context of DDP distributed training
"""
with temp_files_ctx(num=2) as sync_files:
world_size = 2
mp.spawn(
_layer_memory_tracking_ddp_worker, (sync_files, world_size), nprocs=world_size,
)
def _layer_memory_tracking_ddp_worker(gpu_id: int, sync_files: Tuple[str, str], world_size: int):
dist_init(world_size=world_size, rank=gpu_id, filename=sync_files[0], filename_rpc=sync_files[1])
torch.backends.cudnn.deterministic = True
# Create different inputs on each GPU
batch_size = 16
torch.manual_seed(gpu_id)
fake_inputs = torch.randn(size=(batch_size, 10)).cuda(gpu_id)
fake_targets = torch.randn(size=(batch_size, 10)).cuda(gpu_id)
fake_criterion = nn.MSELoss()
# Create a simple model
torch.manual_seed(0)
torch.cuda.manual_seed(0)
model = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 10),)
model = model.cuda(gpu_id)
ddp_model = DistributedDataParallel(model, device_ids=[gpu_id])
# Track the model on a forward / backward pass
tracker = LayerwiseMemoryTracker()
tracker.monitor(ddp_model)
fake_criterion(ddp_model(fake_inputs), fake_targets).backward()
tracker.stop()
# Check the overall structure of the collected traces
forward_names = [f"module.{i}" for i in range(5)]
backward_names = [f"module.{i}" for i in reversed(range(5))]
trace_names = [t.module_name for t in tracker.memory_traces]
assert trace_names == (forward_names + backward_names)
@skip_if_single_gpu
def test_memory_tracking_fsdp():
"""
Check that we can collect memory traces of a simplistic model
in the context of FSDP distributed training
"""
with temp_files_ctx(num=2) as sync_files:
world_size = 2
mp.spawn(
_layer_memory_tracking_fsdp_worker, (sync_files, world_size), nprocs=world_size,
)
def _layer_memory_tracking_fsdp_worker(gpu_id: int, sync_files: Tuple[str, str], world_size: int):
dist_init(world_size=world_size, rank=gpu_id, filename=sync_files[0], filename_rpc=sync_files[1])
torch.backends.cudnn.deterministic = True
# Create different inputs on each GPU
batch_size = 16
torch.manual_seed(gpu_id)
fake_inputs = torch.randn(size=(batch_size, 10)).cuda(gpu_id)
fake_targets = torch.randn(size=(batch_size, 10)).cuda(gpu_id)
fake_criterion = nn.MSELoss()
# Create a global group and a tracker around it
group = dist.new_group()
group = ProcessGroupTracker(group)
# Create a simple model
torch.manual_seed(0)
torch.cuda.manual_seed(0)
model = nn.Sequential(
nn.Linear(10, 10).cuda(gpu_id),
nn.ReLU(),
FullyShardedDataParallel(nn.Linear(10, 10).cuda(gpu_id), flatten_parameters=False, process_group=group,),
nn.ReLU(),
FullyShardedDataParallel(nn.Linear(10, 10).cuda(gpu_id), flatten_parameters=True, process_group=group,),
)
model = model.cuda(gpu_id)
dist_model = FullyShardedDataParallel(model, flatten_parameters=False, process_group=group)
# Track the model on a forward / backward pass
tracker = LayerwiseMemoryTracker()
tracker.monitor(dist_model)
fake_criterion(dist_model(fake_inputs), fake_targets).backward()
tracker.stop()
# Check results of all gathers tracking (feature specific to FSDP)
all_gathered_traces = [
(t.module_name, t.all_gathered, t.cumul_all_gathered) for t in tracker.memory_traces if t.all_gathered > 0
]
assert all_gathered_traces == [
("_fsdp_wrapped_module._fpw_module.0", 440, 440),
("_fsdp_wrapped_module._fpw_module.2._fsdp_wrapped_module._fpw_module", 440, 880),
("_fsdp_wrapped_module._fpw_module.4._fsdp_wrapped_module._fpw_module", 440, 880),
("_fsdp_wrapped_module._fpw_module.4._fsdp_wrapped_module._fpw_module", 440, 0),
("_fsdp_wrapped_module._fpw_module.2._fsdp_wrapped_module._fpw_module", 440, 0),
], all_gathered_traces
def test_find_best_reset_points():
"""
Verify that the reset points are correctly computed
"""
activations = [10, 8, 8, 9, 7, 7, 5, 4, 4]
# Check boundary condition: no checkpoints
memory, split_points = find_best_reset_points(activations, num_checkpoints=0)
assert memory == sum(activations)
# Check boundary condition: checkpoints everywhere
memory, split_points = find_best_reset_points(activations, num_checkpoints=len(activations))
assert memory == max(activations)
# Check one checkpoint allocation
memory, split_points = find_best_reset_points(activations, num_checkpoints=1)
assert memory == 35
assert split_points == [4]
assert sum(activations[: split_points[0]]) == 35
assert sum(activations[split_points[0] :]) == 27
# Check multiple checkpoint allocation
memory, split_points = find_best_reset_points(activations, num_checkpoints=2)
assert memory == 24
delimiters = [0] + split_points + [len(activations)]
splits_memory = [sum(activations[i:j]) for i, j in zip(delimiters[:-1], delimiters[1:])]
assert max(splits_memory) == memory
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment