Unverified Commit 8347c1a2 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] experimental MEVO layer (#840)



* [feat] MEVO kernel

- initial import from min/softmax and min/testing branches
- need to rename and further cleanup

* only test with newer pytorch

* renamed and added comments and code cleanup

* rename and reduce test memory

* testing

* minor fixing

* fixing

* more fix

* changelog

* more 1.7 and 1.8 paper cuts

* remove dead code

* addressed Benjamin's comments

* addressed more comments
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent f327eb4a
...@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- [FSDP]: limited support of shared weights between FSDP wrappers. This allows large parameter - [FSDP]: limited support of shared weights between FSDP wrappers. This allows large parameter
and gradient memory to be sharded despite being needed from different layers due to and gradient memory to be sharded despite being needed from different layers due to
weight sharing. [#836] weight sharing. [#836]
- [MEVO]: a custom layer to help big vocab trainings. Experimental. Docs is still TBD. [#840]
## [0.4.1] - 2021-09-17 ## [0.4.1] - 2021-09-17
### Fixed ### Fixed
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import contextlib
from pprint import pprint
from statistics import mean
import time
import torch
from torch import nn
from torch.cuda import Event
from fairscale.experimental.nn import MEVO, BaselineSoftmaxNllLoss
from fairscale.experimental.nn.mevo import get_data
"""Benchmarking the MEVO kernel and its Baseline."""
SHAPES = [
# name, activation, FC weights
("1k_128h_256k", (1024, 128), (128, 256 * 1024)),
# ("4k_128h_256k", (4096, 128), (128, 256 * 1024)),
# ("8k_4k_32k", (4 * 2048, 4 * 1024), (4 * 1024, 32 * 1024)),
# ("24k_4k_50k", (12 * 2048, 4 * 1024), (4 * 1024, 50 * 1024)),
# ("8k_4k_256k", (4 * 2048, 4 * 1024), (4 * 1024, 256 * 1024)),
# ("8k_4k_256008", (4 * 2048, 4 * 1024), (4 * 1024, 256008)), # max seq len for base is 2100, 2300 for top-k
# ("xk_4k_256008", (1 * 2048, 4 * 1024), (4 * 1024, 256008)),
]
KERNELS = [
BaselineSoftmaxNllLoss,
MEVO,
]
def run_on_gpu(kernel, data, repeats, no_grad, fwd_bwd):
""" Measure both GPU runtime and peak memory usage of a kernel. """
tokens = data[0].shape[0]
def get_cuda_data():
"""Move the data from CPU to GPU. We make a new weight parameter with this call."""
with torch.no_grad():
i, w, t = data # i, t are tensors, w is a param
w = nn.Linear(w.shape[1], w.shape[0], bias=False, dtype=w.dtype, device="cuda").weight
assert w.requires_grad
return i.cuda().requires_grad_(True), w, t.cuda()
def _test(kernel_obj, event):
"""Forward and backward passes."""
context = contextlib.suppress()
if no_grad:
context = torch.no_grad()
with context:
if event is not None:
event.record()
out = kernel_obj(input, target)
if fwd_bwd:
assert not no_grad
out.backward()
del out
if fwd_bwd:
assert input.grad is not None, input
assert weight.grad is not None, weight
assert target.grad is None, target
input.grad = None
weight.grad = None
def _get_kernel():
"""Get a kernel instance."""
return kernel(weight, tile_factor=16)
#
# Run the test once to measure memory.
#
# Ensure GPU memory is clean, empty, 0.
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
cur_mem_before = round(torch.cuda.memory_allocated() / 1024 / 1024)
assert cur_mem_before == 0, cur_mem_before
# Move tensors to GPU.
input, weight, target = get_cuda_data()
# Create the kernel
k = _get_kernel()
_test(k, None)
# Might wait for gpu here
torch.cuda.synchronize()
# Free memory, ensure everything is clean, no leak.
del k
del input
del weight
del target
cur_mem_after = round(torch.cuda.memory_allocated() / 1024 / 1024)
assert cur_mem_after == 0, cur_mem_after
# Get peak mem
peak_mem_after = round(torch.cuda.max_memory_allocated() / 1024 / 1024)
peak_mem = peak_mem_after - cur_mem_before
#
# Run multiple times to get both CPU timing and average GPU timing.
#
# Move tensors to GPU and get k, again.
input, weight, target = get_cuda_data()
k = _get_kernel()
# Get the events
events = [Event(enable_timing=True) for _ in range(repeats + 1)]
# Queue the ops to GPU
cpu_start_time = time.time()
for i in range(repeats):
_test(k, events[i])
events[i + 1].record() # end time of the last run
# CPU could be done much sooner than the GPU here.
cpu_time = time.time() - cpu_start_time
# Might wait for gpu here
torch.cuda.synchronize()
# Get the durations
durations = [cpu_time * 1000] # convert CPU time, from seconds to ms.
for x, y in zip(events, events[1:]):
durations.append(x.elapsed_time(y))
assert len(durations) == repeats + 1
# Free memory
del k
input, weight, target = None, None, None
cur_mem_after = round(torch.cuda.memory_allocated() / 1024 / 1024)
assert cur_mem_after == 0, cur_mem_after
# Skip 2 for cpu time and first warm up time to compute the average.
time_per_call = mean(durations[2:]) # ms
time_per_token = time_per_call * 1000 / tokens # us
return peak_mem, durations[:2] + [time_per_call, time_per_token]
def main():
parser = argparse.ArgumentParser("Benchmarking MEVO")
parser.add_argument("--dtype", type=str, choices=["fp16", "fp32"], default="fp16")
parser.add_argument("--grad", type=str, choices=["grad", "no_grad"], default="grad")
parser.add_argument("--fwd_bwd", action="store_true", default=False)
args = parser.parse_args()
repeats = 9
results = {}
results["peak cached"] = {}
results["durations"] = {}
for shape in SHAPES:
name = shape[0]
results["peak cached"][name] = {}
results["durations"][name] = {}
dtype = torch.float32 if args.dtype == "fp32" else torch.float16
# Use cpu memory to ensure we always start with an empty GPU
data = get_data(shape[1:], dtype, "cpu")
for kernel in KERNELS:
k_name = kernel.__name__
no_grad = args.grad
print(f"Running {k_name} with {name} {dtype} {no_grad} data")
peak_mem, durations = run_on_gpu(kernel, data, repeats, no_grad == "no_grad", args.fwd_bwd)
results["peak cached"][name][k_name] = peak_mem
results["durations"][name][k_name] = durations
pprint(results)
if __name__ == "__main__":
main()
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
from typing import List from typing import List
from .mevo import BaselineSoftmaxNllLoss
from .mevo import MemoryEfficientVocabOutput as MEVO
from .offload import OffloadModel from .offload import OffloadModel
from .sync_batchnorm import SyncBatchNorm from .sync_batchnorm import SyncBatchNorm
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Tuple
import torch
from torch import nn
import torch.distributed as dist
import torch.nn.functional as F
# Debugging flag to enable some prints. Useful to debug with FSDP.
DEBUG = False
def _next_power_of_2_or_max(n: int, max_n: int) -> int:
""" Return the smallest power of 2 greater than or equal to n, with a limit.
Useful when used in splitting a tensor into chunks with power-of-2 sizes.
"""
# special case, just split to 1 element chunks.
if n == 0:
return 1
orig_n = n
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n += 1
assert n >= orig_n, f"{n} vs. {orig_n}"
assert bin(n).count("1") == 1, bin(n) # Catch the case n is too large for this function.
if n > max_n:
return max_n
return n
def _reshape_inputs(input: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Convert 3D inputs to 2D for this kernel"""
if len(input.shape) == 3:
input = input.reshape(-1, input.shape[2])
if len(target.shape) == 2:
target = target.reshape(-1)
return input, target
def get_data(
shape: Tuple[Tuple[int, int], Tuple[int, int]], dtype: torch.dtype = torch.float16, device: str = "cuda"
) -> Tuple[torch.Tensor, nn.Parameter, torch.Tensor]:
""" Utility function for getting some tensors for testing and benchmarking."""
(tokens, d1), (d2, vocabs) = shape
assert d1 == d2
input = torch.rand(tokens, d1, device=device, dtype=dtype).requires_grad_(True)
# Before pytorch 1.9, nn.Linear does not support device and dtype init option. So we use to()
# and an if condition.
layer = nn.Linear(d2, vocabs, bias=False).to(device)
assert dtype in [torch.float16, torch.float32]
if dtype == torch.float16:
layer = layer.half()
weight = layer.weight
target = (torch.rand(tokens, device=device) * vocabs).long()
return input, weight, target
class BaselineSoftmax(nn.Module):
""" Baseline softmax that does an output linear projection and a softmax.
This is intended to be used with an embedding layer with shared weights.
Args:
proj_weight (nn.Parameter):
The shared weight.
tile_factor (int):
Unused. It is here to make kernel init easier with MEVO.
log_softmax (bool):
If True, use log_softmax instead of softmax.
"""
def __init__(self, proj_weight: nn.Parameter, tile_factor: int = 0, log_softmax: bool = True):
super().__init__()
out_dim, in_dim = proj_weight.shape
assert "cuda" in str(proj_weight.device), "weight should be on GPU"
self.fc = nn.Linear(in_dim, out_dim, bias=False).to("cuda")
assert proj_weight.dtype in [torch.float16, torch.float32]
if proj_weight.dtype == torch.float16:
self.fc = self.fc.half()
self.fc.weight = proj_weight
assert self.fc.weight.dtype in [torch.float16, torch.float32], self.fc.weight.dtype
self.fp16 = self.fc.weight.dtype == torch.float16
self.log_softmax = log_softmax
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # type: ignore
""" Forward function that computes softmax output with the input and target."""
assert isinstance(input, torch.Tensor)
assert isinstance(target, torch.Tensor)
input, target = _reshape_inputs(input, target)
if self.fp16:
assert input.dtype == torch.float16
x = self.fc(input)
# Note that we do softmax in FP32, which is important for numerical stability.
if self.log_softmax:
x = F.log_softmax(x, dim=-1, dtype=torch.float32)
else:
x = F.softmax(x, dim=-1, dtype=torch.float32)
assert x.dtype == torch.float32
return x
class BaselineSoftmaxNllLoss(BaselineSoftmax):
""" Baseline that does an output projection, a softmax & a NLL loss (cross-entropy).
See BaselineSoftmax above. Constructor is the same. Only difference is in the
forward function.
This class is used for testing and benchmarking.
"""
def __init__(self, proj_weight: nn.Parameter, tile_factor: int = 0, log_softmax: bool = True):
super().__init__(proj_weight, tile_factor, log_softmax)
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # type: ignore
"""Forward that directly compute the loss."""
assert isinstance(input, torch.Tensor)
assert isinstance(target, torch.Tensor)
input, target = _reshape_inputs(input, target)
x = super().forward(input, target)
return F.nll_loss(x, target, reduction="sum")
class GetMaxFunction(torch.autograd.Function):
"""Custom checkpointed function to get max-per-token from an input and a weight"""
@staticmethod
def get_max(i: torch.Tensor, w: torch.Tensor, full_precision: bool) -> torch.Tensor:
"""
Throughout this code:
i: input data with shape = (split-of-tokens, d_model)
w: weight data with shape = (split-of-vocabs, d_model)
"""
_m = torch.matmul(i, w.T)
if full_precision:
_m = _m.float()
_m = _m.max(dim=1)[0]
return _m
@staticmethod
def forward( # type: ignore
ctx: Any,
i: torch.Tensor,
w: torch.Tensor,
kernel_obj: "MemoryEfficientVocabOutput",
w_idx: int,
w_split_size: int,
split_dim: int,
) -> torch.Tensor:
"""Forward function that computes the max, without saving activations."""
if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print("DEBUG max fwd")
ctx.save_for_backward(i, w)
ctx.kernel_obj = kernel_obj
ctx.w_idx = w_idx
ctx.w_split_size = w_split_size
ctx.args = {}
assert split_dim == 0
# During forward, we use ``no_grad'' to avoid saving the activations.
# The activations will be recomputed in backward below and freed
# immediately after use. This saves the overall GPU peak memory of this layer.
with torch.no_grad():
return GetMaxFunction.get_max(i, w, kernel_obj.fp_max)
@staticmethod
def backward(ctx: Any, *args: Any) -> Any:
"""Recompute the forward max and backward grad.
Accumulate the grad to the right split of the full grad.
"""
if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print("DEBUG max bwd")
assert len(args) == 1
# Gradients should already exist due to TargetScoreFunction's backward.
assert ctx.kernel_obj.proj_weight.grad is not None
# Get saved i and w.
i, w = ctx.saved_tensors
assert i.requires_grad
assert w.requires_grad
# We use ``detach()'' to ensure the backward call below does not
# trigger backward computation that produced i and w here. Otherwise,
# the backward call below would trigger backward all the way to
# the batch input.
i = i.detach().requires_grad_(True)
w = w.detach().requires_grad_(True)
# Forward + backward again.
with torch.enable_grad():
# This saves the activations.
maxs = GetMaxFunction.get_max(i, w, ctx.kernel_obj.fp_max)
# This will use the activations and free them immediately.
torch.autograd.backward(maxs, *args)
# Accumulate the computed gradients into the bigger weight tensor's gradient tensor.
assert w.grad is not None
with torch.no_grad():
grads = torch.split(ctx.kernel_obj.proj_weight.grad, ctx.w_split_size)
grads[ctx.w_idx].add_(w.grad)
return i.grad, None, None, None, None, None
class GetSumFunction(torch.autograd.Function):
"""Custom checkpointed function to get sum-per-token from an input and a weight."""
@staticmethod
def get_sum(i: torch.Tensor, w: torch.Tensor, maxs: torch.Tensor, full_precision: bool) -> torch.Tensor:
_s = torch.matmul(i, w.T)
if full_precision:
_s = _s.float()
_s = (_s - maxs.reshape(-1, 1)).exp().sum(dim=1)
return _s
@staticmethod
def forward( # type: ignore
ctx: Any,
i: torch.Tensor,
w: torch.Tensor,
maxs: torch.Tensor,
kernel_obj: "MemoryEfficientVocabOutput",
w_idx: int,
w_split_size: int,
split_dim: int,
) -> torch.Tensor:
"""Forward function that computes the sum, without saving activations."""
if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print("DEBUG sum fwd")
ctx.save_for_backward(i, w, maxs)
ctx.kernel_obj = kernel_obj
ctx.w_idx = w_idx
ctx.w_split_size = w_split_size
assert split_dim == 0
with torch.no_grad():
return GetSumFunction.get_sum(i, w, maxs, kernel_obj.fp_sum)
@staticmethod
def backward(ctx: Any, *args: Any) -> Any:
"""Recompute the forward sum and backward grad.
Accumulate the grad to the right split of the full grad.
"""
if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print("DEBUG sum bwd")
assert len(args) == 1
# Gradients should already exist due to TargetScoreFunction's backward.
assert ctx.kernel_obj.proj_weight.grad is not None
# Get saved i, w, and maxs.
i, w, maxs = ctx.saved_tensors
assert i.requires_grad
assert w.requires_grad
assert maxs.requires_grad
i = i.detach().requires_grad_(True)
w = w.detach().requires_grad_(True)
maxs = maxs.detach().requires_grad_(True)
# Forward + backward again.
with torch.enable_grad():
sums = GetSumFunction.get_sum(i, w, maxs, ctx.kernel_obj.fp_sum)
torch.autograd.backward(sums, *args)
# Accumulate the grads.
assert w.grad is not None
with torch.no_grad():
grads = torch.split(ctx.kernel_obj.proj_weight.grad, ctx.w_split_size)
grads[ctx.w_idx].add_(w.grad)
return i.grad, None, maxs.grad, None, None, None, None
class TargetScoreFunction(torch.autograd.Function):
"""Custom checkpointed function to compute the target score."""
@staticmethod
def get_target_score(i: torch.Tensor, w: torch.Tensor, target: torch.Tensor, full_precision: bool) -> torch.Tensor:
tokens, d_model = i.shape
assert d_model == w.shape[1]
tw = w.gather(dim=0, index=target.reshape(target.shape[0], 1).expand(target.shape[0], d_model))
assert tw.shape == (tokens, d_model)
target_score = i * tw
if full_precision:
target_score = target_score.float()
target_score = target_score.sum(dim=1) # sum into target scores with shape (tokens,)
return target_score
@staticmethod
def forward( # type: ignore
ctx: Any, i: torch.Tensor, w: torch.Tensor, target: torch.Tensor, kernel_obj: "MemoryEfficientVocabOutput"
) -> torch.Tensor:
"""Forward, without activations."""
if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print("DEBUG target fwd")
ctx.save_for_backward(i, w, target)
ctx.kernel_obj = kernel_obj
with torch.no_grad():
x = TargetScoreFunction.get_target_score(i, w, target, kernel_obj.fp_target)
return x
@staticmethod
def backward(ctx: Any, *args: Any) -> Any:
"""Forward and backward again, assign or accumulate the gradients."""
if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print("DEBUG target bwd")
assert len(args) == 1
i, w, target = ctx.saved_tensors
assert i.requires_grad
assert w.requires_grad
assert not target.requires_grad
i = i.detach().requires_grad_(True)
w = w.detach().requires_grad_(True)
with torch.enable_grad():
scores = TargetScoreFunction.get_target_score(i, w, target, ctx.kernel_obj.fp_target)
torch.autograd.backward(scores, *args)
if ctx.kernel_obj.proj_weight.grad is not None:
# This means we accumulate full grad between iters. Not memory efficient.
ctx.kernel_obj.proj_weight.grad.add_(w.grad)
else:
ctx.kernel_obj.proj_weight.grad = w.grad
return i.grad, None, None, None
class BackwardTriggerFn(torch.autograd.Function):
"""A backward trigger function."""
@staticmethod
def forward( # type: ignore
ctx: Any, w: torch.Tensor, trigger_tensor: torch.Tensor
) -> torch.Tensor:
"""We take a weight tensor and the trigger as inputs and output the weight directly."""
if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print("DEBUG trigger fwd")
ctx.save_for_backward(w, trigger_tensor)
return w
@staticmethod
def backward(ctx: Any, *args: Any) -> Any:
"""We return zero grad for the trigger only."""
if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print("DEBUG trigger bwd")
assert len(args) == 1
w, trigger = ctx.saved_tensors
assert w.requires_grad
assert trigger.requires_grad
return None, torch.zeros_like(trigger)
class BackwardTrigger(nn.Module):
"""A backward trigger module.
This module takes a parameter as an input and create a linked parameter
from a newly created trigger parameter.
The way to use it in a module's ``__init__'' and ``forward'' functions:
```
def __init__():
...
self.trigger = BackwardTrigger(some_layer.weight)
...
def forward():
w = self.trigger()
... continue to use w ...
```
As a resule, the trigger's backward hook will be called at the end of
the backward for the module that uses this trigger.
"""
def __init__(self, linked_param: torch.Tensor):
super().__init__()
assert isinstance(linked_param, nn.Parameter)
self.trigger = nn.Parameter(torch.rand(1, dtype=linked_param.dtype))
self.trigger._linked_param = linked_param
def forward(self) -> torch.Tensor: # type: ignore
return BackwardTriggerFn.apply(self.trigger._linked_param, self.trigger)
class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
""" Fused fc + softmax + nll_loss in a tiled fashion.
This uses much less memory but is quite a bit slower.
Args:
proj_weight (nn.Parameter):
Sharing this weight with an embedding layer.
tile_factor (int):
Number of splits to use on the input sequence and vocab dimensions.
reduction (str):
Reduction OP (sum or mean).
"""
def __init__(self, proj_weight: nn.Parameter, tile_factor: int = 16, reduction: str = "sum"):
super().__init__()
self.proj_weight = proj_weight
# TODO (Min): these two factors doesn't have to be the same. More tuning can be done.
self.tf_in, self.tf_w = tile_factor, tile_factor
self.fp_max = True
self.fp_sum = True # This is esp. important when tensors are large. Otherwise, you get inf.
self.fp_target = True
self.log_softmax = True
self.reduction = reduction
assert self.reduction in ["sum", "mean"]
self.trigger = BackwardTrigger(self.proj_weight)
if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print(
f"DEBUG cfg tf_in={self.tf_in} tf_w={self.tf_w} fp_max={self.fp_max} "
f"fp_sum={self.fp_sum} fp_target={self.fp_target} log_softmax={self.log_softmax} "
f"reduction={self.reduction}"
)
def get_target_nlprob(
self, i: torch.Tensor, w: torch.Tensor, target: torch.Tensor, debase_max: torch.Tensor, exp_sums: torch.Tensor
) -> torch.Tensor:
"""Get target's negative log probability."""
target_score = TargetScoreFunction.apply(i, w, target, self)
prob = (target_score - debase_max).exp() / exp_sums
if self.log_softmax:
# lprob
prob = prob.log()
# nlprob, then sum over all tokens.
return -prob.sum()
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # type: ignore
if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
cur_mem = round(torch.cuda.memory_allocated() / 1024 / 1024)
mem = round(torch.cuda.max_memory_allocated() / 1024 / 1024)
print("DEBUG cur, peak", cur_mem, mem)
assert isinstance(input, torch.Tensor)
assert isinstance(target, torch.Tensor)
assert input.requires_grad
input, target = _reshape_inputs(input, target)
tokens, d_model = input.shape
vocab, d2 = self.proj_weight.shape
assert d_model == d2
split_dim = 0
input_split_size = _next_power_of_2_or_max(tokens // self.tf_in, tokens)
weight_split_size = _next_power_of_2_or_max(vocab // self.tf_w, vocab)
inputs = torch.split(input, input_split_size, split_dim)
weight = self.trigger()
weights = torch.split(weight, weight_split_size, split_dim)
# Get maxs
maxs = []
for i in inputs:
m = None # max with (tokens_tile,) shape
for w_idx, w in enumerate(weights):
_m = GetMaxFunction.apply(i, w, self, w_idx, weight_split_size, split_dim)
if m is None:
m = _m
else:
m = torch.max(m, _m)
assert m is not None
maxs.append(m) # (tokens_tile,)
maxs_tensor = torch.cat(maxs) # (tokens,)
assert maxs_tensor.shape == (tokens,)
# Get sums.
sums = []
for idx, i in enumerate(inputs):
s = None # sum with (tokens_tile,) shape
for w_idx, w in enumerate(weights):
_s = GetSumFunction.apply(i, w, maxs[idx], self, w_idx, weight_split_size, split_dim)
if s is None:
s = _s
else:
s += _s
assert s is not None
sums.append(s) # (tokens_tile,)
sums_tensor = torch.cat(sums) # (tokens,)
assert sums_tensor.shape == (tokens,)
# select weights for targets
result = self.get_target_nlprob(input, self.proj_weight, target, maxs_tensor, sums_tensor)
if self.reduction == "mean":
result /= tokens
return result
...@@ -1412,7 +1412,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1412,7 +1412,7 @@ class FullyShardedDataParallel(nn.Module):
if hasattr(param._linked_param, "_is_shared") and param._linked_param._is_shared: if hasattr(param._linked_param, "_is_shared") and param._linked_param._is_shared:
param = param._linked_param param = param._linked_param
assert param.grad is not None assert param.grad is not None, param.shape
if param.grad.requires_grad: if param.grad.requires_grad:
raise RuntimeError("FSDP only works with gradients that don't require gradients") raise RuntimeError("FSDP only works with gradients that don't require gradients")
......
...@@ -161,7 +161,7 @@ def gumbel_softmax(logits: Tensor, tau: float = ..., hard: bool = ..., eps: floa ...@@ -161,7 +161,7 @@ def gumbel_softmax(logits: Tensor, tau: float = ..., hard: bool = ..., eps: floa
def log_softmax(input: Tensor, dim: Optional[int] = ..., _stacklevel: int = ..., def log_softmax(input: Tensor, dim: Optional[int] = ..., _stacklevel: int = ...,
dtype: Optional[int] = ...) -> Tensor: ... dtype: Optional[_dtype] = ...) -> Tensor: ...
def tanh(input: Any): ... def tanh(input: Any): ...
......
...@@ -4,6 +4,9 @@ from .module import Module ...@@ -4,6 +4,9 @@ from .module import Module
from .. import Parameter from .. import Parameter
from ... import Tensor from ... import Tensor
import torch
from typing import Union
class Identity(Module): class Identity(Module):
...@@ -20,7 +23,7 @@ class Linear(Module): ...@@ -20,7 +23,7 @@ class Linear(Module):
weight: Parameter = ... weight: Parameter = ...
bias: Parameter = ... bias: Parameter = ...
def __init__(self, in_features: int, out_features: int, bias: bool = ...) -> None: ... def __init__(self, in_features: int, out_features: int, bias: bool = ..., device:str = ..., dtype:Union[str, torch.dtype] = ...) -> None: ...
def reset_parameters(self) -> None: ... def reset_parameters(self) -> None: ...
......
...@@ -21,9 +21,9 @@ T_co = TypeVar('T_co', covariant=True) ...@@ -21,9 +21,9 @@ T_co = TypeVar('T_co', covariant=True)
class Module(Generic[T_co]): class Module(Generic[T_co]):
def __init__(self) -> None: ... def __init__(self) -> None: ...
def forward(self, *input: Any, **kwargs: Any) -> T_co: ... # type: ignore def forward(self, *input: Any, **kwargs: Any) -> T_co: ...
def __call__(self, *input: Any, **kwargs: Any) -> T_co: ... # type: ignore def __call__(self, *input: Any, **kwargs: Any) -> T_co: ...
def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None: ... def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None: ...
......
tests/nn/data_parallel/test_fsdp_shared_weights_mevo.py
tests/nn/data_parallel/test_fsdp_shared_weights.py tests/nn/data_parallel/test_fsdp_shared_weights.py
tests/nn/data_parallel/test_fsdp_pre_backward_hook.py tests/nn/data_parallel/test_fsdp_pre_backward_hook.py
tests/nn/data_parallel/test_fsdp_overlap.py tests/nn/data_parallel/test_fsdp_overlap.py
...@@ -42,6 +43,7 @@ tests/nn/pipe/test_dependency.py ...@@ -42,6 +43,7 @@ tests/nn/pipe/test_dependency.py
tests/nn/pipe/test_stream.py tests/nn/pipe/test_stream.py
tests/nn/moe/test_moe_layer.py tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_top2gating.py tests/nn/moe/test_top2gating.py
tests/experimental/nn/test_mevo.py
tests/experimental/nn/test_multiprocess_pipe.py tests/experimental/nn/test_multiprocess_pipe.py
tests/experimental/nn/test_sync_batchnorm.py tests/experimental/nn/test_sync_batchnorm.py
tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
import os
import pytest
import torch
from fairscale.experimental.nn import MEVO
from fairscale.experimental.nn.mevo import BaselineSoftmaxNllLoss, get_data
from fairscale.utils.testing import skip_if_no_cuda
@pytest.fixture(scope="session", params=[torch.float16, torch.float32])
def input_data(request):
shape = ((2, 3), (3, 4))
return get_data(shape, dtype=request.param)
_dense_out = {} # type: ignore
_dense_grad = {} # type: ignore
@skip_if_no_cuda
def test_mevo():
"""Test the MEVO kernel by itself."""
torch.random.manual_seed(os.getpid())
shape = ((5, 3), (3, 7))
# Turn on large data for local testing.
large = False
if large:
shape = ((1 * 2048, 4096), (4096, 256008))
print("\nshapes are", shape)
input, weight, target = get_data(shape, dtype=torch.float16)
k = MEVO(weight, tile_factor=16)
o = k(input, target)
o.backward()
print(o, o.shape)
del o
cur_mem = round(torch.cuda.memory_allocated() / 1024 / 1024)
mem = round(torch.cuda.max_memory_allocated() / 1024 / 1024)
print("cur and peak mem for tiled fwd+bwd =", cur_mem, mem)
assert input.shape == input.grad.shape
input_data = input.data.cpu()
input_grad1 = input.grad.cpu()
del input
cur_mem = round(torch.cuda.memory_allocated() / 1024 / 1024)
mem = round(torch.cuda.max_memory_allocated() / 1024 / 1024)
print("after moving input and its grad, cur and peak mem for tiled fwd+bwd =", cur_mem, mem)
print(weight.grad.norm(), weight.grad)
g1 = weight.grad.clone()
weight.grad = None
input = input_data.cuda().requires_grad_(True)
refk = BaselineSoftmaxNllLoss(weight)
o = refk(input, target)
o.backward()
print(o, o.shape)
del o
print(weight.grad.norm(), weight.grad)
g2 = weight.grad.clone()
input_grad2 = input.grad.cpu()
# Print the diff. We use .cuda() since in 1.7 and 1.8, min() and max() are not
# implemented for cpu float16.
diff = g1 - g2
print("weight grad diff", diff.cuda().min(), diff.cuda().max())
diff = input_grad1 - input_grad2
print("input grad diff", diff.cuda().min(), diff.cuda().max())
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test FSDP with shared weights between wrappers using a model with mevo kernel. """
from copy import deepcopy
import pytest
import torch
from torch import nn
import torch.multiprocessing as mp
from torch.optim import SGD
from fairscale.experimental.nn import MEVO
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import dist_init, objects_are_equal, skip_if_single_gpu, teardown, temp_files_ctx
VOCAB = 4
D_MODEL = 2
BS = 2
SEQ = 3
TILE = 2
_large = True
if _large:
VOCAB = 1024 * 50
D_MODEL = 1024
BS = 2
SEQ = 16
TILE = 16
class Model(nn.Module):
def __init__(self, with_fsdp=False, wrap_middle="none"):
super().__init__()
self.l0 = nn.Embedding(VOCAB, D_MODEL).cuda().half()
nn.init.uniform_(self.l0.weight, -1.0e-1, 1.0e-1)
self.l1 = MEVO(self.l0.weight, tile_factor=TILE, reduction="sum")
self.middle = nn.Linear(D_MODEL, D_MODEL).cuda().half()
# LNs are not strictly needed for this test, but they help reduce the loss quickly
# and improves the numerical stability.
self.ln1 = nn.LayerNorm(D_MODEL).cuda().half()
self.ln2 = nn.LayerNorm(D_MODEL).cuda().half()
if with_fsdp:
# Shared layers much be un-flatten.
self.l0 = FSDP(self.l0, flatten_parameters=False, mixed_precision=False, compute_dtype=torch.float16)
self.l1 = FSDP(self.l1, flatten_parameters=False, mixed_precision=False, compute_dtype=torch.float16)
self.l1.append_shared_param(self.l0.module.weight)
# These are for debugging.
# print(id(self.l0), "is emb")
# print(id(self.l1), "is out")
assert wrap_middle in ["none", "flat", "nonflat"]
if wrap_middle != "none":
self.middle = FSDP(
self.middle,
flatten_parameters=wrap_middle == "flat",
mixed_precision=False,
compute_dtype=torch.float16,
)
# print(id(self.middle), "is middle")
def forward(self, x):
target = x + 1
x = self.l0(x)
x = self.ln1(x)
x = self.middle(x)
x = self.ln2(x)
x = self.l1(x, target)
print("LOSS", x.item())
assert x.item() not in [float("-inf"), float("inf")]
return x
# A fixture to get tempfiles and ensure they are cleaned up.
@pytest.fixture()
def temp_files():
# dist_init needs 2 files + 3 files for before state, after state, in_data.
with temp_files_ctx(5) as files:
yield files
@skip_if_single_gpu
@pytest.mark.parametrize("wrap_middle", ["none", "flat", "nonflat"])
def test_shared_weight_mevo(temp_files, wrap_middle):
"""Test FSDP with a model with shared weights."""
world_size = 2
# Get ref.
model = Model()
sd_before = deepcopy(model.state_dict())
in_data = (torch.rand(BS, SEQ) * (VOCAB - 1)).cuda().long()
_train(model, in_data, world_size)
sd_after = deepcopy(model.state_dict())
# Before and after state should not be equal.
assert not objects_are_equal(sd_before, sd_after)
# Save data
torch.save(sd_before, temp_files[2])
torch.save(sd_after, temp_files[3])
torch.save(in_data, temp_files[4])
# Run FSDP
mp.spawn(
_dist_worker, (world_size, temp_files, wrap_middle), nprocs=world_size,
)
def _dist_worker(rank, world_size, files, wrap_middle):
# Get data from files.
file1, file2, sd_before, sd_after, in_data = files
sd_before = torch.load(sd_before, map_location=lambda storage, loc: storage.cuda(rank))
sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank))
in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank))
result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2)
assert result, "Dist init failed"
fsdp_model = FSDP(
# To debug: first make with_fsdp=False (no inner wrapping) work, then enable inner wrapping
# and make that work.
Model(with_fsdp=True, wrap_middle=wrap_middle),
flatten_parameters=False,
mixed_precision=False,
compute_dtype=torch.float16,
)
fsdp_model.load_state_dict(sd_before)
_train(fsdp_model, in_data)
objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True)
teardown()
def _train(model, in_data, steps_per_iter=1):
optim = SGD(model.parameters(), lr=0.1)
for _ in range(3):
# Simulate multiple ranks.
for _ in range(steps_per_iter):
out = model(in_data)
out.backward()
# Simulate gradient means between ranks.
if steps_per_iter > 1:
with torch.no_grad():
for p in model.parameters():
p.grad /= steps_per_iter
with torch.no_grad():
for p in model.parameters():
assert not torch.isinf(p.grad).any() and not torch.isnan(p.grad).any()
optim.step()
model.zero_grad(set_to_none=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment