Unverified Commit 33a3a667 authored by Chaitanya Sri Krishna Lolla's avatar Chaitanya Sri Krishna Lolla Committed by GitHub
Browse files

Merge pull request #26 from lcskrishna/cl/ifu_07072020

IFU-07072020
parents 7e099371 eba809d7
......@@ -41,3 +41,7 @@ def scalar_python_val(x):
return x.data[0]
else:
return x[0]
# Accounts for the possibility that some ops may be removed from a namespace.
def filter_attrs(module, attrs):
return list(attrname for attrname in attrs if hasattr(module, attrname))
......@@ -11,24 +11,24 @@ MODULE = torch.Tensor
# MODULE = torch.autograd.Variable
FP16_FUNCS = [
FP16_FUNCS = compat.filter_attrs(MODULE, [
'__matmul__',
]
])
BFLOAT16_FUNCS = [
'__matmul__',
]
FP32_FUNCS = [
FP32_FUNCS = compat.filter_attrs(MODULE, [
'__ipow__',
'__pow__',
'__rpow__',
# Cast to fp32 before transfer to CPU
'cpu',
]
])
CASTS = [
CASTS = compat.filter_attrs(MODULE, [
'__add__',
'__div__',
'__eq__',
......@@ -50,7 +50,7 @@ CASTS = [
'__rtruediv__',
'__sub__',
'__truediv__',
]
])
# None of these, but here to make code cleaner.
SEQUENCE_CASTS = []
......
......@@ -465,12 +465,44 @@ bool dispatch_additive_masked_softmax(output_t *dst, const input_t *src, const i
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);
return true;
}
return false;
}
template<typename input_t, typename output_t, typename acc_t>
bool dispatch_additive_masked_softmax_stream(output_t *dst, const input_t *src, const input_t *pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride, cudaStream_t streamid)
{
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 1024) {
// compute function index. there's a function for each power of two size up to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements) ++log2_elements;
additive_masked_softmax_forward_func<input_t, output_t> kernel;
int warp_size, batches_per_warp;
if (!warp_additive_masked_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
// compute warps per block.
int warps_per_block = (threads_per_block / warp_size);
// compute launch size
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, streamid>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);
return true;
}
return false;
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
......@@ -1110,7 +1142,80 @@ void dispatch_masked_scale_softmax_backward_masked_out(output_t *grad_input, con
}
}
}
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_masked_scale_softmax_backward_masked_out_stream(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count, int heads, cudaStream_t streamid)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil_native(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 0, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
case 1: // 2
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 1, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
case 2: // 4
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 2, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
case 3: // 8
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 3, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
case 4: // 16
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 4, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
case 5: // 32
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 5, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
case 6: // 64
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 6, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
case 7: // 128
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 7, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
case 8: // 256
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 8, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
case 9: // 512
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 9, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
case 10: // 1024
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 10, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
default:
break;
}
}
}
template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>
__global__ void masked_scale_softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output, const uint8_t *mask, acc_t scale, int batch_size, int stride, int element_count)
{
......@@ -1266,6 +1371,77 @@ void dispatch_masked_scale_softmax_backward(output_t *grad_input, const input_t
}
}
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_masked_scale_softmax_backward_stream(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count, cudaStream_t streamid)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil_native(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 0, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 1, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 2, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 3, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 4, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 5, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 6, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 7, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 8, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 9, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 10, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}
// elementwise multiplication called in at::softmax_backward_data is fused inside softmax dgrad kernel
// as a result of fusion, intermediate multiplication result is stored in fp32 in registers, instead of fp16
template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>
......@@ -1608,6 +1784,35 @@ bool dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const
return false;
}
template<typename input_t, typename output_t, typename acc_t>
bool dispatch_softmax_backward_stream(output_t *grad_input, const input_t *grad, const input_t *output, int softmax_elements, int softmax_elements_stride, int batch_count, cudaStream_t streamid)
{
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 1024) {
// compute function index. there's a function for each power of two size up to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements) ++log2_elements;
softmax_backward_func<input_t, output_t> kernel;
int warp_size, batches_per_warp;
if (!warp_softmax_backward_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
// compute warps per block.
int warps_per_block = (threads_per_block / warp_size);
// compute launch size
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
return true;
}
return false;
}
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE=32, int ELEMENTS_PER_LDG_STG=1>
__global__ void masked_softmax_warp_backward(__half *gradInput, const __half *grad, const __half *output, const uint8_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride)
{
......
# Introduction to ASP
This page documents the API for ASP (Automatic Sparsity), a tool that enables sparse training and inference for PyTorch models by adding 2 lines of Python.
## Importing ASP
```
from apex.contrib.sparsity import ASP
```
## Initializing ASP
Apart from the import statement, it is sufficient to add just the following line of code before the training phase to augment the model and the optimizer for sparse training/infercence:
```
ASP.prune_trained_model(model, optimizer)
```
In a typical PyTorch training loop, it might look like this:
```
ASP.prune_trained_model(model, optimizer)
x, y = DataLoader(args)
for epoch in range(epochs):
y_pred = model(x)
loss = loss_function(y_pred, y)
loss.backward()
optimizer.step()
torch.save(...)
```
The `prune_trained_model` calculates the sparse mask and applies it to the weights. This is done once, i.e., sparse locations in the weights matrix remain fixed after this step. In order to recompute the sparse mask in between training, say after an epoch, use the following method:
```
ASP.compute_sparse_masks()
```
A more thorough example can be found in `./test/toy_problem.py`.
\ No newline at end of file
from .sparse_masklib import create_mask
from .asp import ASP
import types
import torch
from .sparse_masklib import create_mask
torchvision_imported=True
try:
import torchvision
except ImportError:
print("[ASP][Warning] torchvision cannot be imported.")
torchvision_imported=False
def eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallowed_layer_names):
eligible_modules_list = []
for name, mod in model.named_modules():
if isinstance(mod, whitelist_layer_types) and name not in disallowed_layer_names:
if allowed_layer_names is not None and name not in allowed_layer_names:
continue
eligible_modules_list.append((name, mod))
return eligible_modules_list
class ASP:
__model = None
__verbosity = 0
__optimizer = None
__sparse_parameters = []
__calculate_mask = None
@classmethod
def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d",
verbosity=3,
whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d],
allowed_layer_names=None, disallowed_layer_names=[],
allow_recompute_mask=False):
"""Call this method to modify your model to take advantage of sparse matrix multiplication.
Note that this call alone only augments the model with additional buffers needed for sparse MMA,
it does not enable use of sparse MMA.
If you are starting with a fresh model:
model = ...
ASP.init_model_for_pruning(model, mask_calculator, ...)
if (training) ASP.init_optimizer_for_pruning(optimizer)
ASP.compute_sparse_masks() // sparsity is off by default, call when youy want to enable it.
If you are starting from a checkpoint:
model = ...
ASP.init_model_for_pruning(model, mask_calculator, ...)
torch.load(...)
if (training) ASP.init_optimizer_for_pruning(optimizer)
Arguments:
model The model
mask_calculator Either callable that computes mask given a tensor OR pattern string for sparse mask lib.
verbosity Integer controling verbosity level.
0 -> Only errors.
1 -> Errors and warnings.
2 -> Errors, warnings and info.
3 -> Errors, warnings, info and debug.
whitelist Module types approved for sparsity.
allowed_layer_names If not None, only layer names that appear in this list are considered for sparsity.
disallowed_layer_names If not [], only layer names that do not appear in this list are considered for sparsity.
allow_recompute_mask If True, stores pruned values so that dense weights can be restored.
Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage.
Support for allow_recompute_mask can be removed, it is not part of our recipe -- AKM.
"""
assert (cls.__model is None), "ASP has been initialized already."
cls.__model = model
cls.__verbosity = verbosity
if isinstance(mask_calculator, str):
def create_mask_from_pattern(param):
return create_mask(param, mask_calculator).bool()
cls.__calculate_mask = create_mask_from_pattern
else:
cls.__calculate_mask = mask_calculator #user defined function
# function to extract variables that will be sparsified.
# idea is that you will add one of these functions for each module type that can be sparsified.
if torchvision_imported:
print("[ASP] torchvision is imported, can work with the MaskRCNN/KeypointRCNN from torchvision.")
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torchvision.ops.misc.Conv2d: ['weight']}
else:
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight']}
for module_type in whitelist:
assert (module_type in sparse_parameter_list), "Module %s :: Don't know how to sparsify module." % module.dtype()
# find all sparse modules, extract sparse parameters and decorate
def add_sparse_attributes(module_name, module):
sparse_parameters = sparse_parameter_list[type(module)]
for p_name, p in module.named_parameters():
if p_name in sparse_parameters and p.requires_grad:
# check for NVIDIA's TC compatibility: we check along the horizontal direction
if p.dtype == torch.float32 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): #User defines FP32 and APEX internally uses FP16 math
print("[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity" % (module_name, p_name, str(p.size()), str(p.dtype)))
continue
if p.dtype == torch.float16 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): #For Conv2d dim= K x CRS; we prune along C
print("[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity" % (module_name, p_name, str(p.size()), str(p.dtype)))
continue
if cls.__verbosity >= 3:
print("[ASP] Sparsifying %s::%s of size=%s and type=%s for sparsity" % (module_name, p_name, str(p.size()), str(p.dtype)))
mask = torch.ones_like(p).bool()
buffname = name.split(".")[-1] # buffer names cannot contain "."
module.register_buffer('__%s_mma_mask' % buffname, mask)
if allow_recompute_mask:
pruned = torch.zeros_like(p).cpu()
module.register_buffer('__%s_mma_pruned_p' % buffname, pruned)
else:
pruned = None
cls.__sparse_parameters.append((module_name, module, p_name, p, mask, pruned))
for name, sparse_module in eligible_modules(model, tuple(whitelist), allowed_layer_names, disallowed_layer_names):
add_sparse_attributes(name, sparse_module)
@classmethod
def init_optimizer_for_pruning(cls, optimizer):
"""Call this method to monkey patch optimizer step function so that masks can be applied to
gradients and weights during training.
You must call init_model_for_pruning(...) before calling init_optimizer_for_pruning(...)
"""
assert (cls.__optimizer is None), "ASP has initialized optimizer already."
assert (cls.__calculate_mask is not None), "Called ASP.init_optimizer_for_pruning before ASP.init_model_for_pruning."
# store pointer to original optimizer step method
cls.__optimizer = optimizer
cls.__optimizer.__step = optimizer.step
def __step(opt_self, *args, **kwargs):
# prune gradients before step method
with torch.no_grad():
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
p.grad.mul_(mask)
# call original optimizer step method
rval = opt_self.__step(*args, **kwargs)
# prune parameters after step method
with torch.no_grad():
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
p.mul_(mask)
return rval
cls.__optimizer.step = types.MethodType(__step, cls.__optimizer)
@classmethod
def compute_sparse_masks(cls):
"""Call this method to enable sparsity.
If init(...) was called with allow_recompute_mask=False AND sparsity is disabled, pruned field can be None.
"""
with torch.no_grad():
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
if mask.sum() < mask.numel(): # when recalculating masks
# restore dense parameter if allow_recompute_mask is enabled
assert (pruned is not None), "Unable to restore dense parameter because allow_recompute_mask == False"
p.add_(pruned.cuda())
mask.set_(cls.__calculate_mask(p))
if pruned is not None: # stow away pruned weights to cpu
pruned.set_((p * (~mask)).cpu())
p.mul_(mask) # in-place multiplication, so pruned weights are 0-values, hence checkpoint will have 0s for pruned weights
if cls.__verbosity >= 2:
print("[ASP] Enabled %.2f%% sparsity for %s::%s of size=%s and type=%s" % (100.0*mask.sum()/mask.numel(), module_name, p_name, str(p.size()), str(p.dtype)))
@classmethod
def restore_pruned_weights(cls):
"""Call this method to disable sparsity and restore all weights.
This will only work if init(...) was called with allow_recompute=True.
"""
with torch.no_grad():
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
if mask.sum() < mask.numel():
assert (pruned is not None), "Unable to restore dense parameter because allow_recompute_mask == False"
p.add_(pruned.cuda())
mask.fill_(1)
pruned.zero_()
if cls.__verbosity >= 2:
print("[ASP] Disabled sparsity for %s::%s (dense weights restored)" % (module_name, p_name))
@classmethod
def is_sparsity_enabled(cls):
"""Call this method to determine if sparsity is enabled in the model.
The typical use case is right after checkpoint has been loaded.
"""
total,sp100,sp50 = 0,0,0
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
total += 1
mask_sum = mask.sum()
mask_numel = mask.numel()
if mask_sum == mask_numel:
sp100 += 1
elif mask_sum*2 == mask_numel:
sp50 += 1
assert (total == sp100 or total == sp50), "Inconsistent model sparsity"
if total == sp100:
return False
elif total == sp50:
return True
@classmethod
def prune_trained_model(cls, model, optimizer):
# add mask buffers to model (init_model_for_pruning), augment optimizer (init_optimizer_for_pruning) and compute masks (compute_sparse_masks)
cls.init_model_for_pruning(model, mask_calculator="m4n2_1d", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d], allow_recompute_mask=False)
cls.init_optimizer_for_pruning(optimizer)
cls.compute_sparse_masks()
import sys
import torch
import numpy as np
import collections
from itertools import permutations
""" compute density (helper fn to compute % NNZs in a tensor) """
def fill(x):
return float(x.nonzero().size(0))/torch.numel(x)
""" reshape matrix into m-dimensional vectors: (h,w) -> (hw/m, m) """
def reshape_1d(matrix, m):
# If not a nice multiple of m, fill with zeroes.
if matrix.shape[1] % m > 0:
mat = torch.cuda.FloatTensor(matrix.shape[0], matrix.shape[1] + (m-matrix.shape[1]%m)).fill_(0)
mat[:, :matrix.shape[1]] = matrix
shape = mat.shape
return mat.view(-1,m),shape
else:
return matrix.view(-1,m), matrix.shape
""" return all possible m:n patterns in a 1d vector """
valid_m4n2_1d_patterns = None
def compute_valid_1d_patterns(m,n):
# Early exit if patterns was already created.
global valid_m4n2_1d_patterns
if m==4 and n==2 and valid_m4n2_1d_patterns is not None: return valid_m4n2_1d_patterns
patterns = torch.zeros(m)
patterns[:n] = 1
valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist()))))
if m == 4 and n == 2: valid_m4n2_1d_patterns = valid_patterns
return valid_patterns
""" m:n 1d structured best """
def mn_1d_best(matrix, m, n):
# Find all possible patterns.
patterns = compute_valid_1d_patterns(m,n).cuda()
# Find the best m:n pattern (sum of non-masked weights).
mask = torch.cuda.IntTensor(matrix.shape).fill_(1).view(-1,m)
mat,shape = reshape_1d(matrix,m)
pmax = torch.argmax(torch.matmul(mat.abs(),patterns.t()), dim=1)
mask[:] = patterns[pmax[:]]
mask = mask.view(matrix.shape)
return mask
def m4n2_1d(mat, density):
return mn_1d_best(mat, 4, 2)
"""
Below 2d-masking related code is targeted more for training (from scratch).
2d-pruning of a weight tensor is done to accelerate DGRAD step during backprop
phase of training algorithm. Acceleration comes from using SpMMA instructions in
Tensor Cores of NVIDIA Ampere GPU Architecture
(note: this code does not do the acceleration, GPU kernels are required for this).
1d pruning of weight tensor helps speed up FPROP step by pruning in 2:4 pattern
along the horizontal (logical) direction.
During DGRAD step, weight tensor is transposed. 2d pruning functions below, mask
weight tensor such that their transposed versions are also 2:4 sparse along the
horizontal (logical) direction. Thus, with 2d pruning, weight tensors are
2:4 sparse along row and column directions.
"""
""" m:n 2d structured pruning: greedy method to select mask """
def mn_2d_greedy(matrix, m, n):
# Convert to numpy
mat = matrix.cpu().detach().numpy()
mask = np.ones(mat.shape, dtype=int)
rowCount = int(mat.shape[0]/m) * m
colCount = int(mat.shape[1]/m) * m
for rowStartIdx in range(0, rowCount, m):
rowEndIdx = rowStartIdx + m
for colStartIdx in range(0, colCount, m):
colEndIdx = colStartIdx + m
matrixSub = np.absolute(np.squeeze(mat[rowStartIdx:rowEndIdx, colStartIdx:colEndIdx]))
maskSub = np.squeeze(mask[rowStartIdx:rowEndIdx, colStartIdx:colEndIdx])
maskSub.fill(0.0)
matrixVecView = matrixSub.reshape(-1)
maskVecView = maskSub.reshape(-1)
linearIdx = np.argsort(matrixVecView)
matrixIdx = [(int(x/m), x % m) for x in linearIdx]
rowCounter = collections.Counter()
colCounter = collections.Counter()
for currIdx in range(len(linearIdx) - 1, -1, -1):
currMatrixEntry = matrixIdx[currIdx]
if (rowCounter[currMatrixEntry[0]] == n) or (colCounter[currMatrixEntry[1]] == n):
continue
#end if
maskSub[currMatrixEntry[0], currMatrixEntry[1]] = 1.0
rowCounter[currMatrixEntry[0]] += 1
colCounter[currMatrixEntry[1]] += 1
return torch.tensor(mask.cuda())
def m4n2_2d_greedy(mat, density):
return mn_2d_greedy(mat, 4, 2)
""" return all possible m:n patterns in a mxn block. """
valid_m4n2_2d_patterns = None
def compute_valid_2d_patterns(m,n):
# Early exit if patterns was already created.
global valid_m4n2_2d_patterns
if valid_m4n2_2d_patterns is not None: return valid_m4n2_2d_patterns
patterns = torch.zeros(m)
patterns[:n] = 1
patterns = list(set(permutations(patterns.tolist())))
patterns = patterns + patterns
patterns = torch.Tensor(list(set(permutations(patterns,m))))
valid = ((patterns.sum(dim=1) <= n).sum(dim=1) == m).nonzero().view(-1)
valid_patterns = torch.Tensor(valid.shape[0],m,m)
valid_patterns[:] = patterns[valid[:]]
if m == 4 and n == 2: valid_m4n2_2d_patterns = valid_patterns
return valid_patterns
""" m:n 2d structured pruning: exhaustive method to select best mask """
def mn_2d_best(matrix, m, n):
# Find all possible patterns.
patterns = compute_valid_2d_patterns(m,n).cuda()
# Find the best m:n pattern (sum of non-masked weights).
mask = torch.cuda.IntTensor(matrix.shape).fill_(1)
mat = reshape_2d(matrix,m,m).abs()
pmax = torch.argmax(torch.matmul(mat,patterns.view(patterns.shape[0],m*m).t()), dim=2)
# Copy best m:n patterns into mask.
mat = mat.view(mat.shape[0]*mat.shape[1],-1)
pmax = pmax.view(pmax.shape[0]*pmax.shape[1]).unsqueeze(1).expand(-1,mat.shape[1])
patterns = patterns.view(patterns.shape[0],patterns.shape[1]*patterns.shape[2])
mat = torch.gather(patterns,0,pmax)
mat = reshape_2d_inv(mat.view(matrix.shape[0]//m,matrix.shape[1]//m,m,m))
mask.copy_(mat.type(mask.type()))
return mask
def m4n2_2d_best(mat, density):
return mn_2d_best(mat, 4, 2)
""" returns a sparse mask """
def create_mask(tensor, pattern="m4n2_1d", density=0.5):
# Reshape tensor and mask.
shape = tensor.shape
ttype = tensor.type()
t = tensor.float().contiguous()
# 1d-tensor
if len(shape) == 1:
t = t.view(1, shape[0])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
return mask.view(shape).type(ttype)
# 2d-tensor (in, out)
elif len(shape) == 2:
t = t.view(shape[0], shape[1])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
return mask.view(shape).type(ttype)
# 3d-tensor (batch, in, out)
elif len(shape) == 3:
t = t.view(shape[0]*shape[1], shape[2])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
return mask.view(shape).type(ttype)
# 4d-tensor (in, out, h, w)
elif len(shape) == 4:
"""
# transformers (bmm)
t = t.view(shape[0]*shape[1]*shape[2], shape[3])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
return mask.view(shape).type(ttype)
"""
# convs
t = t.permute(2,3,0,1).contiguous().view(shape[2]*shape[3]*shape[0], shape[1])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
mask = mask.view(shape[2], shape[3], shape[0], shape[1]).permute(2,3,0,1).contiguous()
return mask.view(shape).type(ttype)
from collections import OrderedDict
import torch
from apex.optimizers import FusedAdam
from apex.contrib.sparsity import ASP
def build_model(args):
od = OrderedDict()
for i in range(args.num_layers):
if i == 0:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])
elif i == args.num_layers-1:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features])
else:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])
return torch.nn.Sequential(od)
def train_step(args, model, optimizer, input_batch, target_batch, step):
predicted_target = model(input_batch)
loss = ((predicted_target-target_batch)**2).sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()
step = step + 1
#print("Step %d :: loss=%e" % (step, loss.item()))
return step
def train_loop(args, model, optimizer, step, num_steps):
for i in range(num_steps):
input_batch = torch.randn([args.batch_size, args.input_features]).cuda()
target_batch = torch.randn([args.batch_size, args.output_features]).cuda()
step = train_step(args, model, optimizer, input_batch, target_batch, step)
return step
def main(args):
#
# PART1
#
torch.manual_seed(args.seed)
model = build_model(args).cuda()
one_ll = next(model.children()).weight
optimizer = FusedAdam(model.parameters())
ASP.init_model_for_pruning(model, args.pattern, verbosity=args.verbosity, whitelist=args.whitelist, allow_recompute_mask=args.allow_recompute_mask)
ASP.init_optimizer_for_pruning(optimizer)
step = 0
# train for a few steps with dense weights
print("DENSE :: ",one_ll)
step = train_loop(args, model, optimizer, step, args.num_dense_steps)
# simulate sparsity by inserting zeros into existing dense weights
ASP.enable_sparsity()
# train for a few steps with sparse weights
print("SPARSE :: ",one_ll)
step = train_loop(args, model, optimizer, step, args.num_sparse_steps)
torch.save({
'step': step,
'verbosity': args.verbosity,
'seed2': args.seed2,
'pattern': args.pattern,
'whitelist': args.whitelist,
'allow_recompute_mask': args.allow_recompute_mask,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, args.checkpoint_path)
if __name__ == '__main__':
class Args:
verbosity=3
seed = 4873
seed2 = 99875
pattern = "m4n2_2d_best"
whitelist = [torch.nn.Linear]
allow_recompute_mask = True
batch_size = 32
input_features = 8
output_features = 8
hidden_features = 32
num_layers = 4
num_dense_steps = 2000
num_sparse_steps = 3000
num_sparse_steps_2 = 1000
checkpoint_path = "part1.chkp"
args = Args()
main(args)
from collections import OrderedDict
import torch
from apex.optimizers import FusedAdam
from apex.contrib.sparsity import ASP
def build_model(args):
od = OrderedDict()
for i in range(args.num_layers):
if i == 0:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])
elif i == args.num_layers-1:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features])
else:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])
return torch.nn.Sequential(od)
def train_step(args, model, optimizer, input_batch, target_batch, step):
predicted_target = model(input_batch)
loss = ((predicted_target-target_batch)**2).sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()
step = step + 1
#print("Step %d :: loss=%e" % (step, loss.item()))
return step
def train_loop(args, model, optimizer, step, num_steps):
for i in range(num_steps):
input_batch = torch.randn([args.batch_size, args.input_features]).cuda()
target_batch = torch.randn([args.batch_size, args.output_features]).cuda()
step = train_step(args, model, optimizer, input_batch, target_batch, step)
return step
def main(step, args, model_state_dict, optimizer_state_dict):
#
# PART2
#
model = build_model(args).cuda()
one_ll = next(model.children()).weight
optimizer = FusedAdam(model.parameters())
ASP.init_model_for_pruning(model, args.pattern, verbosity=args.verbosity, whitelist=args.whitelist, allow_recompute_mask=args.allow_recompute_mask)
ASP.init_optimizer_for_pruning(optimizer)
torch.manual_seed(args.seed2)
model.load_state_dict(model_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
print("Model sparsity is %s" % ("enabled" if ASP.sparsity_is_enabled() else "disabled"))
# train for a few steps with sparse weights
print("SPARSE :: ",one_ll)
step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2)
if __name__ == '__main__':
checkpoint = torch.load("part1.chkp")
class Args:
verbosity = checkpoint['verbosity']
seed = 4873
seed2 = checkpoint['seed2']
pattern = checkpoint['pattern']
whitelist = checkpoint['whitelist']
allow_recompute_mask = checkpoint['allow_recompute_mask']
batch_size = 32
input_features = 8
output_features = 8
hidden_features = 32
num_layers = 4
num_dense_steps = 2000
num_sparse_steps = 3000
num_sparse_steps_2 = 1000
checkpoint_path = "part1.chkp"
args = Args()
main(checkpoint['step'], args, checkpoint['model_state_dict'], checkpoint['optimizer_state_dict'])
from collections import OrderedDict
import torch
from apex.optimizers import FusedAdam
from apex.contrib.sparsity import ASP
#
# Reference run for checkpointing test (part1 + part2)
#
def build_model(args):
od = OrderedDict()
for i in range(args.num_layers):
if i == 0:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])
elif i == args.num_layers-1:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features])
else:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])
return torch.nn.Sequential(od)
def train_step(args, model, optimizer, input_batch, target_batch, step):
predicted_target = model(input_batch)
loss = ((predicted_target-target_batch)**2).sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()
step = step + 1
#print("Step %d :: loss=%e" % (step, loss.item()))
return step
def train_loop(args, model, optimizer, step, num_steps):
for i in range(num_steps):
input_batch = torch.randn([args.batch_size, args.input_features]).cuda()
target_batch = torch.randn([args.batch_size, args.output_features]).cuda()
step = train_step(args, model, optimizer, input_batch, target_batch, step)
return step
def main(args):
#
# PART1
#
torch.manual_seed(args.seed)
model = build_model(args).cuda()
one_ll = next(model.children()).weight
optimizer = FusedAdam(model.parameters())
ASP.init_model_for_pruning(model, args.pattern, whitelist=args.whitelist, allow_recompute_mask=args.allow_recompute_mask)
ASP.init_optimizer_for_pruning(optimizer)
step = 0
# train for a few steps with dense weights
print("DENSE :: ",one_ll)
step = train_loop(args, model, optimizer, step, args.num_dense_steps)
# simulate sparsity by inserting zeros into existing dense weights
ASP.enable_sparsity()
# train for a few steps with sparse weights
print("SPARSE :: ",one_ll)
step = train_loop(args, model, optimizer, step, args.num_sparse_steps)
#
# PART 2
#
torch.manual_seed(args.seed2)
# train for a few steps with sparse weights
print("SPARSE :: ",one_ll)
step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2)
if __name__ == '__main__':
class Args:
seed = 4873
seed2 = 99875
pattern = "m4n2_2d_best"
whitelist = [torch.nn.Linear]
allow_recompute_mask = True
batch_size = 32
input_features = 8
output_features = 8
hidden_features = 32
num_layers = 4
num_dense_steps = 2000
num_sparse_steps = 3000
num_sparse_steps_2 = 1000
checkpoint_path = "part1.chkp"
args = Args()
main(args)
from collections import OrderedDict
import torch
from apex.optimizers import FusedAdam
from apex.contrib.sparsity import ASP
def build_model(args):
od = OrderedDict()
for i in range(args.num_layers):
if i == 0:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])
elif i == args.num_layers-1:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features])
else:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])
return torch.nn.Sequential(od)
def train_step(args, model, optimizer, input_batch, target_batch, step):
predicted_target = model(input_batch)
loss = ((predicted_target-target_batch)**2).sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()
step = step + 1
#print("Step %d :: loss=%e" % (step, loss.item()))
return step
def train_loop(args, model, optimizer, step, num_steps):
for i in range(num_steps):
input_batch = torch.randn([args.batch_size, args.input_features]).cuda()
target_batch = torch.randn([args.batch_size, args.output_features]).cuda()
step = train_step(args, model, optimizer, input_batch, target_batch, step)
return step
def main(args):
model = build_model(args).cuda()
one_ll = next(model.children()).weight
optimizer = FusedAdam(model.parameters())
# only prune linear layers, even though we also support conv1d, conv2d and conv3d
ASP.init_model_for_pruning(model, "m4n2_1d", whitelist=[torch.nn.Linear], allow_recompute_mask=True)
ASP.init_optimizer_for_pruning(optimizer)
step = 0
# train for a few steps with dense weights
print("DENSE :: ",one_ll)
step = train_loop(args, model, optimizer, step, args.num_dense_steps)
# simulate sparsity by inserting zeros into existing dense weights
ASP.compute_sparse_masks()
# train for a few steps with sparse weights
print("SPARSE :: ",one_ll)
step = train_loop(args, model, optimizer, step, args.num_sparse_steps)
# recompute sparse masks
ASP.compute_sparse_masks()
# train for a few steps with sparse weights
print("SPARSE :: ",one_ll)
step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2)
# turn off sparsity
print("SPARSE :: ",one_ll)
ASP.restore_pruned_weights()
# train for a few steps with dense weights
print("DENSE :: ",one_ll)
step = train_loop(args, model, optimizer, step, args.num_dense_steps_2)
if __name__ == '__main__':
class Args:
batch_size = 32
input_features = 16
output_features = 8
hidden_features = 40
num_layers = 4
num_dense_steps = 2000
num_sparse_steps = 3000
num_sparse_steps_2 = 1000
num_dense_steps_2 = 1500
args = Args()
main(args)
......@@ -28,16 +28,24 @@ class SyncBatchnormFunction(Function):
if torch.distributed.is_initialized():
if not process_group:
process_group = torch.distributed.group.WORLD
device = mean.device
world_size = torch.distributed.get_world_size(process_group)
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device)
var_all = torch.empty(world_size, var_biased.size(0), dtype=var_biased.dtype, device=var_biased.device)
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=device)
var_all = torch.empty(world_size, var_biased.size(0), dtype=var_biased.dtype, device=device)
count_all = torch.cuda.IntTensor(world_size, device=device)
mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)]
var_l = [var_all.narrow(0, i, 1) for i in range(world_size)]
count_l = [count_all.narrow(0, i, 1) for i in range(world_size)]
torch.distributed.all_gather(mean_l, mean, process_group)
torch.distributed.all_gather(var_l, var_biased, process_group)
mean, var, inv_std = syncbn.welford_parallel(mean_all, var_all, count, eps)
# TODO(Jie): should do fp32 math instead!
torch.distributed.all_gather(
count_l,
torch.cuda.IntTensor([count], device=device),
process_group)
mean, var, inv_std = syncbn.welford_parallel(mean_all, var_all, count_all, eps)
else:
device = mean.device
count_all = torch.cuda.IntTensor([count], device=device)
inv_std = 1.0 / torch.sqrt(var_biased + eps)
var = var_biased * (count) / (count-1)
......@@ -52,7 +60,7 @@ class SyncBatchnormFunction(Function):
mean = running_mean.data
inv_std = 1.0 / torch.sqrt(running_variance.data + eps)
ctx.save_for_backward(input, weight, mean, inv_std, z, bias)
ctx.save_for_backward(input, weight, mean, inv_std, z, bias, count_all)
ctx.process_group = process_group
ctx.channel_last = channel_last
ctx.world_size = world_size
......@@ -71,7 +79,7 @@ class SyncBatchnormFunction(Function):
# mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0)
saved_input, weight, mean, inv_std, z, bias = ctx.saved_tensors
saved_input, weight, mean, inv_std, z, bias, count = ctx.saved_tensors
process_group = ctx.process_group
channel_last = ctx.channel_last
world_size = ctx.world_size
......@@ -83,26 +91,24 @@ class SyncBatchnormFunction(Function):
if isinstance(z, torch.Tensor) and ctx.needs_input_grad[1]:
grad_z = grad_output.clone()
# TODO(jie): why do I have to clone here? life time of grad_output?
# TODO: update kernel to not pre_divide by item_num
if channel_last:
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last(grad_output, saved_input, mean, inv_std, weight)
sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last(grad_output, saved_input, mean, inv_std, weight)
else:
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, mean, inv_std, weight)
sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, mean, inv_std, weight)
# calculate grad_input
if ctx.needs_input_grad[0]:
if torch.distributed.is_initialized():
torch.distributed.all_reduce(
mean_dy, ReduceOp.SUM, process_group)
mean_dy = mean_dy / world_size
sum_dy, ReduceOp.SUM, process_group)
torch.distributed.all_reduce(
mean_dy_xmu, ReduceOp.SUM, process_group)
mean_dy_xmu = mean_dy_xmu / world_size
sum_dy_xmu, ReduceOp.SUM, process_group)
if channel_last:
grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count)
else:
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count)
if weight is None or not ctx.needs_input_grad[2]:
grad_weight = None
......
......@@ -12,7 +12,7 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input);
// implemented using welford
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,
const at::Tensor var_biased_feature_nodes,
int numel,
const at::Tensor numel,
const float eps);
// elementwise BN operation, returns output
......@@ -24,7 +24,7 @@ at::Tensor batchnorm_forward_CUDA(const at::Tensor input,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// implemented using kahan summation
......@@ -36,14 +36,15 @@ std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// mean/inv_std/mean_dy/mean_dy_xmu precision is fp32
// mean/inv_std/sum_dy/sum_dy_xmu precision is fp32
at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu);
const at::Tensor sum_dy,
const at::Tensor sum_dy_xmu,
const at::Tensor count);
// returns {mean, biased_var}
// implemented using welford
......@@ -62,7 +63,7 @@ at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
const at::optional<at::Tensor> shift,
const bool fuse_relu);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
......@@ -74,15 +75,16 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output,
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// mean/inv_std/mean_dy/mean_dy_xmu precision is fp32
// mean/inv_std/sum_dy/sum_dy_xmu precision is fp32
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu);
const at::Tensor sum_dy,
const at::Tensor sum_dy_xmu,
const at::Tensor count);
at::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
......
......@@ -327,15 +327,15 @@ __global__ void reduce_bn_kernel(
const scalar_t* __restrict__ grad_output,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
accscalar_t* __restrict__ mean_dy,
accscalar_t* __restrict__ mean_dy_xmu,
accscalar_t* __restrict__ sum_dy_o,
accscalar_t* __restrict__ sum_dy_xmu_o,
layerscalar_t* __restrict__ grad_weight,
layerscalar_t* __restrict__ grad_bias,
const int bs,
const int fs,
const int ss) {
static __shared__ int s_mem[64];
int total_item_num = bs * ss;
//int total_item_num = bs * ss;
int thread_id = threadIdx.y*blockDim.x + threadIdx.x;
......@@ -377,8 +377,10 @@ __global__ void reduce_bn_kernel(
if (grad_weight != NULL) {
grad_weight[blockIdx.x] = static_cast<layerscalar_t>(sum_dy_xmu * factor);
}
mean_dy[blockIdx.x] = sum_dy / total_item_num;
mean_dy_xmu[blockIdx.x] = sum_dy_xmu / total_item_num;
//mean_dy[blockIdx.x] = sum_dy / total_item_num;
//mean_dy_xmu[blockIdx.x] = sum_dy_xmu / total_item_num;
sum_dy_o[blockIdx.x] = sum_dy;
sum_dy_xmu_o[blockIdx.x] = sum_dy_xmu;
}
}
......@@ -390,16 +392,24 @@ __global__ void batchnorm_backward_kernel(
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
const layerscalar_t* __restrict__ weight,
const accscalar_t* __restrict__ mean_dy,
const accscalar_t* __restrict__ mean_dy_xmu,
const accscalar_t* __restrict__ sum_dy,
const accscalar_t* __restrict__ sum_dy_xmu,
const int* __restrict__ numel,
scalar_t* __restrict__ grad_input,
const int64_t world_size,
const int ss,
const int bs) {
int64_t div = 0;
for (int i = 0; i < world_size; i++) {
div += numel[i];
}
auto m_c = static_cast<accscalar_t>(mean[blockIdx.x]);
auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]);
//auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]);
auto m_dy_c = static_cast<accscalar_t>(sum_dy[blockIdx.x]) / div;
auto factor_1_c = inv_std[blockIdx.x];
auto factor_2_c = (weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x])) * factor_1_c;
factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[blockIdx.x];
//factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[blockIdx.x];
factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[blockIdx.x] / div;
for (int batch_offset = blockIdx.y*blockDim.y+threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {
int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss;
......@@ -559,13 +569,13 @@ template <typename scalar_t>
__global__ void welford_kernel_parallel(
const scalar_t* __restrict__ mean,
const scalar_t* __restrict__ var_biased,
const int* __restrict__ numel,
scalar_t* __restrict__ out_mean,
scalar_t* __restrict__ out_var,
scalar_t* __restrict__ inv_std,
const int world_size,
const int feature_size,
const float eps,
const int numel) {
const float eps) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < feature_size; i += gridDim.x * blockDim.x) {
// load data;
......@@ -574,7 +584,7 @@ __global__ void welford_kernel_parallel(
scalar_t m_2_n = 0;
int count = 0;
for (int j = 0; j < world_size; j++) {
welford_merge_element(count, x_mean, m_2_n, numel, mean[address], var_biased[address]*numel);
welford_merge_element(count, x_mean, m_2_n, numel[j], mean[address], var_biased[address]*numel[j]);
address += feature_size;
}
out_mean[i] = x_mean;
......@@ -694,8 +704,8 @@ __global__ void reduce_bn_c_last_kernel(
const scalar_t* __restrict__ grad_output,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
accscalar_t* __restrict__ mean_dy,
accscalar_t* __restrict__ mean_dy_xmu,
accscalar_t* __restrict__ sum_dy_o,
accscalar_t* __restrict__ sum_dy_xmu_o,
layerscalar_t* __restrict__ grad_weight,
layerscalar_t* __restrict__ grad_bias,
volatile accscalar_t* staging_data,
......@@ -814,8 +824,10 @@ __global__ void reduce_bn_c_last_kernel(
if (grad_weight != NULL) {
grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
}
mean_dy[c_offset] = sum_dy_th / reduction_size;
mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
//mean_dy[c_offset] = sum_dy_th / reduction_size;
//mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
sum_dy_o[c_offset] = sum_dy_th;
sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
}
}
} else {
......@@ -826,8 +838,10 @@ __global__ void reduce_bn_c_last_kernel(
if (grad_weight != NULL) {
grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
}
mean_dy[c_offset] = sum_dy_th / reduction_size;
mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
//mean_dy[c_offset] = sum_dy_th / reduction_size;
//mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
sum_dy_o[c_offset] = sum_dy_th;
sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
}
}
}
......@@ -844,11 +858,17 @@ __global__ void batchnorm_backward_c_last_kernel(
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
const layerscalar_t* __restrict__ weight,
const accscalar_t* __restrict__ mean_dy,
const accscalar_t* __restrict__ mean_dy_xmu,
const accscalar_t* __restrict__ sum_dy,
const accscalar_t* __restrict__ sum_dy_xmu,
const int* __restrict__ numel,
scalar_t* __restrict__ grad_input,
const int64_t world_size,
const int reduction_size,
const int stride) {
int64_t div = 0;
for (int i = 0; i < world_size; i++) {
div += numel[i];
}
// tensor dimension (m,c)
// loop along m dimension
int inner_loop_stride = blockDim.y * gridDim.y;
......@@ -858,10 +878,10 @@ __global__ void batchnorm_backward_c_last_kernel(
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
auto m_c = mean[c_offset];
auto m_dy_c = mean_dy[c_offset];
auto m_dy_c = sum_dy[c_offset] / div;
auto factor_1_c = inv_std[c_offset];
auto factor_2_c = (weight == NULL? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset])) * factor_1_c;
factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[c_offset];
factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] / div;
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
int address_base = m_offset * stride + c_offset;
......@@ -986,8 +1006,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
auto scalar_type = promote_scalartype(input);
at::Tensor mean_dy = at::empty({feature_size}, mean.options());
at::Tensor mean_dy_xmu = at::empty({feature_size}, mean.options());
at::Tensor sum_dy = at::empty({feature_size}, mean.options());
at::Tensor sum_dy_xmu = at::empty({feature_size}, mean.options());
at::Tensor grad_weight;
at::Tensor grad_bias;
......@@ -1018,8 +1038,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
grad_output.DATA_PTR<scalar_t_0>(),
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sum_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
weight.has_value() ? grad_weight.DATA_PTR<accscalar_t>() : NULL,
weight.has_value() ? grad_bias.DATA_PTR<accscalar_t>() : NULL,
batch_size,
......@@ -1039,8 +1059,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
grad_output.DATA_PTR<scalar_t_0>(),
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sum_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
weight.has_value() ? grad_weight.DATA_PTR<scalar_t_0>() : NULL,
weight.has_value() ? grad_bias.DATA_PTR<scalar_t_0>() : NULL,
batch_size,
......@@ -1049,7 +1069,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
);
}
return {mean_dy, mean_dy_xmu, grad_weight, grad_bias};
return {sum_dy, sum_dy_xmu, grad_weight, grad_bias};
}
at::Tensor batchnorm_backward_CUDA(
......@@ -1058,8 +1078,9 @@ at::Tensor batchnorm_backward_CUDA(
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu) {
const at::Tensor sum_dy,
const at::Tensor sum_dy_xmu,
const at::Tensor count) {
const auto batch_size = input.size(0);
const auto feature_size = input.size(1);
......@@ -1088,9 +1109,11 @@ at::Tensor batchnorm_backward_CUDA(
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sum_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
count.DATA_PTR<int>(),
grad_input.DATA_PTR<scalar_t_0>(),
count.numel(),
space_size,
batch_size);
);
......@@ -1108,9 +1131,11 @@ at::Tensor batchnorm_backward_CUDA(
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sum_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
count.DATA_PTR<int>(),
grad_input.DATA_PTR<scalar_t_0>(),
count.numel(),
space_size,
batch_size);
);
......@@ -1121,7 +1146,7 @@ at::Tensor batchnorm_backward_CUDA(
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,
const at::Tensor var_biased,
int numel,
const at::Tensor numel,
const float eps) {
const auto world_size = mean_feature_nodes.size(0);
const auto feature_size = mean_feature_nodes.size(1);
......@@ -1142,13 +1167,13 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
welford_kernel_parallel<scalar_t_0><<<grid, block, 0, stream>>>(
mean_feature_nodes.DATA_PTR<scalar_t_0>(),
var_biased.DATA_PTR<scalar_t_0>(),
numel.DATA_PTR<int>(),
out_mean.DATA_PTR<scalar_t_0>(),
out_var.DATA_PTR<scalar_t_0>(),
inv_std.DATA_PTR<scalar_t_0>(),
world_size,
feature_size,
eps,
numel);
eps);
);
}
......@@ -1270,8 +1295,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
at::Tensor mean_dy = at::empty({stride}, mean.options());
at::Tensor mean_dy_xmu = at::empty({stride}, mean.options());
at::Tensor sumn_dy = at::empty({stride}, mean.options());
at::Tensor sum_dy_xmu = at::empty({stride}, mean.options());
at::Tensor grad_weight;
at::Tensor grad_bias;
......@@ -1310,8 +1335,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
grad_output.DATA_PTR<scalar_t_0>(),
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sumn_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
weight.has_value() ? grad_weight.DATA_PTR<accscalar_t>() : NULL,
weight.has_value() ?grad_bias.DATA_PTR<accscalar_t>() : NULL,
staging_data_ptr,
......@@ -1335,8 +1360,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
grad_output.DATA_PTR<scalar_t_0>(),
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sumn_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
weight.has_value() ? grad_weight.DATA_PTR<scalar_t_0>() : NULL,
weight.has_value() ?grad_bias.DATA_PTR<scalar_t_0>() : NULL,
staging_data_ptr,
......@@ -1346,7 +1371,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
);
}
return {mean_dy, mean_dy_xmu, grad_weight, grad_bias};
return {sumn_dy, sum_dy_xmu, grad_weight, grad_bias};
}
at::Tensor batchnorm_backward_c_last_CUDA(
......@@ -1355,8 +1380,9 @@ at::Tensor batchnorm_backward_c_last_CUDA(
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu) {
const at::Tensor sum_dy,
const at::Tensor sum_dy_xmu,
const at::Tensor count) {
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
......@@ -1380,9 +1406,11 @@ at::Tensor batchnorm_backward_c_last_CUDA(
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sum_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
count.DATA_PTR<int>(),
grad_input.DATA_PTR<scalar_t_0>(),
count.numel(),
reduction_size,
stride);
);
......@@ -1401,9 +1429,11 @@ at::Tensor batchnorm_backward_c_last_CUDA(
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sum_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
count.DATA_PTR<int>(),
grad_input.DATA_PTR<scalar_t_0>(),
count.numel(),
reduction_size,
stride);
);
......
import unittest
import os
import torch
from torch.optim import Optimizer
import apex
from apex.multi_tensor_apply import multi_tensor_applier
from apex.testing.common_utils import skipIfRocm
class RefLAMB(Optimizer):
r"""Implements Lamb algorithm.
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-6)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01)
.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(RefLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available:
import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_lamb = amp_C.multi_tensor_lamb
else:
raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
# create separate grad lists for fp32 and fp16 params
g_all_32, g_all_16 = [], []
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
if p.dtype == torch.float32:
g_all_32.append(p.grad.data)
elif p.dtype == torch.float16:
g_all_16.append(p.grad.data)
else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
g_norm_32, g_norm_16 = torch.zeros(1, device='cuda'), torch.zeros(1, device='cuda')
# compute grad norm for two lists
if len(g_all_32) > 0:
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_32], False)[0]
if len(g_all_16) > 0:
g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_16], False)[0]
# blend two grad norms to get global grad norm
global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[[g_norm_32, g_norm_16]],
False)[0]
max_grad_norm = 1.0
clipped_ratio = max_grad_norm / max(global_grad_norm, max_grad_norm)
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
p.grad.data *= clipped_ratio
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['m'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['v'] = torch.zeros_like(p.data)
m_t, v_t = state['m'], state['v']
beta1, beta2 = group['betas']
state['step'] += 1
# m_t = beta1 * m + (1 - beta1) * g_t
m_t.mul_(beta1).add_(grad, alpha=1-beta1)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v_t.mul_(beta2).addcmul_(grad, grad, value=1-beta2)
# Debiasing
m_t_hat = m_t / (1.0 - beta1 ** state['step'])
v_t_hat = v_t / (1.0 - beta2 ** state['step'])
update = m_t_hat / v_t_hat.sqrt().add(group['eps'])
if group['weight_decay'] != 0:
update.add_(p.data, alpha=group['weight_decay'])
trust_ratio = 1.0
w_norm = p.data.pow(2).sum().sqrt()
g_norm = update.pow(2).sum().sqrt()
if w_norm > 0 and g_norm > 0:
trust_ratio = w_norm / g_norm
state['w_norm'] = w_norm
state['g_norm'] = g_norm
state['trust_ratio'] = trust_ratio
step_size = group['lr']
p.data.add_(update, alpha=-step_size*trust_ratio)
return loss
class TestFusedLAMB(unittest.TestCase):
def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):
self.max_abs_diff = max_abs_diff
self.max_rel_diff = max_rel_diff
self.iters = iters
torch.cuda.manual_seed(9876)
def tearDown(self):
pass
def gen_param_optim(self, tensors, lamb_option):
ref_param = []
tst_param = []
for tensor in tensors:
ref_param.append(torch.nn.Parameter(tensor.clone()))
tst_param.append(torch.nn.Parameter(tensor.clone()))
ref_optim = RefLAMB(ref_param, **lamb_option)
tst_optim = apex.optimizers.FusedLAMB(tst_param, use_nvlamb=True, **lamb_option)
return (ref_param, tst_param, ref_optim, tst_optim)
def gen_grad(self, ref_param, tst_param):
for p_ref, p_tst in zip(ref_param, tst_param):
p_ref.grad = torch.rand_like(p_ref)
p_tst.grad = p_ref.grad
def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
half_grads = []
for p_ref, _ in zip(ref_param, tst_param):
half_grads.append(torch.rand_like(p_ref).half())
p_ref.grad = half_grads[-1].float() / scale
return half_grads
def get_max_diff(self, ref_param, tst_param):
max_abs_diff = max_rel_diff = 0
for p_ref, p_tst in zip(ref_param, tst_param):
max_abs_diff_p = (p_ref - p_tst).abs().max().item()
max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()
if max_abs_diff_p > max_abs_diff: max_abs_diff = max_abs_diff_p
if max_rel_diff_p > max_rel_diff: max_rel_diff = max_rel_diff_p
return max_abs_diff, max_rel_diff
def gen_single_type_test(self, param_type=torch.float):
nelem = 278011
tensor = torch.rand(nelem, dtype=param_type, device='cuda')
weight_decay = [0, 0.01]
for wd in weight_decay:
lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim([tensor], lamb_option)
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
@skipIfRocm
def test_float(self):
self.gen_single_type_test(param_type=torch.float)
@unittest.skip("PyTorch optimizer is not numerically correct for fp16")
def test_half(self):
self.gen_single_type_test(param_type=torch.float16)
@skipIfRocm
def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
weight_decay = [0, 0.01]
for wd in weight_decay:
lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
tensors = []
for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim(tensors, lamb_option)
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
@skipIfRocm
def test_lamb_option(self):
nelem = 1
tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
weight_decay = [0, 0.01]
for wd in weight_decay:
lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':wd}
ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim([tensor], lamb_option)
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
if __name__ == '__main__':
script_path = os.path.dirname(os.path.realpath(__file__))
unittest.main()
......@@ -35,6 +35,7 @@ inp = (np.random.randn(batch_size, feature_size, space_size, space_size)).astype
grad = (np.random.randn(batch_size, feature_size, space_size, space_size)).astype(dtype)
weight = (np.random.randn(feature_size)).astype(dtype)
bias = (np.random.randn(feature_size)).astype(dtype)
count = torch.cuda.IntTensor([batch_size*space_size**2])
type_tensor = torch.cuda.FloatTensor
ref_tensor = torch.cuda.DoubleTensor
......@@ -110,17 +111,19 @@ grad_output2_r = ref_tensor(grad)
grad_bias_r = grad_output_r.sum(1)
grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)
sum_dy_r = grad_output_r.sum(1)
mean_dy_r = grad_output_r.mean(1)
sum_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)
mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).mean(1)
grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu)
sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, sum_dy, sum_dy_xmu, count)
sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result
sbn_result = compare("comparing mean_dy grad: ", mean_dy, mean_dy_r, error) and sbn_result
sbn_result = compare("comparing mean_dy_xmu grad: ", mean_dy_xmu, mean_dy_xmu_r, error) and sbn_result
sbn_result = compare("comparing sum_dy grad: ", sum_dy, sum_dy_r, error) and sbn_result
sbn_result = compare("comparing sum_dy_xmu grad: ", sum_dy_xmu, sum_dy_xmu_r, error) and sbn_result
sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error) and sbn_result
compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error)
sbn_result = compare("comparing sbn input grad: ", inp_sbn.grad, grad_input_r, error) and sbn_result
......
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from apex.parallel import SyncBatchNorm as ApexSyncBatchNorm
import argparse
import os
import numpy as np
var_batch = 16
def compare(desc, inp1, inp2, error= 1e-5):
a = inp1.clone().detach().cpu().numpy()
b = inp2.clone().detach().cpu().numpy()
close = np.allclose(a,b, error, error)
if not close:
print(desc, close)
z = a - b
index = (np.abs(z) >= error + error * np.abs(b)).nonzero()
print("dif : ", z[index])
print("inp1 : ", a[index])
print("inp2 : ", b[index])
return close
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--apex', action='store_true')
args = parser.parse_args()
torch.manual_seed(2809)
# Setup DDP
torch.cuda.set_device(args.local_rank)
device = torch.device('cuda:{}'.format(args.local_rank))
torch.distributed.init_process_group(
'nccl',
init_method='env://',
rank=args.local_rank,
)
# Setup model
if args.apex:
model = nn.Sequential(
nn.Conv2d(3, 6, 3, 1, 1),
ApexSyncBatchNorm(6)
)
else:
model = nn.Sequential(
nn.Conv2d(3, 6, 3, 1, 1),
nn.SyncBatchNorm(6)
)
# Setup reference model
model_reference = nn.Sequential(
nn.Conv2d(3, 6, 3, 1, 1),
nn.BatchNorm2d(6)
)
with torch.no_grad():
model_reference[0].weight.copy_(model[0].weight)
model_reference[0].bias.copy_(model[0].bias)
model_reference.to(device)
model = model.to(device)
model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)
global_batch_size = var_batch + 8
# Create random data
if args.local_rank == 0:
data = torch.randn(var_batch, 3, 8, 8, device=device, dtype=torch.float) * 50.0
grad = torch.randint(0, 10, (var_batch, 6, 8, 8), device=device, dtype=torch.float) / 10.0
else:
data = torch.randn(8, 3, 8, 8, device=device)
grad = torch.randint(0, 10, (8, 6, 8, 8), device=device, dtype=torch.float) / 10.0
data.requires_grad_()
data.retain_grad = True
weighted_gradient = True
# DDP forward/backward
output = model(data)
if weighted_gradient:
output.backward(grad * 2 / global_batch_size)
else:
output.backward(grad / output.size(0))
d_list = [torch.randn(8, 3, 8, 8, device=device) for i in range(int(os.environ['WORLD_SIZE']))]
y_list = [torch.randn(8, 6, 8, 8, device=device) for i in range(int(os.environ['WORLD_SIZE']))]
dgrad_list = [torch.randn(8, 3, 8, 8, device=device) for i in range(int(os.environ['WORLD_SIZE']))]
grad_list = [torch.randn(8, 6, 8, 8, device=device) for i in range(int(os.environ['WORLD_SIZE']))]
if args.local_rank == 0:
# placeholder, these random data will later be discarded.
torch.distributed.all_gather(d_list, torch.randn(8, 3, 8, 8, device=device))
torch.distributed.all_gather(y_list, torch.randn(8, 6, 8, 8, device=device))
torch.distributed.all_gather(dgrad_list, torch.randn(8, 3, 8, 8, device=device))
torch.distributed.all_gather(grad_list, torch.randn(8, 6, 8, 8, device=device))
else:
torch.distributed.all_gather(d_list, data)
torch.distributed.all_gather(y_list, output)
torch.distributed.all_gather(dgrad_list, data.grad)
torch.distributed.all_gather(grad_list, grad)
torch.distributed.barrier()
if args.local_rank == 0:
ref_tensor = d_list[1:]
ref_tensor.insert(0, data)
assert(ref_tensor[0].equal(data))
ref_tensor = torch.cat(ref_tensor, 0)
ref_tensor = ref_tensor.detach()
ref_tensor.requires_grad_()
ref_tensor.retain_grad()
# Reference forward/backward
output_reference = model_reference(ref_tensor)
grad_tensor = grad_list[1:]
grad_tensor.insert(0, grad)
assert(grad_tensor[0].equal(grad))
grad_tensor = torch.cat(grad_tensor, 0)
if weighted_gradient:
output_reference.backward(grad_tensor / output_reference.size(0))
else:
output_reference.backward(grad_tensor / output_reference.size(0))
dgrad_tensor = dgrad_list[1:]
dgrad_tensor.insert(0, data.grad)
dgrad_tensor = torch.cat(dgrad_tensor, 0)
# check output
output_tensor = y_list[1:]
output_tensor.insert(0, output)
output_tensor = torch.cat(output_tensor, 0)
passed = True
passed = passed and compare("check output",
output_tensor,
output_reference)
# check stats
passed = passed and compare("check running mean failed",
model_reference[1].running_mean,
model.module[1].running_mean)
passed = passed and compare("check running var failed",
model_reference[1].running_var,
model.module[1].running_var)
passed = passed and compare("bn wgrad check failed!",
model_reference[1].weight.grad,
model.module[1].weight.grad, 1e-6)
passed = passed and compare("conv wgrad check failed!",
model_reference[0].weight.grad,
model.module[0].weight.grad)
# can't really compare dgrad directly, as we need to scale it to account for
# DDP
# passed = passed and compare("dgrad check failed!", ref_tensor.grad, dgrad_tensor)
if passed:
print("====SBN two gpu with different batches test passed")
else:
assert("*failed two gpu with different batches tests*")
......@@ -114,6 +114,11 @@ grad_sbn = grad_output_t.clone().detach()
out_sbn = sbn(inp_sbn[start:finish])
out_sbn.backward(grad_sbn[start:finish])
count = [ space_size**2 * ( (i+1) * batch_size // args.world_size - i * batch_size // args.world_size ) for i in range(0, args.world_size)]
count = torch.cuda.IntTensor(count)
print("--- count : " , count)
sbn_result = True
bn_result = True
......@@ -136,18 +141,20 @@ grad_output2_r = ref_tensor(grad)
grad_bias_r = grad_output_r.sum(1)
grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)
sum_dy_r = grad_output_r.sum(1)
mean_dy_r = grad_output_r.mean(1)
mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).mean(1)
sum_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)
grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu)
sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, sum_dy, sum_dy_xmu, count)
if args.local_rank == 0:
sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result
sbn_result = compare("comparing mean_dy grad: ", mean_dy, mean_dy_r, error) and sbn_result
sbn_result = compare("comparing mean_dy_xmu grad: ", mean_dy_xmu, mean_dy_xmu_r, error) and sbn_result
sbn_result = compare("comparing sum_dy grad: ", sum_dy, sum_dy_r, error) and sbn_result
sbn_result = compare("comparing sum_dy_xmu grad: ", sum_dy_xmu, sum_dy_xmu_r, error) and sbn_result
sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error) and sbn_result
compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error)
......
......@@ -3,5 +3,6 @@ python single_gpu_unit_test.py
python test_batchnorm1d.py
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py --fp16
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_test_different_batch_size.py --apex
#beware, you need a system with at least 4 gpus to test group_size<world_size
#python -m torch.distributed.launch --nproc_per_node=4 test_groups.py --group_size=2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment