Unverified Commit 8a42a8e3 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] FSDP forward pass overlap between compute and all-gather (#671)



* [fix] FSDP forward pass overlap between compute and all-gather

- much thanks for @cyanguwa for report and @QuentinDuval for debugging it
- a new unit test is added to check for this and ensure we detect
  issue with overlapping and cpu/gpu blocking wait calls

* fix

* fix

* fix

* better assertion outputs

* fix format and tune all_gather mb for CI

* more tuning with non_flatten

* undo an accidental change

* tuning all gather mb and del model

* Update + fix overlapping test to use patched all_gather w/ delay (#672)

* fixing get_cycles_per_ms

* add get_smi_memory

* update the docstring
Co-authored-by: default avatarMin Xu <min.xu@acm.org>
Co-authored-by: default avatarMyle Ott <myleott@fb.com>
parent c8d32c30
...@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD ## NEXT - TBD
### Fixed ### Fixed
- FSDP: fix forward pass not overlapping compute and all-gather
- FSDP: improved frozen weight support - FSDP: improved frozen weight support
- FSDP: workaround AMP autocast cache issue with clear\_autocast\_cache flag - FSDP: workaround AMP autocast cache issue with clear\_autocast\_cache flag
- setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar - setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar
......
...@@ -1448,13 +1448,15 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1448,13 +1448,15 @@ class FullyShardedDataParallel(nn.Module):
if params is None: if params is None:
params = self.params params = self.params
self.has_full_params = False self.has_full_params = False
self._streams["all_gather"].wait_stream(torch.cuda.current_stream()) current_stream = torch.cuda.current_stream()
with torch.cuda.stream(self._streams["all_gather"]):
for p in params: for p in params:
if not p._is_sharded: # e.g., world_size == 1 if not p._is_sharded: # e.g., world_size == 1
if self.mixed_precision: if self.mixed_precision:
self._free_fp16_param_shard([p]) self._free_fp16_param_shard([p])
continue continue
# Don't let PyTorch reuse this memory until all work in the current
# stream is complete.
p._full_param_padded.record_stream(current_stream)
# There may be external references to the Tensor Storage that we # There may be external references to the Tensor Storage that we
# can't modify, such as references that are created by # can't modify, such as references that are created by
# ctx.save_for_backward in the forward pass. Thus when we # ctx.save_for_backward in the forward pass. Thus when we
...@@ -1692,16 +1694,23 @@ def _get_default_cuda_device(module: nn.Module) -> torch.device: ...@@ -1692,16 +1694,23 @@ def _get_default_cuda_device(module: nn.Module) -> torch.device:
def cast_floats_to_right_precision(to_fp16: bool, no_grad: bool, *args: Any, **kwargs: Any) -> Tuple[Any, Any]: def cast_floats_to_right_precision(to_fp16: bool, no_grad: bool, *args: Any, **kwargs: Any) -> Tuple[Any, Any]:
""" """
Cast floating point Tensors in *args or **kwargs to FP16 or FP32 if they are not. Cast floating point Tensors in *args or **kwargs to FP16 or FP32 if they are not.
We also retain the requires_grad flag so that casting doesn't affect the autograd graph.
""" """
def fn_fp16(x: torch.Tensor) -> torch.Tensor: def fn_fp16(x: torch.Tensor) -> torch.Tensor:
if x.dtype is torch.float32: if x.dtype is torch.float32:
return x.half() y = x.half()
if x.is_leaf:
y.requires_grad = x.requires_grad
return y
return x return x
def fn_fp32(x: torch.Tensor) -> torch.Tensor: def fn_fp32(x: torch.Tensor) -> torch.Tensor:
if x.dtype is torch.float16: if x.dtype is torch.float16:
return x.float() y = x.float()
if x.is_leaf:
y.requires_grad = x.requires_grad
return y
return x return x
fn = fn_fp16 if to_fp16 else fn_fp32 fn = fn_fp16 if to_fp16 else fn_fp32
......
...@@ -34,6 +34,7 @@ import logging ...@@ -34,6 +34,7 @@ import logging
import multiprocessing import multiprocessing
import os import os
import random import random
from statistics import mean
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
...@@ -577,22 +578,12 @@ class DeviceAndTypeCheckModule(Base): ...@@ -577,22 +578,12 @@ class DeviceAndTypeCheckModule(Base):
@functools.lru_cache() @functools.lru_cache()
def get_cycles_per_ms() -> float: def get_cycles_per_ms() -> float:
"""Approximate number of cycles per millisecond for torch.cuda._sleep """Measure and return approximate number of cycles per millisecond for torch.cuda._sleep
Copied from: github.com/pytorch/pytorch/blob/master/test/test_cuda.py Copied from: github.com/pytorch/pytorch/blob/master/test/test_cuda.py
..note::
This doesn't seems to return consistent cycles on desktop GPUs likely
due to frequency scaling.
>>> get_cycles_per_ms()
227.6441091140009
# new python process
>>> get_cycles_per_ms()
564.652154766248
# new python process
>>> get_cycles_per_ms()
245.56459442962856
""" """
def measure() -> float:
start = torch.cuda.Event(enable_timing=True) start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True)
start.record() start.record()
...@@ -602,6 +593,21 @@ def get_cycles_per_ms() -> float: ...@@ -602,6 +593,21 @@ def get_cycles_per_ms() -> float:
cycles_per_ms = 1000000 / start.elapsed_time(end) cycles_per_ms = 1000000 / start.elapsed_time(end)
return cycles_per_ms return cycles_per_ms
# Get 10 values and remove the 2 max and 2 min and return the avg.
# This is to avoid system disturbance that skew the results, e.g.
# the very first cuda call likely does a bunch of init, which takes
# much longer than subsequent calls.
#
# Tested on both Tesla V100, Quadro GP100, Titan RTX, RTX 3090 GPUs
# and seems to return stable values. Therefore, we enable caching
# using lru_cache decorator above.
num = 10
vals = []
for _ in range(num):
vals.append(measure())
vals = sorted(vals)
return mean(vals[2 : num - 2])
class DummyProcessGroup: class DummyProcessGroup:
def __init__(self, rank: int, size: int): def __init__(self, rank: int, size: int):
...@@ -681,3 +687,15 @@ def dump_all_tensors(rank: int) -> None: ...@@ -681,3 +687,15 @@ def dump_all_tensors(rank: int) -> None:
except Exception as e: except Exception as e:
pass pass
print(torch.cuda.memory_summary()) print(torch.cuda.memory_summary())
def get_smi_memory() -> float:
"""Return process's GPU memory in MB."""
pid = os.getpid()
info_string = torch.cuda.list_gpu_processes()
for line in info_string.splitlines():
if str(pid) in line:
toks = line.split()
return float(toks[3])
# If the process is not in the list, we are not using the GPU.
return 0.0
...@@ -13,3 +13,6 @@ pytest-cov == 2.10.0 ...@@ -13,3 +13,6 @@ pytest-cov == 2.10.0
pytest-timeout == 1.4.2 pytest-timeout == 1.4.2
remote-pdb >= 2.1.0 remote-pdb >= 2.1.0
parameterized >= 0.8.1 parameterized >= 0.8.1
# For torch.cuda.list_gpu_processes()
pynvml == 8.0.4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
from typing import Optional, Tuple, Union, Dict, Any from typing import Optional, Tuple, Union, Dict, Any
import ctypes import ctypes
from . import amp from . import amp
...@@ -48,6 +49,7 @@ def reset_max_memory_cached(device: Optional[_device_t]=...) -> None: ... ...@@ -48,6 +49,7 @@ def reset_max_memory_cached(device: Optional[_device_t]=...) -> None: ...
def memory_summary() -> str: ... def memory_summary() -> str: ...
def cudart() -> ctypes.CDLL: ... def cudart() -> ctypes.CDLL: ...
def find_cuda_windows_lib() -> Optional[ctypes.CDLL]: ... def find_cuda_windows_lib() -> Optional[ctypes.CDLL]: ...
def list_gpu_processes(device: Union[torch.device, str, None, int] = None) -> str: ...
#MODIFIED BY TORCHGPIPE #MODIFIED BY TORCHGPIPE
from .. import ByteTensor from .. import ByteTensor
def set_rng_state(new_state: ByteTensor, device: _device_t = ...) -> None: ... def set_rng_state(new_state: ByteTensor, device: _device_t = ...) -> None: ...
......
tests/nn/misc/test_checkpoint_activations.py tests/nn/misc/test_checkpoint_activations.py
tests/nn/misc/test_checkpoint_activations_norm.py tests/nn/misc/test_checkpoint_activations_norm.py
tests/nn/data_parallel/test_fsdp_overlap.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_apply.py tests/nn/data_parallel/test_fsdp_apply.py
tests/nn/data_parallel/test_fsdp_state_dict.py tests/nn/data_parallel/test_fsdp_state_dict.py
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test FSDP and ensure expected overlapping between all_gather and forward. """
from statistics import mean
import time
from unittest.mock import patch
import pytest
import torch
from torch.cuda import Event
import torch.multiprocessing as mp
import torch.nn as nn
from fairscale.nn import enable_wrap, wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import (
dist_init,
get_cycles_per_ms,
skip_if_single_gpu,
teardown,
temp_files_ctx,
torch_version,
)
class Layer(nn.Module):
def __init__(self, compute_cycles, has_params: bool):
super().__init__()
self.sleep_cycles = compute_cycles
self.optional_param = None
if has_params:
self.optional_param = nn.Parameter(torch.rand(1))
def forward(self, x):
# Get 2 events.
self.e1 = Event(enable_timing=True)
self.e2 = Event(enable_timing=True)
# Record the fake forward compute time.
self.e1.record()
if self.sleep_cycles > 0:
torch.cuda._sleep(self.sleep_cycles)
if self.optional_param is not None:
x = x + self.optional_param # force the param to be part of the graph
self.e2.record()
return x
def get_time(self):
# return the recorded duration.
return self.e1.elapsed_time(self.e2)
def _create_model(fsdp_config, compute_cycles, has_params: bool):
with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
model = wrap(
nn.Sequential(
wrap(Layer(compute_cycles, has_params)),
wrap(Layer(compute_cycles, has_params)),
wrap(Layer(compute_cycles, has_params)),
wrap(Layer(compute_cycles, has_params)),
)
).cuda()
return model
class Min10:
def __init__(self):
self.data = []
def add(self, new_data):
if len(self.data) < 10:
self.data.append(new_data)
else:
self.data = sorted(self.data)
if new_data < self.data[-1]:
self.data[-1] = new_data
def avg(self):
return mean(self.data)
def _distributed_worker(
gpu_id, world_size, fsdp_config, tempfile, tempfile_rpc,
):
torch.cuda.set_device(gpu_id)
rank = gpu_id
result = dist_init(rank, world_size, tempfile, tempfile_rpc)
assert result, "Dist init failed"
# Save the original torch.distributed.all_gather function since we will
# patch it to include an artificial delay.
orig_all_gather = torch.distributed.all_gather
def run(compute_cycles, all_gather_cycles):
has_params = all_gather_cycles > 0
model = _create_model(fsdp_config, compute_cycles, has_params)
# Get the input and sets the input's requires_grad to True because
# we have a fake compute in the forward pass.
batch = torch.rand(1).cuda()
batch.requires_grad = True
# We run 20 iterations but only collect timing data from the minimal 10
# data points because nondeterministic system events can disturb the timing.
cpu_iter = Min10()
cpu_wait = Min10()
gpu_compute = Min10()
gpu_total = Min10()
for _ in range(20):
# Get two events for measuring the overall time.
e1 = Event(enable_timing=True)
e2 = Event(enable_timing=True)
cpu_start = time.process_time()
all_gather_called = False
def _delayed_all_gather(*args, **kwargs):
nonlocal all_gather_called
all_gather_called = True
torch.cuda._sleep(all_gather_cycles)
return orig_all_gather(*args, **kwargs)
# forward pass
#
# Even though both e1 & e2 are on the compute stream, since
# compute depends on all_gather, e2-e1 includes all_gather time.
e1.record()
with patch("torch.distributed.all_gather", _delayed_all_gather):
out = model(batch)
if has_params and world_size > 1:
assert all_gather_called
else:
assert not all_gather_called
e2.record()
# backward pass
out.backward()
if torch_version() >= (1, 7, 0):
model.zero_grad(set_to_none=True)
else:
for p in model.parameters():
p.grad = None
cpu_iter_time = time.process_time() - cpu_start
# wait for gpu
out.item()
cpu_wait_for_gpu_time = time.process_time() - cpu_start - cpu_iter_time
# get sum of the compute time
times = []
for mod in model.modules():
if not isinstance(mod, Layer):
continue
times.append(mod.get_time())
# get gpu compute + all_gather time
overall_gpu_time = e1.elapsed_time(e2)
cpu_iter.add(cpu_iter_time)
cpu_wait.add(cpu_wait_for_gpu_time)
gpu_compute.add(sum(times))
gpu_total.add(overall_gpu_time)
del model
return {
"cpu_iter": cpu_iter.avg(),
"cpu_wait": cpu_wait.avg(),
"gpu_compute": gpu_compute.avg(),
"gpu_total": gpu_total.avg(),
}
sleep_cycles = int(100 * get_cycles_per_ms())
e1 = run(0, 0) # no compute, no all-gather
e2 = run(0, sleep_cycles) # no compute, only all-gather
e3 = run(sleep_cycles, 0) # only compute, no all-gather
e4 = run(sleep_cycles, sleep_cycles) # both compute and all-gather
debug_string = f"\nrank{rank}:\n e1: {e1}\n e2: {e2}\n e3: {e3}\n e4: {e4}"
print(debug_string)
# Check the cpu/gpu timing. CPU should run ahead of GPU. Therefore, cpu-gpu
# wait should be long, except when there is no real work on GPU.
#
# If the assertions fail below, we likely have a cpu-gpu wait in the forward/backward pass.
short = [e1["cpu_iter"], e2["cpu_iter"], e3["cpu_iter"], e4["cpu_iter"], e1["cpu_wait"]]
long = [e3["cpu_wait"], e4["cpu_wait"]]
if world_size == 1:
short.append(e2["cpu_wait"]) # all gather should not be happening.
else:
long.append(e2["cpu_wait"]) # all gather should happen and prolong the cpu-gpu wait.
for s in short:
for l in long:
# 10X longer is a safe margin, since the GPU work timing is around 100X more
# of that of the CPU.
assert s * 10 < l, f"{s} * 10 < {l} in " + debug_string
# Check the GPU timing.
short = [e1["gpu_compute"], e1["gpu_total"], e2["gpu_compute"]]
long = [e3["gpu_compute"], e3["gpu_total"], e4["gpu_compute"], e4["gpu_total"]]
if world_size == 1:
short.append(e2["gpu_total"]) # all gather should not be happening.
else:
long.append(e2["gpu_total"]) # all gather should happen and prolong the cpu-gpu wait.
for s in short:
for l in long:
# 10X longer is a safe margin, since the time is around 100X longer
# when there is work on GPU vs. no work.
assert s * 10 < l, f"{s} * 10 < {l} in " + debug_string
# Check the GPU overlapping when there is all-gather.
if world_size > 1:
compute_only = e3["gpu_compute"]
all_gather_only = e2["gpu_total"]
both = e4["gpu_total"]
assert compute_only + all_gather_only > 1.1 * both, (
f"{compute_only} + {all_gather_only} > 1.1 * {both} in " + debug_string
)
teardown()
@skip_if_single_gpu
@pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
@pytest.mark.parametrize("mixed", ["mixed", "full"])
def test_forward_overlap(world_size, flatten, mixed):
fsdp_config = {
"flatten_parameters": flatten == "flatten",
"mixed_precision": mixed == "mixed",
}
with temp_files_ctx(2) as temp_files:
mp.spawn(
_distributed_worker, (world_size, fsdp_config, temp_files[0], temp_files[1]), nprocs=world_size,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment