Commit 4d6ed501 authored by Deyu Fu's avatar Deyu Fu
Browse files

Merge branch 'multi_tensor_sgd' into deyuf/fused_optimizer_v2

parents 690b1f71 9f64bf27
...@@ -124,29 +124,13 @@ def check_optimizers(optimizers): ...@@ -124,29 +124,13 @@ def check_optimizers(optimizers):
raise RuntimeError("An incoming optimizer is an instance of {}. ".format(bad_optim_type) + raise RuntimeError("An incoming optimizer is an instance of {}. ".format(bad_optim_type) +
"The optimizer(s) passed to amp.initialize() must be bare \n" "The optimizer(s) passed to amp.initialize() must be bare \n"
"instances of either ordinary Pytorch optimizers, or Apex fused \n" "instances of either ordinary Pytorch optimizers, or Apex fused \n"
"optimizers (currently just FusedAdam, but FusedSGD will be added \n" "optimizers (FusedAdam or FusedSGD). \n"
"soon). You should not manually wrap your optimizer in either \n" "You should not manually wrap your optimizer in either \n"
"apex.fp16_utils.FP16_Optimizer or apex.optimizers.FP16_Optimizer. \n" "apex.fp16_utils.FP16_Optimizer or apex.optimizers.FP16_Optimizer. \n"
"amp.initialize will take care of that for you (if necessary) based \n" "amp.initialize will take care of that for you (if necessary) based \n"
"on the specified opt_level (and optional overridden properties).") "on the specified opt_level (and optional overridden properties).")
def wrap_fused_adam(optimizer, properties):
msg = 'Currently, the usage of FusedAdam is restricted to '\
'amp.initialize(..., opt_level="O2", keep_batchnorm_fp32=False, '\
'loss_scale=float or "dynamic"). We are working on enabling more general usage.'
assert properties.master_weights is True, msg
assert properties.cast_model_type is torch.float16, msg
assert (properties.keep_batchnorm_fp32 is False or
properties.keep_batchnorm_fp32 is None), msg
if properties.loss_scale == "dynamic":
return FP16_Optimizer_for_fused(optimizer, dynamic_loss_scale=True)
else:
return FP16_Optimizer_for_fused(optimizer, static_loss_scale=properties.loss_scale)
def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None): def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None):
from apex.parallel import DistributedDataParallel as apex_DDP from apex.parallel import DistributedDataParallel as apex_DDP
from .amp import init as amp_init from .amp import init as amp_init
...@@ -176,7 +160,6 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs ...@@ -176,7 +160,6 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
if not _amp_state.allow_incoming_model_not_fp32: if not _amp_state.allow_incoming_model_not_fp32:
check_params_fp32(models) check_params_fp32(models)
# In the future, when FP16_Optimizer can be deprecated and master weights can # In the future, when FP16_Optimizer can be deprecated and master weights can
# become an attribute, remember to stash master weights before casting the model. # become an attribute, remember to stash master weights before casting the model.
...@@ -223,10 +206,6 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs ...@@ -223,10 +206,6 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
model.forward = patch_forward(model.forward) model.forward = patch_forward(model.forward)
for i, optimizer in enumerate(optimizers): for i, optimizer in enumerate(optimizers):
# Still need to special case this for the first pass
if isinstance(optimizer, FusedAdam):
optimizers[i] = wrap_fused_adam(optimizer, properties)
else:
optimizers[i] = _process_optimizer(optimizer, properties) optimizers[i] = _process_optimizer(optimizer, properties)
_amp_state.loss_scalers = [] _amp_state.loss_scalers = []
......
...@@ -3,6 +3,7 @@ from ..fp16_utils import master_params_to_model_params ...@@ -3,6 +3,7 @@ from ..fp16_utils import master_params_to_model_params
from ..multi_tensor_apply import multi_tensor_applier from ..multi_tensor_apply import multi_tensor_applier
from ._amp_state import maybe_print from ._amp_state import maybe_print
import torch import torch
from ..optimizers import FusedAdam, FusedSGD
class AmpOptimizerState(object): class AmpOptimizerState(object):
...@@ -10,6 +11,20 @@ class AmpOptimizerState(object): ...@@ -10,6 +11,20 @@ class AmpOptimizerState(object):
pass pass
def _master_params_to_model_params(self):
stash = self._amp_stash
if multi_tensor_applier.available:
if len(stash.all_fp16_params) > 0:
multi_tensor_applier(
stash.multi_tensor_scale,
stash.dummy_overflow_buf,
[stash.all_fp32_from_fp16_params, stash.all_fp16_params],
1.0)
else:
for fp16_group, fp32_from_fp16_group in zip(stash.fp16_groups, stash.fp32_from_fp16_groups):
master_params_to_model_params(fp16_group, fp32_from_fp16_group)
def lazy_init_with_master_weights(self): def lazy_init_with_master_weights(self):
stash = self._amp_stash stash = self._amp_stash
stash.fp16_groups = [] stash.fp16_groups = []
...@@ -60,6 +75,8 @@ def lazy_init_with_master_weights(self): ...@@ -60,6 +75,8 @@ def lazy_init_with_master_weights(self):
for group in stash.fp32_from_fp32_groups: for group in stash.fp32_from_fp32_groups:
stash.all_fp32_from_fp32_params += group stash.all_fp32_from_fp32_params += group
# all_fp16_grad_stash is only needed for fused optimizers.
stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
# stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params] # stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
stash.all_fp32_from_fp32_grad_stash = [None for _ in stash.all_fp32_from_fp32_params] stash.all_fp32_from_fp32_grad_stash = [None for _ in stash.all_fp32_from_fp32_params]
...@@ -73,15 +90,55 @@ def lazy_init_with_master_weights(self): ...@@ -73,15 +90,55 @@ def lazy_init_with_master_weights(self):
self.load_state_dict(self.state_dict()) self.load_state_dict(self.state_dict())
def post_backward_models_are_masters(scaler, params, stashed_grads, scale_override=None):
grads_have_scale, stashed_have_scale, out_scale = scaler.loss_scale(), 1.0, 1.0
if scale_override is not None:
grads_have_scale, stashed_have_scale, out_scale = scale_override
# This is a lot of python overhead...
grads_needing_unscale = []
grads_needing_unscale_with_stash = []
stashed = []
for param, stashed_grad in zip(params, stashed_grads):
if param.grad is None and stashed_grad is not None:
param.grad = stashed_grad
elif param.grad is not None and stashed_grad is None:
grads_needing_unscale.append(param.grad)
elif param.grad is not None and stashed_grad is not None:
grads_needing_unscale_with_stash.append(param.grad)
stashed.append(stashed_grad)
else: # param.grad is None and stashed_grad is None
continue
# unscale() implements grads*(1/scale), so "scale" should be grads_have_scale/out_scale.
if len(grads_needing_unscale) > 0:
scaler.unscale(
grads_needing_unscale,
grads_needing_unscale,
None, # unused_scale, currently present to avoid API breakage elsewhere
models_are_masters=True,
scale_override=grads_have_scale/out_scale)
if len(grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed(
grads_needing_unscale_with_stash,
stashed,
grads_needing_unscale_with_stash,
scale_override=(grads_have_scale, stashed_have_scale, out_scale))
# Clear the stash.
for i in range(len(stashed_grads)):
stashed_grads[i] = None
def prepare_backward_with_master_weights(self): def prepare_backward_with_master_weights(self):
stash = self._amp_stash stash = self._amp_stash
if not stash.lazy_init_called: self._amp_lazy_init()
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
for i, param in enumerate(stash.all_fp16_params): for i, param in enumerate(stash.all_fp16_params):
# Set up to leverage grad copy elision: # Set up to leverage grad copy elision.
# This may behave differently from an unpatched optimizer if zero_grad is used and the param is unused.
param.grad = None param.grad = None
# for i, param in enumerate(stash.all_fp32_from_fp16_params): # for i, param in enumerate(stash.all_fp32_from_fp16_params):
...@@ -96,6 +153,8 @@ def prepare_backward_with_master_weights(self): ...@@ -96,6 +153,8 @@ def prepare_backward_with_master_weights(self):
def post_backward_with_master_weights(self, scaler): def post_backward_with_master_weights(self, scaler):
stash = self._amp_stash stash = self._amp_stash
self._amp_lazy_init()
# This is a lot of python overhead... # This is a lot of python overhead...
fp16_grads_needing_unscale = [] fp16_grads_needing_unscale = []
new_fp32_grads = [] new_fp32_grads = []
...@@ -129,37 +188,10 @@ def post_backward_with_master_weights(self, scaler): ...@@ -129,37 +188,10 @@ def post_backward_with_master_weights(self, scaler):
preexisting_fp32_grads) preexisting_fp32_grads)
# fp32 params can be treated as they would be in the "no_master_weights" case. # fp32 params can be treated as they would be in the "no_master_weights" case.
grads_needing_unscale = [] post_backward_models_are_masters(
grads_needing_unscale_with_stash = [] scaler,
stashed = [] stash.all_fp32_from_fp32_params,
for param, stashed_grad in zip(stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash)
stash.all_fp32_from_fp32_grad_stash):
if param.grad is None and stashed_grad is not None:
param.grad = stashed_grad
elif param.grad is not None and stashed_grad is None:
grads_needing_unscale.append(param.grad)
elif param.grad is not None and stashed_grad is not None:
grads_needing_unscale_with_stash.append(param.grad)
stashed.append(stashed_grad)
else: # param.grad is None and stashed_grad is None:
continue
if len(grads_needing_unscale) > 0:
scaler.unscale(
grads_needing_unscale,
grads_needing_unscale,
scaler.loss_scale(),
models_are_masters=True)
if len(grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed(
grads_needing_unscale_with_stash,
stashed,
grads_needing_unscale_with_stash)
# Clear the stash.
for i in range(len(stash.all_fp32_from_fp32_grad_stash)):
stash.all_fp32_from_fp32_grad_stash[i] = None
def lazy_init_no_master_weights(self): def lazy_init_no_master_weights(self):
...@@ -184,9 +216,7 @@ def lazy_init_no_master_weights(self): ...@@ -184,9 +216,7 @@ def lazy_init_no_master_weights(self):
def prepare_backward_no_master_weights(self): def prepare_backward_no_master_weights(self):
stash = self._amp_stash stash = self._amp_stash
if not stash.lazy_init_called: self._amp_lazy_init()
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
for i, param in enumerate(stash.all_fp16_params): for i, param in enumerate(stash.all_fp16_params):
stash.all_fp16_grad_stash[i] = param.grad stash.all_fp16_grad_stash[i] = param.grad
...@@ -202,55 +232,141 @@ def prepare_backward_no_master_weights(self): ...@@ -202,55 +232,141 @@ def prepare_backward_no_master_weights(self):
def post_backward_no_master_weights(self, scaler): def post_backward_no_master_weights(self, scaler):
stash = self._amp_stash stash = self._amp_stash
self._amp_lazy_init()
split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash), split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
(stash.all_fp32_params, stash.all_fp32_grad_stash)) (stash.all_fp32_params, stash.all_fp32_grad_stash))
for params, stashed_grads in split_types: for params, stashed_grads in split_types:
# This is a lot of python overhead... post_backward_models_are_masters(scaler, params, stashed_grads)
grads_needing_unscale = []
grads_needing_unscale_with_stash = []
stashed = []
for param, stashed_grad in zip(params, stashed_grads):
if param.grad is None and stashed_grad is not None:
param.grad = stashed_grad
elif param.grad is not None and stashed_grad is None:
grads_needing_unscale.append(param.grad)
elif param.grad is not None and stashed_grad is not None:
grads_needing_unscale_with_stash.append(param.grad)
stashed.append(stashed_grad)
else: # param.grad is None and stashed_grad is None
continue
if len(grads_needing_unscale) > 0:
scaler.unscale(
grads_needing_unscale,
grads_needing_unscale,
scaler.loss_scale(),
models_are_masters=True)
if len(grads_needing_unscale_with_stash) > 0: #####################################################################################
scaler.unscale_with_stashed( # FusedAdam versions
grads_needing_unscale_with_stash, #####################################################################################
stashed,
grads_needing_unscale_with_stash)
# Clear the stash. def prepare_backward_with_master_weights_FusedAdam(self):
for i in range(len(stashed_grads)): stash = self._amp_stash
stashed_grads[i] = None
self._amp_lazy_init()
def _master_params_to_model_params(self):
def post_backward_with_master_weights_FusedAdam(self, scaler):
stash = self._amp_stash stash = self._amp_stash
if multi_tensor_applier.available:
if len(stash.all_fp16_params) > 0: self._amp_lazy_init()
multi_tensor_applier(
stash.multi_tensor_scale, stash.scale = scaler.loss_scale()
stash.grads = [[param.grad.data for param in group] for group in stash.fp16_groups]
stash.output_params = [[param for param in group] for group in stash.fp16_groups]
norm_groups = []
skip = False
for grad_group in stash.grads:
norm, _ = multi_tensor_applier(
stash.multi_tensor_l2norm,
stash.dummy_overflow_buf, stash.dummy_overflow_buf,
[stash.all_fp32_from_fp16_params, stash.all_fp16_params], [grad_group],
1.0) False)
# Still syncing here for now.
norm = float(norm)
norm_groups.append(norm)
if norm == float('inf') or norm == -float('inf') or norm != norm:
skip = True
if skip:
scaler._overflow_buf.fill_(1.)
scaler._has_overflow = True
stash.grad_norms = norm_groups
def prepare_backward_no_master_weights_FusedAdam(self):
stash = self._amp_stash
self._amp_lazy_init()
def post_backward_no_master_weights_FusedAdam(self, scaler):
stash = self._amp_stash
self._amp_lazy_init()
stash.scale = scaler.loss_scale()
stash.grads = None
stash.output_params = None
stash.grad_norms = None
#####################################################################################
# FusedSGD versions
# Eat this ugly code duplication for now. First make it work, then make it clean.
# It's difficult to anticipate what can be unified between the FusedAdam and FusedSGD
# implementations until I have them both working.
#####################################################################################
# FusedSGD never explicitly materializes the fp32 gradients for "fp32 from fp16" master params
# outside the kernel, so we must accumulate directly into the model grads.
def prepare_backward_with_master_weights_FusedSGD(self):
if self.materialize_master_grads:
prepare_backward_with_master_weights(self)
else: else:
for fp16_group, fp32_from_fp16_group in zip(stash.fp16_groups, stash.fp32_from_fp16_groups): stash = self._amp_stash
master_params_to_model_params(fp16_group, fp32_from_fp16_group)
self._amp_lazy_init()
for i, param in enumerate(stash.all_fp16_params):
stash.all_fp16_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
for i, param in enumerate(stash.all_fp32_from_fp32_params):
stash.all_fp32_from_fp32_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
def post_backward_with_master_weights_FusedSGD(self, scaler):
if self.materialize_master_grads:
post_backward_with_master_weights(self, scaler)
else:
stash = self._amp_stash
self._amp_lazy_init()
grads_have_scale = scaler.loss_scale()
stashed_have_scale = self.most_recent_scale
out_scale = grads_have_scale
if self.scale_set_by_backward:
out_scale = min(grads_have_scale, self.most_recent_scale)
split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
(stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash))
# unscale_with_stashed() implements grads*1/scale + stashed_grads*1.
# stashed_grads are scaled by self.most_recent_scale.
for params, stashed_grads in split_types:
post_backward_models_are_masters(scaler, params, stashed_grads,
(grads_have_scale, stashed_have_scale, out_scale))
self.most_recent_scale = out_scale
self.scale_set_by_backward = True
def prepare_backward_no_master_weights_FusedSGD(self):
prepare_backward_no_master_weights(self)
def post_backward_no_master_weights_FusedSGD(self, scaler):
post_backward_no_master_weights(self, scaler)
def _amp_lazy_init(self):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
def _process_optimizer(optimizer, properties): def _process_optimizer(optimizer, properties):
...@@ -266,7 +382,8 @@ def _process_optimizer(optimizer, properties): ...@@ -266,7 +382,8 @@ def _process_optimizer(optimizer, properties):
for name in ("_lazy_init_maybe_master_weights", for name in ("_lazy_init_maybe_master_weights",
"_master_params_to_model_params", "_master_params_to_model_params",
"_prepare_amp_backward", "_prepare_amp_backward",
"_post_amp_backward"): "_post_amp_backward",
"_amp_lazy_init"):
if hasattr(optimizer, name): if hasattr(optimizer, name):
raise RuntimeError("Incoming optimizer already has {} defined.".format(name)) raise RuntimeError("Incoming optimizer already has {} defined.".format(name))
...@@ -274,6 +391,7 @@ def _process_optimizer(optimizer, properties): ...@@ -274,6 +391,7 @@ def _process_optimizer(optimizer, properties):
if multi_tensor_applier.available: if multi_tensor_applier.available:
import amp_C import amp_C
optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale
optimizer._amp_stash.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
optimizer._amp_stash.dummy_overflow_buf = torch.cuda.IntTensor([0]); optimizer._amp_stash.dummy_overflow_buf = torch.cuda.IntTensor([0]);
if properties.master_weights: if properties.master_weights:
...@@ -288,6 +406,7 @@ def _process_optimizer(optimizer, properties): ...@@ -288,6 +406,7 @@ def _process_optimizer(optimizer, properties):
if closure is not None: if closure is not None:
raise RuntimeError("Currently, Amp does not support closure use with optimizers.") raise RuntimeError("Currently, Amp does not support closure use with optimizers.")
retval = old_step() retval = old_step()
if not (isinstance(self, FusedAdam) or isinstance(self, FusedSGD)):
self._master_params_to_model_params() self._master_params_to_model_params()
# Clear the master grads that wouldn't be zeroed by model.zero_grad() # Clear the master grads that wouldn't be zeroed by model.zero_grad()
for param in self._amp_stash.all_fp32_from_fp16_params: for param in self._amp_stash.all_fp32_from_fp16_params:
...@@ -298,9 +417,7 @@ def _process_optimizer(optimizer, properties): ...@@ -298,9 +417,7 @@ def _process_optimizer(optimizer, properties):
old_zero_grad = optimizer.zero_grad old_zero_grad = optimizer.zero_grad
def new_zero_grad(self): def new_zero_grad(self):
stash = self._amp_stash stash = self._amp_stash
if not stash.lazy_init_called: self._amp_lazy_init()
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
# Zero the model grads. # Zero the model grads.
for param in stash.all_fp16_params: for param in stash.all_fp16_params:
if param.grad is not None: if param.grad is not None:
...@@ -315,21 +432,43 @@ def _process_optimizer(optimizer, properties): ...@@ -315,21 +432,43 @@ def _process_optimizer(optimizer, properties):
param.grad = None param.grad = None
optimizer.zero_grad = types.MethodType(new_zero_grad, optimizer) optimizer.zero_grad = types.MethodType(new_zero_grad, optimizer)
if isinstance(optimizer, FusedAdam):
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_with_master_weights_FusedAdam, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_with_master_weights_FusedAdam, optimizer)
elif isinstance(optimizer, FusedSGD):
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_with_master_weights_FusedSGD, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_with_master_weights_FusedSGD, optimizer)
else:
optimizer._prepare_amp_backward = types.MethodType( optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_with_master_weights, optimizer) prepare_backward_with_master_weights, optimizer)
optimizer._post_amp_backward = types.MethodType( optimizer._post_amp_backward = types.MethodType(
post_backward_with_master_weights, optimizer) post_backward_with_master_weights, optimizer)
else: else:
optimizer._lazy_init_maybe_master_weights = types.MethodType( optimizer._lazy_init_maybe_master_weights = types.MethodType(
lazy_init_no_master_weights, optimizer) lazy_init_no_master_weights, optimizer)
if isinstance(optimizer, FusedAdam):
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_no_master_weights_FusedAdam, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_no_master_weights_FusedAdam, optimizer)
elif isinstance(optimizer, FusedSGD):
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_no_master_weights_FusedSGD, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_no_master_weights_FusedSGD, optimizer)
else:
optimizer._prepare_amp_backward = types.MethodType( optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_no_master_weights, optimizer) prepare_backward_no_master_weights, optimizer)
optimizer._post_amp_backward = types.MethodType( optimizer._post_amp_backward = types.MethodType(
post_backward_no_master_weights, optimizer) post_backward_no_master_weights, optimizer)
optimizer._amp_lazy_init = types.MethodType(_amp_lazy_init, optimizer)
old_add_param_group = optimizer.add_param_group old_add_param_group = optimizer.add_param_group
def new_add_param_group(self, new_group): def new_add_param_group(self, new_group):
......
...@@ -6,8 +6,6 @@ from . import utils ...@@ -6,8 +6,6 @@ from . import utils
from .opt import OptimWrapper from .opt import OptimWrapper
from .scaler import LossScaler from .scaler import LossScaler
from ._amp_state import _amp_state, master_params, maybe_print from ._amp_state import _amp_state, master_params, maybe_print
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
from ..parallel.LARC import LARC from ..parallel.LARC import LARC
...@@ -89,11 +87,6 @@ def scale_loss(loss, ...@@ -89,11 +87,6 @@ def scale_loss(loss,
if isinstance(optimizers, torch.optim.Optimizer) or isinstance(optimizers, LARC): if isinstance(optimizers, torch.optim.Optimizer) or isinstance(optimizers, LARC):
optimizers = [optimizers] optimizers = [optimizers]
# this is what happens when i have to support tools from different sources under the same API...
# TODO: Rewrite FusedAdam to use multi-tensor apply and the same loss scaler.
if isinstance(optimizers, FP16_Optimizer_for_fused):
loss_scale = optimizers.cur_scale
else:
loss_scaler = _amp_state.loss_scalers[loss_id] loss_scaler = _amp_state.loss_scalers[loss_id]
loss_scale = loss_scaler.loss_scale() loss_scale = loss_scaler.loss_scale()
...@@ -120,8 +113,8 @@ def scale_loss(loss, ...@@ -120,8 +113,8 @@ def scale_loss(loss,
for optimizer in optimizers: for optimizer in optimizers:
optimizer._amp_stash.params_have_scaled_gradients = True optimizer._amp_stash.params_have_scaled_gradients = True
else: else:
# FusedAdam and FusedSGD will take care of unscaling as part of their step() methods. # FusedAdam and FusedSGD may take care of unscaling as part of their step() methods.
if not isinstance(optimizers, FP16_Optimizer_for_fused): # if not isinstance(optimizers, FP16_Optimizer_for_fused):
loss_scaler.clear_overflow_state() loss_scaler.clear_overflow_state()
for optimizer in optimizers: for optimizer in optimizers:
optimizer._post_amp_backward(loss_scaler) optimizer._post_amp_backward(loss_scaler)
...@@ -142,10 +135,15 @@ def scale_loss(loss, ...@@ -142,10 +135,15 @@ def scale_loss(loss,
maybe_print(("Gradient overflow. Skipping step, loss scaler " + maybe_print(("Gradient overflow. Skipping step, loss scaler " +
"{} reducing loss scale to {}").format(loss_id, "{} reducing loss scale to {}").format(loss_id,
loss_scaler.loss_scale())) loss_scaler.loss_scale()))
# TODO: I don't like the special casing for different optimizer implementations.
# Maybe skip should delegate to a method owned by the optimizers themselves.
if hasattr(opt._amp_stash, "all_fp32_from_fp16_params"): if hasattr(opt._amp_stash, "all_fp32_from_fp16_params"):
# Clear the master grads that wouldn't be zeroed by model.zero_grad() # Clear the master grads that wouldn't be zeroed by model.zero_grad()
for param in opt._amp_stash.all_fp32_from_fp16_params: for param in opt._amp_stash.all_fp32_from_fp16_params:
param.grad = None param.grad = None
if hasattr(opt, "most_recent_scale"):
opt.most_recent_scale = 1.0
opt.scale_set_by_backward = False
opt.step = opt_step opt.step = opt_step
opt._amp_stash.already_patched = False opt._amp_stash.already_patched = False
return skip_step return skip_step
......
...@@ -16,7 +16,7 @@ def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=F ...@@ -16,7 +16,7 @@ def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=F
master_grad.mul_(scale) master_grad.mul_(scale)
return False return False
def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, scale, check_overflow=False): def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, check_overflow=False):
# Exception handling for 18.04 compatibility # Exception handling for 18.04 compatibility
if check_overflow: if check_overflow:
cpu_sum = float(model_grad.float().sum()) cpu_sum = float(model_grad.float().sum())
...@@ -26,9 +26,8 @@ def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, scale, ch ...@@ -26,9 +26,8 @@ def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, scale, ch
# if master_grad is not model_grad: # copy_ probably internally short-circuits this # if master_grad is not model_grad: # copy_ probably internally short-circuits this
# master_grad.copy_(model_grad) # master_grad.copy_(model_grad)
assert stashed_grad.dtype == master_grad.dtype assert stashed_grad.dtype == master_grad.dtype
converted_model_grad = model_grad.to(master_grad.dtype) converted_model_grad = model_grad.data.to(master_grad.dtype)
stashed_grad.add_(scale, converted_model_grad) master_grad.data = a*converted_model_grad.data + b*stashed_grad.data
master_grad.data = stashed_grad.data
return False return False
class LossScaler(object): class LossScaler(object):
...@@ -92,11 +91,13 @@ class LossScaler(object): ...@@ -92,11 +91,13 @@ class LossScaler(object):
break break
# unused_scale keeps some of the old API alive for hopefully a short time. # unused_scale keeps some of the old API alive for hopefully a short time.
def unscale(self, model_grads, master_grads, unused_scale, models_are_masters=False): def unscale(self, model_grads, master_grads, unused_scale, models_are_masters=False, scale_override=None):
if self._has_overflow: if self._has_overflow:
return return
scale = self._loss_scale scale = self._loss_scale
if scale_override is not None:
scale = scale_override
if scale == 1.0 and models_are_masters and not self.dynamic: if scale == 1.0 and models_are_masters and not self.dynamic:
return return
...@@ -126,7 +127,8 @@ class LossScaler(object): ...@@ -126,7 +127,8 @@ class LossScaler(object):
model_grads, model_grads,
stashed_master_grads, stashed_master_grads,
master_grads, master_grads,
scale): a,
b):
for model, stashed, master in zip(model_grads, stashed_master_grads, master_grads): for model, stashed, master in zip(model_grads, stashed_master_grads, master_grads):
if model is None and stashed is None: if model is None and stashed is None:
continue continue
...@@ -141,7 +143,8 @@ class LossScaler(object): ...@@ -141,7 +143,8 @@ class LossScaler(object):
self._has_overflow = axpby_check_overflow_python(model, self._has_overflow = axpby_check_overflow_python(model,
stashed, stashed,
master, master,
1./scale, a,
b,
self.dynamic) self.dynamic)
if self._has_overflow and self.dynamic: if self._has_overflow and self.dynamic:
break break
...@@ -149,11 +152,14 @@ class LossScaler(object): ...@@ -149,11 +152,14 @@ class LossScaler(object):
def unscale_with_stashed(self, def unscale_with_stashed(self,
model_grads, model_grads,
stashed_master_grads, stashed_master_grads,
master_grads): master_grads,
scale_override=None):
if self._has_overflow: if self._has_overflow:
return return
scale = self._loss_scale grads_have_scale, stashed_have_scale, out_scale = self._loss_scale, 1.0, 1.0
if scale_override is not None:
grads_have_scale, stashed_have_scale, out_scale = scale_override
if LossScaler.has_fused_kernel: if LossScaler.has_fused_kernel:
if (not LossScaler.warned_unscaling_non_fp32_grad if (not LossScaler.warned_unscaling_non_fp32_grad
...@@ -167,14 +173,15 @@ class LossScaler(object): ...@@ -167,14 +173,15 @@ class LossScaler(object):
multi_tensor_applier(LossScaler.multi_tensor_axpby_cuda, multi_tensor_applier(LossScaler.multi_tensor_axpby_cuda,
self._overflow_buf, self._overflow_buf,
[model_grads, stashed_master_grads, master_grads], [model_grads, stashed_master_grads, master_grads],
1./scale, out_scale/grads_have_scale, # 1./scale,
1.0, out_scale/stashed_have_scale, # 1.0,
0) # check only arg 0, aka the incoming model grads, for infs 0) # check only arg 0, aka the incoming model grads, for infs
else: else:
self.unscale_with_stashed_python(model_grads, self.unscale_with_stashed_python(model_grads,
stashed_master_grads, stashed_master_grads,
master_grads, master_grads,
scale) out_scale/grads_have_scale,
out_scale/stashed_have_scale)
# Defer to update_scale # Defer to update_scale
# If the fused kernel is available, we only need one D2H memcopy and sync. # If the fused kernel is available, we only need one D2H memcopy and sync.
......
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include "THC/THC.h"
#include "batch_norm.h"
#include <cuda.h>
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
static size_t round_up_to_multiple(size_t x, int multiple) {
return ((x + multiple - 1) / multiple) * multiple;
}
// TODO: Stop manually allocating CUDA memory; allocate an ATen byte
// tensor instead.
struct Workspace {
Workspace(size_t size) : size(size), data(NULL) {
data = THCudaMalloc(at::globalContext().lazyInitCUDA(), size);
}
Workspace(const Workspace&) = delete;
Workspace(Workspace&&) = default;
Workspace& operator=(Workspace&&) = default;
~Workspace() {
if (data) {
THCudaFree(at::globalContext().lazyInitCUDA(), data);
}
}
size_t size;
void* data;
};
// Return {y}
at::Tensor nhwc_bn_fwd_train(
const at::Tensor& x,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
const bool fuse_relu,
void * my_data,
void * pair_data,
void * pair_data2,
void * pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int occupancy,
const int grid_dim_x,
const bool coop) {
const int N = x.size(0);
const int H = x.size(1);
const int W = x.size(2);
const int C = x.size(3);
// generating new magic number and use that for sync
int* magic = magic_tensor.data<int>();
*magic = (*magic + 1) & 0xff;
// Allocate output tensor
at::Tensor y = at::empty({N, H, W, C}, x.options());
// Create wrapper
NhwcBatchNorm *bn = new NhwcBatchNorm();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.data<at::Half>(),
nullptr,
y.data<at::Half>(),
nullptr);
bn->setWeightPointers({scale.data<float>(), bias.data<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.data<float>(), running_inv_var.data<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 3; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.data<float>());
workspace.push_back(minibatch_inv_var.data<float>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = ret_cta.data<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
for (auto index = 3; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
// Don't fuse in ReLU for now at least
bn->fwd(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);
return y;
}
at::Tensor nhwc_bn_fwd_eval(
const at::Tensor& x,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& ret_cta,
const int bn_group,
const float momentum,
const float epsilon,
const bool fuse_relu) {
const int N = x.size(0);
const int H = x.size(1);
const int W = x.size(2);
const int C = x.size(3);
// Allocate output tensor
at::Tensor y = at::empty({N, H, W, C}, x.options());
// Create wrapper
NhwcBatchNorm *bn = new NhwcBatchNorm();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.data<at::Half>(),
nullptr,
y.data<at::Half>(),
nullptr);
bn->setWeightPointers({scale.data<float>(), bias.data<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.data<float>(), running_inv_var.data<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 3; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(nullptr);
workspace.push_back(nullptr);
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = ret_cta.data<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
for (auto index = 3; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
// Don't fuse in ReLU for now at least
bn->fwdInference(stream, fuse_relu);
return y;
}
std::vector<at::Tensor> nhwc_bn_bwd(
const at::Tensor& x,
const at::Tensor& dy,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
const bool fuse_relu,
void * my_data,
void * pair_data,
void * pair_data2,
void * pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int occupancy,
const int grid_dim_x,
const bool coop) {
// shape
const int N = x.size(0);
const int H = x.size(1);
const int W = x.size(2);
const int C = x.size(3);
// generating new magic number and use that for sync
int* magic = magic_tensor.data<int>();
*magic = (*magic + 1) & 0xff;
// outputs
at::Tensor x_grad, scale_grad, bias_grad;
// Allocate outputs
x_grad = at::empty_like(x);
scale_grad = at::empty_like(scale);
bias_grad = at::empty_like(bias);
// Create wrapper
NhwcBatchNorm *bn = new NhwcBatchNorm();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.data<at::Half>(),
x_grad.data<at::Half>(),
nullptr,
dy.data<at::Half>());
bn->setWeightPointers({scale.data<float>(), bias.data<float>()}, {scale_grad.data<float>(), bias_grad.data<float>()});
bn->setParameterPointers({running_mean.data<float>(), running_inv_var.data<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 3; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.data<float>());
workspace.push_back(minibatch_inv_var.data<float>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = ret_cta.data<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
for (auto index = 3; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
bn->dgrad(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);
return std::vector<at::Tensor>{x_grad, scale_grad, bias_grad};
}
int nhwc_bn_fwd_occupancy() {
int device_id=-1;
cudaGetDevice(&device_id);
//max occupancy supported by the code is 2
return NhwcBatchNorm::smem_driven_fwd_occupancy(device_id, 2);
}
int nhwc_bn_bwd_occupancy() {
int device_id=-1;
cudaGetDevice(&device_id);
//max occupancy supported by the code is 2
return NhwcBatchNorm::smem_driven_bwd_occupancy(device_id, 2);
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nhwc_batch_norm.h
* \brief CUDA NHWC Batch Normalization code
* \author Shankara Rao Thejaswi Nanditale, Dick Carter, Evgeni Krimer
*/
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
#include <cudnn.h>
#include <algorithm>
#include <vector>
#include <string>
#include "nhwc_batch_norm_kernel.h"
#include "cuda_utils.h"
#define VERBOSE_DEFAULT false
class NhwcBatchNorm {
public:
NhwcBatchNorm() {
name_ = "nhwc_batchnorm";
createTensorDescriptor(&X_tensor_desc_);
createTensorDescriptor(&Y_tensor_desc_);
}
~NhwcBatchNorm() {
destroyTensorDescriptor(X_tensor_desc_);
destroyTensorDescriptor(Y_tensor_desc_);
}
void die() {
std::cerr << "batchnorm not initialized" << std::endl;
exit(-1);
}
void fwd(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);
void dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);
void fwdInference(cudaStream_t stream, bool use_relu);
dim3 calc_fwd_grid(int *loop, const int grid_dim_x);
dim3 calc_bwd_grid(int *loop, const int grid_dim_x);
void setInputDescriptor(const cudnnTensorFormat_t format,
const cudnnDataType_t data_type,
int n, int c, int h, int w, int bn_group) {
m_ = n * h * w;
int m_bn_adjusted = m_ * bn_group;
c_ = c;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
svar_inv_count_ = 1.f / m_bn_adjusted;
// factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1).
int divisor = m_bn_adjusted - 1;
// nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs.
rvar_inv_count_ = divisor == 0 ? 1.f : 1.f / divisor;
setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w);
}
void setOutputDescriptor(const cudnnTensorFormat_t format,
const cudnnDataType_t data_type,
int n, int c, int h, int w) {
setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w);
}
const std::vector<size_t> numWorkspaceBytes() const;
void setWorkspacePointers(
const std::vector<void*>& workspace,
const std::vector<size_t>& num_workspace_bytes);
void setInputOutputPointers(void* X, void* dX, void* Y, void *dY) {
X_ = X;
dX_ = dX;
Y_ = Y;
dY_ = dY;
}
// Sets the pointers for the scale and weight (in that order) data and derivative buffers.
void setWeightPointers(const std::vector<void*>& weight_pointers,
const std::vector<void*>& deriv_pointers) {
assert(weight_pointers.size() == 2);
assert(deriv_pointers.size() == 2);
scale_ = static_cast<float*>(weight_pointers[0]);
bias_ = static_cast<float*>(weight_pointers[1]);
dscale_ = static_cast<float*>(deriv_pointers[0]);
dbias_ = static_cast<float*>(deriv_pointers[1]);
}
// Sets the pointers for the population mean and variance buffers, in that order.
void setParameterPointers(const std::vector<void*>& param_pointers) {
assert(param_pointers.size() == 2);
population_mean_ = static_cast<float*>(param_pointers[0]);
population_variance_ = static_cast<float*>(param_pointers[1]);
}
void setConstants(const double exp_avg_factor, const double eps) {
exp_avg_factor_ = exp_avg_factor;
eps_ = eps;
}
void processCudnnStatus(const cudnnStatus_t& status,
const std::string& string = std::string(),
bool verbose = VERBOSE_DEFAULT) {
if (status != CUDNN_STATUS_SUCCESS)
LOG(FATAL) << string << " " << cudnnGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << cudnnGetErrorString(status);
}
void checkCudaStatus(const std::string& string = std::string(),
bool verbose = VERBOSE_DEFAULT) {
cudaError_t status = cudaGetLastError();
if (status != cudaSuccess)
LOG(FATAL) << string << " " << cudaGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << cudaGetErrorString(status);
}
size_t size_retired_ctas(int grid_y) const {
// Note that the value of max_grid_y to handle known GPUs is about 160.
const int max_grid_y = 1024;
if (grid_y > max_grid_y)
LOG(INFO) << "GPU capabilities exceeds assumptions.";
const int retired_cta_bytes = max_grid_y * 2 * sizeof(int);
// Since the region will be initialized once and used for many kernels,
// the idea is to return an ample size that will cover all uses.
return retired_cta_bytes;
}
cudnnTensorDescriptor_t X_tensor_desc_ = nullptr;
cudnnTensorDescriptor_t Y_tensor_desc_ = nullptr;
void* X_ = nullptr;
void* dX_ = nullptr;
void* Y_ = nullptr;
void* dY_ = nullptr;
// Learned scale and bias weights.
float* scale_ = nullptr;
float* dscale_ = nullptr;
float* bias_ = nullptr;
float* dbias_ = nullptr;
// Computed population mean and variance parameters.
float* population_mean_ = nullptr;
float* population_variance_ = nullptr;
// Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd).
float* minibatch_mean_ = nullptr;
float* minibatch_variance_ = nullptr;
int m_ = 0; // Number of values per channel that BN is normalizing.
int c_ = 0; // Number of channels over which BN is normalizing.
float svar_inv_count_ = 0.f; // factor to scale sum of squared errors to get saved variance
float rvar_inv_count_ = 0.f; // factor to scale sum of squared errors to get running variance
double exp_avg_factor_ = 0.;
double eps_ = 0.;
std::string name_;
private:
void setTensorDescriptor(cudnnTensorDescriptor_t descriptor,
cudnnTensorFormat_t format,
cudnnDataType_t data_type,
int n, int c, int h, int w) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w);
processCudnnStatus(status, "set tensor descriptor");
}
void createTensorDescriptor(cudnnTensorDescriptor_t *descriptor) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
status = cudnnCreateTensorDescriptor(descriptor);
processCudnnStatus(status, "create tensor_descriptor");
}
void destroyTensorDescriptor(cudnnTensorDescriptor_t descriptor) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
status = cudnnDestroyTensorDescriptor(descriptor);
processCudnnStatus(status, "destroy tensor_descriptor");
}
protected:
float *partial_sums_ = nullptr;
int *partial_counts_ = nullptr;
int *retired_ctas_ = nullptr;
void _setFwdParams(NhwcBatchNormFwdParams *params) const;
void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams *params) const;
void _setBwdParams(NhwcBatchNormBwdParams *params) const;
// @todo: ability to configure these?
// Kernel params
static const int USE_ONLINE_APPROACH = 1;
static const int THREADS_PER_CTA = 512;
static const int THREADS_PER_PIXEL = 16;
static const int C_ELEMENTS_PER_CTA = 64;
static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL;
static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024;
typedef uint16_t StorageType;
//typedef float StorageType;
// increasing this to 6 causes spills in fwd kernel!
static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 5;
static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 3;
static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 10;
static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 5;
static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \
PIXELS_PER_THREAD_IN_SMEM_FWD;
static const int PIXELS_PER_THREAD_BWD = PIXELS_PER_THREAD_IN_REGISTERS_BWD + \
PIXELS_PER_THREAD_IN_SMEM_BWD;
static const int PIXELS_PER_THREAD_FWD_INFERENCE = 4;
// Derived params
static const size_t SMEM_SIZE_FWD = PIXELS_PER_THREAD_IN_SMEM_FWD*THREADS_PER_CTA*\
ELEMENTS_PER_LDG*sizeof(StorageType);
static const size_t SMEM_SIZE_BWD = PIXELS_PER_THREAD_IN_SMEM_BWD*THREADS_PER_CTA*\
ELEMENTS_PER_LDG*2*sizeof(StorageType);
static const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
static const int PIXELS_PER_CTA_FWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \
PIXELS_PER_THREAD_FWD;
static const int PIXELS_PER_CTA_BWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \
PIXELS_PER_THREAD_BWD;
static const int PIXELS_PER_CTA_FWD_INFERENCE = THREADS_PER_CTA/THREADS_PER_PIXEL * \
PIXELS_PER_THREAD_FWD_INFERENCE;
// max grid.y in case of group bn is limited by exchange buffer size
static const int MAX_GBN_BLOCK_Y = 256;
// Helper function to launch the forward kernel.
// We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel
// version that was compiled with that occupancy in its launch bounds. This way, we avoid
// needless register spills.
void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params,
dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) {
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
cudaLaunchKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if (outer_loops == 1 && use_relu) {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(1, true, false, 2, coop);
else
LAUNCH_FWD_KERNEL(1, true, false, 1, coop);
} else if (outer_loops == 1 && !use_relu) {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(1, false, false, 2, coop);
else
LAUNCH_FWD_KERNEL(1, false, false, 1, coop);
} else if (use_relu) {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(0, true, false, 2, coop);
else
LAUNCH_FWD_KERNEL(0, true, false, 1, coop);
} else {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(0, false, false, 2, coop);
else
LAUNCH_FWD_KERNEL(0, false, false, 1, coop);
}
#undef LAUNCH_FWD_KERNEL
}
// Helper function to launch the backward kernel.
void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params,
dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) {
#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_func = nhwc_batch_norm_bwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(bwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " bwd coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using BWD_FUNC = decltype(nhwc_batch_norm_bwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
cudaLaunchKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd coop serial kernel"); \
} while (0)
#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_relu_func = nhwc_batch_norm_bwd_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(bwd_relu_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " bwd-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using BWD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
cudaLaunchKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if (outer_loops == 1 && use_relu) {
if (occupancy >= 2)
LAUNCH_BWD_RELU_KERNEL(1, 2, coop);
else
LAUNCH_BWD_RELU_KERNEL(1, 1, coop);
} else if (outer_loops == 1 && !use_relu) {
if (occupancy >= 2)
LAUNCH_BWD_KERNEL(1, 2, coop);
else
LAUNCH_BWD_KERNEL(1, 1, coop);
} else if (use_relu) {
if (occupancy >= 2)
LAUNCH_BWD_RELU_KERNEL(0, 2, coop);
else
LAUNCH_BWD_RELU_KERNEL(0, 1, coop);
} else {
if (occupancy >= 2)
LAUNCH_BWD_KERNEL(0, 2, coop);
else
LAUNCH_BWD_KERNEL(0, 1, coop);
}
#undef LAUNCH_BWD_KERNEL
}
public:
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);
int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
}
// Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);
int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
}
};
const std::vector<size_t> NhwcBatchNorm::numWorkspaceBytes() const {
assert(c_ > 0);
// choose the max memory required between fwd/bwd passes
int grid_x_fwd = div_up(m_, PIXELS_PER_CTA_FWD);
int grid_x_bwd = div_up(m_, PIXELS_PER_CTA_BWD);
int grid_x = max(grid_x_fwd, grid_x_bwd);
int grid_y = div_up(c_, C_ELEMENTS_PER_CTA);
const size_t num_mean_bytes = c_ * sizeof(float);
const size_t num_variance_bytes = num_mean_bytes;
const size_t size_sums = grid_y*grid_x*THREADS_PER_PIXEL*\
ELEMENTS_PER_LDG*2*sizeof(float);
const size_t size_counts = grid_y*grid_x*sizeof(int);
return {num_mean_bytes, num_variance_bytes,
size_retired_ctas(grid_y), size_sums, size_counts};
}
void NhwcBatchNorm::setWorkspacePointers(
const std::vector<void*>& workspace,
const std::vector<size_t>& num_workspace_bytes) {
assert(workspace.size() == 5);
assert(num_workspace_bytes.size() == 5);
minibatch_mean_ = static_cast<float*>(workspace[0]);
minibatch_variance_ = static_cast<float*>(workspace[1]);
retired_ctas_ = static_cast<int*>(workspace[2]);
partial_sums_ = static_cast<float*>(workspace[3]);
partial_counts_ = static_cast<int*>(workspace[4]);
}
void NhwcBatchNorm::_setFwdParams(NhwcBatchNormFwdParams *params) const {
params->gmem_src = static_cast<uint16_t*>(X_);
params->gmem_dst = static_cast<uint16_t*>(Y_);
params->gmem_src1 = nullptr;
params->gmem_bias = bias_;
params->gmem_scale = scale_;
params->gmem_running_mean = population_mean_;
params->gmem_running_var = population_variance_;
params->gmem_saved_mean = minibatch_mean_;
params->gmem_saved_var = minibatch_variance_;
params->gmem_relu_bitmask = nullptr;
params->nhw = m_;
params->c = c_;
params->svar_inv_count = svar_inv_count_;
params->rvar_inv_count = rvar_inv_count_;
params->gmem_sums = partial_sums_;
params->gmem_counts = partial_counts_;
params->gmem_retired_ctas = retired_ctas_;
params->var_eps = eps_;
params->outer_loops = 0;
params->exp_avg_factor = static_cast<float>(exp_avg_factor_);
params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
}
void NhwcBatchNorm::_setFwdInferenceParams(NhwcBatchNormFwdInferenceParams
*params) const {
params->gmem_src = static_cast<uint16_t*>(X_);
params->gmem_dst = static_cast<uint16_t*>(Y_);
params->gmem_src1 = nullptr;
params->gmem_bias = bias_;
params->gmem_scale = scale_;
params->gmem_mean = population_mean_;
params->gmem_var = population_variance_;
params->nhw = m_;
params->c = c_;
params->var_eps = eps_;
}
void NhwcBatchNorm::_setBwdParams(NhwcBatchNormBwdParams *params) const {
params->gmem_src = static_cast<uint16_t*>(X_);
params->gmem_dy = static_cast<uint16_t*>(dY_);
params->gmem_dst = static_cast<uint16_t*>(dX_);
params->gmem_dst1 = nullptr;
params->gmem_relu_bitmask = nullptr;
params->gmem_dscale = dscale_;
params->gmem_dbias = dbias_;
params->gmem_scale = scale_;
params->gmem_bias = bias_;
params->gmem_saved_mean = minibatch_mean_;
params->gmem_saved_var = minibatch_variance_;
params->nhw = m_;
params->c = c_;
params->svar_inv_count = svar_inv_count_;
params->gmem_sums = partial_sums_;
params->gmem_retired_ctas = retired_ctas_;
params->outer_loops = 0;
params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
}
void NhwcBatchNorm::fwdInference(cudaStream_t stream, bool use_relu) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
&& scale_ != nullptr
&& bias_ != nullptr
// && minibatch_mean_ != nullptr
// && minibatch_variance_ != nullptr
&& population_mean_ != nullptr
&& population_variance_ != nullptr
&& X_ != nullptr
// && dX_ != nullptr
&& Y_ != nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&& partial_sums_ != nullptr
&& partial_counts_ != nullptr;
if (!ptrs_are_set)
die();
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD_INFERENCE);
grid_dim.y = div_up(c_, C_ELEMENTS_PER_CTA);
// @todo: maybe just move this inside initialize routine?
NhwcBatchNormFwdInferenceParams params;
_setFwdInferenceParams(&params);
if (use_relu) {
nhwc_batch_norm_fwd_inference
<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG, true, false>
<<<grid_dim, THREADS_PER_CTA, 0, stream>>>(params);
checkCudaStatus(name_ + " fwd_inference-relu kernel");
} else {
nhwc_batch_norm_fwd_inference
<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG, false, false>
<<<grid_dim, THREADS_PER_CTA, 0, stream>>>(params);
checkCudaStatus(name_ + " fwd_inference kernel");
}
}
dim3 NhwcBatchNorm::calc_fwd_grid(int *loop, const int grid_dim_x) {
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD);
int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
unsigned int max_grid_x = grid_dim_x;
if (grid_dim.x <= max_grid_x) {
*loop = 1;
if (max_grid_x / grid_dim.x > 1) {
grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));
assert(grid_dim.y<MAX_GBN_BLOCK_Y); //FIXME: turn into a loop
} else {
grid_dim.y = 1;
}
} else {
grid_dim.x = max_grid_x;
grid_dim.y = 1;
int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_FWD*PIXELS_PER_LDG*grid_dim.x;
int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_FWD*PIXELS_PER_LDG*grid_dim.x;
*loop = div_up(nhw_in_regs, pixels_per_iteration);
}
return grid_dim;
}
dim3 NhwcBatchNorm::calc_bwd_grid(int *loop, const int grid_dim_x) {
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_BWD);
int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
unsigned int max_grid_x = grid_dim_x;
if (grid_dim.x <= max_grid_x) {
*loop = 1;
if (max_grid_x / grid_dim.x > 1) {
grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));
assert(grid_dim.y<MAX_GBN_BLOCK_Y); //FIXME: turn into a loop
} else {
grid_dim.y = 1;
}
} else {
grid_dim.x = max_grid_x;
grid_dim.y = 1;
int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_BWD*PIXELS_PER_LDG*grid_dim.x;
int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_BWD*PIXELS_PER_LDG*grid_dim.x;
*loop = div_up(nhw_in_regs, pixels_per_iteration);
}
return grid_dim;
}
void NhwcBatchNorm::fwd(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,
const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
&& scale_ != nullptr
&& bias_ != nullptr
&& minibatch_mean_ != nullptr
&& minibatch_variance_ != nullptr
&& population_mean_ != nullptr
&& population_variance_ != nullptr
&& X_ != nullptr
// && dX_ != nullptr
&& Y_ != nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&& partial_sums_ != nullptr
&& partial_counts_ != nullptr
&& retired_ctas_ != nullptr;
if (!ptrs_are_set)
die();
// reset of retired_cta_count no longer needed
NhwcBatchNormFwdParams params;
_setFwdParams(&params);
params.my_data = my_data;
params.pair_datas[0] = pair_data;
params.pair_datas[1] = pair_data2;
params.pair_datas[2] = pair_data3;
params.magic = magic;
params.sync_iters = (bn_group==8)?3:(bn_group >> 1);
dim3 grid_dim = calc_fwd_grid(&params.outer_loops, grid_dim_x);
_fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, occupancy, coop);
}
void NhwcBatchNorm::dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,
const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
&& scale_ != nullptr
&& (bias_ != nullptr || !use_relu)
&& minibatch_mean_ != nullptr
&& minibatch_variance_ != nullptr
// && population_mean_ != nullptr
// && population_variance_ != nullptr
&& X_ != nullptr
&& dX_ != nullptr
// && Y_ != nullptr
&& dY_ != nullptr
&& dscale_ != nullptr
&& dbias_ != nullptr;
if (!ptrs_are_set)
die();
// reset of retired_cta_count no longer needed
NhwcBatchNormBwdParams params;
_setBwdParams(&params);
params.my_data = my_data;
params.pair_datas[0] = pair_data;
params.pair_datas[1] = pair_data2;
params.pair_datas[2] = pair_data3;
params.magic = magic;
params.sync_iters = (bn_group==8)?3:(bn_group >> 1);
params.wgrad_coeff = 1.0 / bn_group;
dim3 grid_dim = calc_bwd_grid(&params.outer_loops, grid_dim_x);
_bwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, occupancy, coop);
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include "THC/THC.h"
#include "batch_norm_add_relu.h"
#include <cuda.h>
//FIXME move the common stuff to common h file
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
static size_t round_up_to_multiple(size_t x, int multiple) {
return ((x + multiple - 1) / multiple) * multiple;
}
// TODO: Stop manually allocating CUDA memory; allocate an ATen byte
// tensor instead.
struct Workspace {
Workspace(size_t size) : size(size), data(NULL) {
data = THCudaMalloc(at::globalContext().lazyInitCUDA(), size);
}
Workspace(const Workspace&) = delete;
Workspace(Workspace&&) = default;
Workspace& operator=(Workspace&&) = default;
~Workspace() {
if (data) {
THCudaFree(at::globalContext().lazyInitCUDA(), data);
}
}
size_t size;
void* data;
};
// Return {y}
at::Tensor nhwc_bn_addrelu_fwd_train(
const at::Tensor& x,
const at::Tensor& z,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& bitmask,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
void * my_data,
void * pair_data,
void * pair_data2,
void * pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int occupancy,
const int grid_dim_x,
const bool coop) {
const int N = x.size(0);
const int H = x.size(1);
const int W = x.size(2);
const int C = x.size(3);
// generating new magic number and use that for sync
int* magic = magic_tensor.data<int>();
*magic = (*magic + 1) & 0xff;
// Allocate output tensor
at::Tensor y = at::empty({N, H, W, C}, x.options());
// Create wrapper
NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.data<at::Half>(),
nullptr,
y.data<at::Half>(),
nullptr,
z.data<at::Half>(),
nullptr);
bn->setWeightPointers({scale.data<float>(), bias.data<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.data<float>(), running_inv_var.data<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 4; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.data<float>());
workspace.push_back(minibatch_inv_var.data<float>());
workspace.push_back(bitmask.data<int32_t>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = ret_cta.data<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
for (auto index = 4; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-4];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
// Don't fuse in ReLU for now at least
bn->fwd(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);
return y;
}
at::Tensor nhwc_bn_addrelu_fwd_eval(
const at::Tensor& x,
const at::Tensor& z,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& ret_cta,
const int bn_group,
const float momentum,
const float epsilon) {
const int N = x.size(0);
const int H = x.size(1);
const int W = x.size(2);
const int C = x.size(3);
// Allocate output tensor
at::Tensor y = at::empty({N, H, W, C}, x.options());
// Create wrapper
NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.data<at::Half>(),
nullptr,
y.data<at::Half>(),
nullptr,
z.data<at::Half>(),
nullptr);
bn->setWeightPointers({scale.data<float>(), bias.data<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.data<float>(), running_inv_var.data<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 4; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(nullptr);
workspace.push_back(nullptr);
workspace.push_back(nullptr);
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = ret_cta.data<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
for (auto index = 4; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-4];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
// Don't fuse in ReLU for now at least
bn->fwdInference(stream);
return y;
}
std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
const at::Tensor& x,
const at::Tensor& dy,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& bitmask,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
void * my_data,
void * pair_data,
void * pair_data2,
void * pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int occupancy,
const int grid_dim_x,
const bool coop) {
// shape
const int N = x.size(0);
const int H = x.size(1);
const int W = x.size(2);
const int C = x.size(3);
// generating new magic number and use that for sync
int* magic = magic_tensor.data<int>();
*magic = (*magic + 1) & 0xff;
// outputs
at::Tensor x_grad, z_grad, scale_grad, bias_grad;
// Allocate outputs
x_grad = at::empty_like(x);
z_grad = at::empty_like(x);
scale_grad = at::empty_like(scale);
bias_grad = at::empty_like(bias);
// Create wrapper
NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.data<at::Half>(),
x_grad.data<at::Half>(),
nullptr,
dy.data<at::Half>(),
nullptr,
z_grad.data<at::Half>());
bn->setWeightPointers({scale.data<float>(), bias.data<float>()}, {scale_grad.data<float>(), bias_grad.data<float>()});
bn->setParameterPointers({running_mean.data<float>(), running_inv_var.data<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 4; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.data<float>());
workspace.push_back(minibatch_inv_var.data<float>());
workspace.push_back(bitmask.data<int32_t>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = ret_cta.data<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
for (auto index = 4; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-4];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
bn->dgrad(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);
return std::vector<at::Tensor>{x_grad, z_grad, scale_grad, bias_grad};
}
int nhwc_bn_addrelu_fwd_occupancy() {
int device_id=-1;
cudaGetDevice(&device_id);
//max occupancy supported by the code is 2
return NhwcBatchNormAddRelu::smem_driven_fwd_occupancy(device_id, 2);
}
int nhwc_bn_addrelu_bwd_occupancy() {
int device_id=-1;
cudaGetDevice(&device_id);
//max occupancy supported by the code is 2
return NhwcBatchNormAddRelu::smem_driven_bwd_occupancy(device_id, 2);
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nhwc_batch_norm_add_relu.h
* \brief CUDA NHWC Batch Normalization code with fused addition
* \author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer
*/
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
#include <cudnn.h>
#include <algorithm>
#include <vector>
#include <string>
#include "nhwc_batch_norm_kernel.h"
#include "cuda_utils.h"
#define VERBOSE_DEFAULT false
class NhwcBatchNormAddRelu {
public:
NhwcBatchNormAddRelu() {
name_ = "nhwc_batchnormaddrelu";
createTensorDescriptor(&X_tensor_desc_);
createTensorDescriptor(&Y_tensor_desc_);
}
~NhwcBatchNormAddRelu() {
destroyTensorDescriptor(X_tensor_desc_);
destroyTensorDescriptor(Y_tensor_desc_);
}
void die() {
std::cerr << "batchnormaddrelu not initialized" << std::endl;
exit(-1);
}
void fwd(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);
void dgrad(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);
void fwdInference(cudaStream_t stream);
dim3 calc_fwd_grid(int *loop, const int grid_dim_x);
dim3 calc_bwd_grid(int *loop, const int grid_dim_x);
void setInputDescriptor(const cudnnTensorFormat_t format,
const cudnnDataType_t data_type,
int n, int c, int h, int w, int bn_group) {
m_ = n * h * w;
int m_bn_adjusted = m_ * bn_group;
c_ = c;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
svar_inv_count_ = 1.f / m_bn_adjusted;
// factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1).
int divisor = m_bn_adjusted - 1;
// nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs.
rvar_inv_count_ = divisor == 0 ? 1.f : 1.f / divisor;
setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w);
}
void setOutputDescriptor(const cudnnTensorFormat_t format,
const cudnnDataType_t data_type,
int n, int c, int h, int w) {
setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w);
}
const std::vector<size_t> numWorkspaceBytes() const;
void setWorkspacePointers(
const std::vector<void*>& workspace,
const std::vector<size_t>& num_workspace_bytes);
void setInputOutputPointers(void* X, void* dX, void* Y, void *dY, void* addend, void* dAddend) {
X_ = X;
dX_ = dX;
Y_ = Y;
dY_ = dY;
addend_ = addend;
dAddend_ = dAddend;
}
// Sets the pointers for the scale and weight (in that order) data and derivative buffers.
void setWeightPointers(const std::vector<void*>& weight_pointers,
const std::vector<void*>& deriv_pointers) {
assert(weight_pointers.size() == 2);
assert(deriv_pointers.size() == 2);
scale_ = static_cast<float*>(weight_pointers[0]);
bias_ = static_cast<float*>(weight_pointers[1]);
dscale_ = static_cast<float*>(deriv_pointers[0]);
dbias_ = static_cast<float*>(deriv_pointers[1]);
}
// Sets the pointers for the population mean and variance buffers, in that order.
void setParameterPointers(const std::vector<void*>& param_pointers) {
assert(param_pointers.size() == 2);
population_mean_ = static_cast<float*>(param_pointers[0]);
population_variance_ = static_cast<float*>(param_pointers[1]);
}
void setConstants(const double exp_avg_factor, const double eps) {
exp_avg_factor_ = exp_avg_factor;
eps_ = eps;
}
void processCudnnStatus(const cudnnStatus_t& status,
const std::string& string = std::string(),
bool verbose = VERBOSE_DEFAULT) {
if (status != CUDNN_STATUS_SUCCESS)
LOG(FATAL) << string << " " << cudnnGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << cudnnGetErrorString(status);
}
void checkCudaStatus(const std::string& string = std::string(),
bool verbose = VERBOSE_DEFAULT) {
cudaError_t status = cudaGetLastError();
if (status != cudaSuccess)
LOG(FATAL) << string << " " << cudaGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << cudaGetErrorString(status);
}
size_t size_retired_ctas(int grid_y) const {
// Note that the value of max_grid_y to handle known GPUs is about 160.
const int max_grid_y = 1024;
if (grid_y > max_grid_y)
LOG(INFO) << "GPU capabilities exceeds assumptions.";
const int retired_cta_bytes = max_grid_y * 2 * sizeof(int);
// Since the region will be initialized once and used for many kernels,
// the idea is to return an ample size that will cover all uses.
return retired_cta_bytes;
}
cudnnTensorDescriptor_t X_tensor_desc_ = nullptr;
cudnnTensorDescriptor_t Y_tensor_desc_ = nullptr;
void* X_ = nullptr;
void* dX_ = nullptr;
void* Y_ = nullptr;
void* dY_ = nullptr;
void* addend_ = nullptr;
void* dAddend_ = nullptr;
// Learned scale and bias weights.
float* scale_ = nullptr;
float* dscale_ = nullptr;
float* bias_ = nullptr;
float* dbias_ = nullptr;
// Computed population mean and variance parameters.
float* population_mean_ = nullptr;
float* population_variance_ = nullptr;
// Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd).
float* minibatch_mean_ = nullptr;
float* minibatch_variance_ = nullptr;
int m_ = 0; // Number of values per channel that BN is normalizing.
int c_ = 0; // Number of channels over which BN is normalizing.
float svar_inv_count_ = 0.f; // factor to scale sum of squared errors to get saved variance
float rvar_inv_count_ = 0.f; // factor to scale sum of squared errors to get running variance
double exp_avg_factor_ = 0.;
double eps_ = 0.;
std::string name_;
private:
void setTensorDescriptor(cudnnTensorDescriptor_t descriptor,
cudnnTensorFormat_t format,
cudnnDataType_t data_type,
int n, int c, int h, int w) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w);
processCudnnStatus(status, "set tensor descriptor");
}
void createTensorDescriptor(cudnnTensorDescriptor_t *descriptor) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
status = cudnnCreateTensorDescriptor(descriptor);
processCudnnStatus(status, "create tensor_descriptor");
}
void destroyTensorDescriptor(cudnnTensorDescriptor_t descriptor) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
status = cudnnDestroyTensorDescriptor(descriptor);
processCudnnStatus(status, "destroy tensor_descriptor");
}
protected:
float *partial_sums_ = nullptr;
int *partial_counts_ = nullptr;
int *retired_ctas_ = nullptr;
unsigned int *relu_bitmask_ = nullptr;
void _setFwdParams(NhwcBatchNormFwdParams *params) const;
void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams *params) const;
void _setBwdParams(NhwcBatchNormBwdParams *params) const;
// @todo: ability to configure these?
// Kernel params
static const int USE_ONLINE_APPROACH = 1;
static const int THREADS_PER_CTA = 512;
static const int THREADS_PER_PIXEL = 16;
static const int C_ELEMENTS_PER_CTA = 64;
static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL;
static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024;
typedef uint16_t StorageType;
// increasing this to 6 causes spills in fwd kernel!
static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 5;
static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 3;
static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 10;
static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 5;
static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \
PIXELS_PER_THREAD_IN_SMEM_FWD;
static const int PIXELS_PER_THREAD_BWD = PIXELS_PER_THREAD_IN_REGISTERS_BWD + \
PIXELS_PER_THREAD_IN_SMEM_BWD;
static const int PIXELS_PER_THREAD_FWD_INFERENCE = 4;
// Derived params
static const size_t SMEM_SIZE_FWD = PIXELS_PER_THREAD_IN_SMEM_FWD*THREADS_PER_CTA*\
ELEMENTS_PER_LDG*sizeof(StorageType);
static const size_t SMEM_SIZE_BWD = PIXELS_PER_THREAD_IN_SMEM_BWD*THREADS_PER_CTA*\
ELEMENTS_PER_LDG*2*sizeof(StorageType);
static const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
static const int PIXELS_PER_CTA_FWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \
PIXELS_PER_THREAD_FWD;
static const int PIXELS_PER_CTA_BWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \
PIXELS_PER_THREAD_BWD;
static const int PIXELS_PER_CTA_FWD_INFERENCE = THREADS_PER_CTA/THREADS_PER_PIXEL * \
PIXELS_PER_THREAD_FWD_INFERENCE;
// max grid.y in case of group bn is limited by exchange buffer size
static const int MAX_GBN_BLOCK_Y = 256;
// Helper function to launch the forward kernel.
// We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel
// version that was compiled with that occupancy in its launch bounds. This way, we avoid
// needless register spills.
void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params,
dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) {
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
cudaLaunchKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if (outer_loops == 1) {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(1, false, true, 2, coop);
else
LAUNCH_FWD_KERNEL(1, false, true, 1, coop);
} else {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(0, false, true, 2, coop);
else
LAUNCH_FWD_KERNEL(0, false, true, 1, coop);
}
#undef LAUNCH_FWD_KERNEL
}
// Helper function to launch the backward kernel.
void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params,
dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) {
#define LAUNCH_BWD_ADD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
auto bwd_add_relu_func = nhwc_batch_norm_bwd_add_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(bwd_add_relu_func, \
cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + \
" bwd-add-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using BWD_ADD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_add_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
cudaLaunchKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-add-relu coop serial kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if (outer_loops == 1) {
if (occupancy >= 2)
LAUNCH_BWD_ADD_RELU_KERNEL(1, 2, coop);
else
LAUNCH_BWD_ADD_RELU_KERNEL(1, 1, coop);
} else {
if (occupancy >= 2)
LAUNCH_BWD_ADD_RELU_KERNEL(0, 2, coop);
else
LAUNCH_BWD_ADD_RELU_KERNEL(0, 1, coop);
}
#undef LAUNCH_BWD_KERNEL
}
public:
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);
int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
}
// Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);
int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
}
};
const std::vector<size_t> NhwcBatchNormAddRelu::numWorkspaceBytes() const {
assert(c_ > 0);
// choose the max memory required between fwd/bwd passes
int grid_x_fwd = div_up(m_, PIXELS_PER_CTA_FWD);
int grid_x_bwd = div_up(m_, PIXELS_PER_CTA_BWD);
int grid_x = max(grid_x_fwd, grid_x_bwd);
int grid_y = div_up(c_, C_ELEMENTS_PER_CTA);
const size_t num_mean_bytes = c_ * sizeof(float);
const size_t num_variance_bytes = num_mean_bytes;
int elems_per_group = ((m_ + 31) & ~31) * 2;
int group_count = div_up(c_, C_ELEMENTS_PER_CTA);
const size_t bitmask_bytes = elems_per_group * group_count * sizeof(unsigned int);
const size_t size_sums = grid_y*grid_x*THREADS_PER_PIXEL*\
ELEMENTS_PER_LDG*2*sizeof(float);
const size_t size_counts = grid_y*grid_x*sizeof(int);
return {num_mean_bytes, num_variance_bytes, bitmask_bytes,
size_retired_ctas(grid_y), size_sums, size_counts};
}
void NhwcBatchNormAddRelu::setWorkspacePointers(
const std::vector<void*>& workspace,
const std::vector<size_t>& num_workspace_bytes) {
assert(workspace.size() == 6);
assert(num_workspace_bytes.size() == 6);
minibatch_mean_ = static_cast<float*>(workspace[0]);
minibatch_variance_ = static_cast<float*>(workspace[1]);
relu_bitmask_ = static_cast<unsigned int*>(workspace[2]);
retired_ctas_ = static_cast<int*>(workspace[3]);
partial_sums_ = static_cast<float*>(workspace[4]);
partial_counts_ = static_cast<int*>(workspace[5]);
}
void NhwcBatchNormAddRelu::_setFwdParams(NhwcBatchNormFwdParams *params) const {
params->gmem_src = static_cast<uint16_t*>(X_);
params->gmem_dst = static_cast<uint16_t*>(Y_);
params->gmem_src1 = static_cast<uint16_t*>(addend_);
params->gmem_bias = bias_;
params->gmem_scale = scale_;
params->gmem_running_mean = population_mean_;
params->gmem_running_var = population_variance_;
params->gmem_saved_mean = minibatch_mean_;
params->gmem_saved_var = minibatch_variance_;
params->gmem_relu_bitmask = relu_bitmask_;
params->nhw = m_;
params->c = c_;
params->svar_inv_count = svar_inv_count_;
params->rvar_inv_count = rvar_inv_count_;
params->gmem_sums = partial_sums_;
params->gmem_counts = partial_counts_;
params->gmem_retired_ctas = retired_ctas_;
params->var_eps = eps_;
params->outer_loops = 0;
params->exp_avg_factor = static_cast<float>(exp_avg_factor_);
params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
}
void NhwcBatchNormAddRelu::_setFwdInferenceParams(NhwcBatchNormFwdInferenceParams
*params) const {
params->gmem_src = static_cast<uint16_t*>(X_);
params->gmem_dst = static_cast<uint16_t*>(Y_);
params->gmem_src1 = static_cast<uint16_t*>(addend_);
params->gmem_bias = bias_;
params->gmem_scale = scale_;
params->gmem_mean = population_mean_;
params->gmem_var = population_variance_;
params->nhw = m_;
params->c = c_;
params->var_eps = eps_;
}
void NhwcBatchNormAddRelu::_setBwdParams(NhwcBatchNormBwdParams *params) const {
params->gmem_src = static_cast<uint16_t*>(X_);
params->gmem_dy = static_cast<uint16_t*>(dY_);
params->gmem_dst = static_cast<uint16_t*>(dX_);
params->gmem_dst1 = static_cast<uint16_t*>(dAddend_);
params->gmem_relu_bitmask = relu_bitmask_;
params->gmem_dscale = dscale_;
params->gmem_dbias = dbias_;
params->gmem_scale = scale_;
params->gmem_bias = bias_;
params->gmem_saved_mean = minibatch_mean_;
params->gmem_saved_var = minibatch_variance_;
params->nhw = m_;
params->c = c_;
params->svar_inv_count = svar_inv_count_;
params->gmem_sums = partial_sums_;
params->gmem_retired_ctas = retired_ctas_;
params->outer_loops = 0;
params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
}
void NhwcBatchNormAddRelu::fwdInference(cudaStream_t stream) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
&& scale_ != nullptr
&& bias_ != nullptr
// && minibatch_mean_ != nullptr
// && minibatch_variance_ != nullptr
&& population_mean_ != nullptr
&& population_variance_ != nullptr
&& X_ != nullptr
// && dX_ != nullptr
&& Y_ != nullptr
&& addend_ != nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&& partial_sums_ != nullptr
&& partial_counts_ != nullptr;
if (!ptrs_are_set)
die();
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD_INFERENCE);
grid_dim.y = div_up(c_, C_ELEMENTS_PER_CTA);
// @todo: maybe just move this inside initialize routine?
NhwcBatchNormFwdInferenceParams params;
_setFwdInferenceParams(&params);
nhwc_batch_norm_fwd_inference
<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG, false, true>
<<<grid_dim, THREADS_PER_CTA, 0, stream>>>(params);
checkCudaStatus(name_ + " fwd_inference-relu kernel");
}
dim3 NhwcBatchNormAddRelu::calc_fwd_grid(int *loop, const int grid_dim_x) {
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD);
int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
unsigned int max_grid_x = grid_dim_x;
if (grid_dim.x <= max_grid_x) {
*loop = 1;
if (max_grid_x / grid_dim.x > 1) {
grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));
assert(grid_dim.y<MAX_GBN_BLOCK_Y); //FIXME: turn into a loop
} else {
grid_dim.y = 1;
}
} else {
grid_dim.x = max_grid_x;
grid_dim.y = 1;
int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_FWD*PIXELS_PER_LDG*grid_dim.x;
int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_FWD*PIXELS_PER_LDG*grid_dim.x;
*loop = div_up(nhw_in_regs, pixels_per_iteration);
}
return grid_dim;
}
dim3 NhwcBatchNormAddRelu::calc_bwd_grid(int *loop, const int grid_dim_x) {
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_BWD);
int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
unsigned int max_grid_x = grid_dim_x;
if (grid_dim.x <= max_grid_x) {
*loop = 1;
if (max_grid_x / grid_dim.x > 1) {
grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));
assert(grid_dim.y<MAX_GBN_BLOCK_Y); //FIXME: turn into a loop
} else {
grid_dim.y = 1;
}
} else {
grid_dim.x = max_grid_x;
grid_dim.y = 1;
int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_BWD*PIXELS_PER_LDG*grid_dim.x;
int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_BWD*PIXELS_PER_LDG*grid_dim.x;
*loop = div_up(nhw_in_regs, pixels_per_iteration);
}
return grid_dim;
}
void NhwcBatchNormAddRelu::fwd(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,
const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
&& scale_ != nullptr
&& bias_ != nullptr
&& minibatch_mean_ != nullptr
&& minibatch_variance_ != nullptr
&& relu_bitmask_ != nullptr
&& population_mean_ != nullptr
&& population_variance_ != nullptr
&& X_ != nullptr
// && dX_ != nullptr
&& Y_ != nullptr
&& addend_ != nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&& partial_sums_ != nullptr
&& partial_counts_ != nullptr
&& retired_ctas_ != nullptr;
if (!ptrs_are_set)
die();
// reset of retired_cta_count no longer needed
NhwcBatchNormFwdParams params;
_setFwdParams(&params);
params.my_data = my_data;
params.pair_datas[0] = pair_data;
params.pair_datas[1] = pair_data2;
params.pair_datas[2] = pair_data3;
params.magic = magic;
params.sync_iters = (bn_group==8)?3:(bn_group >> 1);
dim3 grid_dim = calc_fwd_grid(&params.outer_loops, grid_dim_x);
_fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, occupancy, coop);
}
void NhwcBatchNormAddRelu::dgrad(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,
const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
&& scale_ != nullptr
&& bias_ != nullptr
&& minibatch_mean_ != nullptr
&& minibatch_variance_ != nullptr
&& relu_bitmask_ != nullptr
// && population_mean_ != nullptr
// && population_variance_ != nullptr
&& X_ != nullptr
&& dX_ != nullptr
// && Y_ != nullptr
&& dY_ != nullptr
&& dAddend_ != nullptr
&& dscale_ != nullptr
&& dbias_ != nullptr
&& retired_ctas_ != nullptr;
if (!ptrs_are_set)
die();
// reset of retired_cta_count no longer needed
NhwcBatchNormBwdParams params;
_setBwdParams(&params);
params.my_data = my_data;
params.pair_datas[0] = pair_data;
params.pair_datas[1] = pair_data2;
params.pair_datas[2] = pair_data3;
params.magic = magic;
params.sync_iters = (bn_group==8)?3:(bn_group >> 1);
params.wgrad_coeff = 1.0 / bn_group;
dim3 grid_dim = calc_bwd_grid(&params.outer_loops, grid_dim_x);
_bwdKernelLauncher(stream, params, grid_dim, params.outer_loops, occupancy, coop);
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
#include <ATen/cuda/CUDAContext.h>
#ifndef CUDA_UTILS_H
#define CUDA_UTILS_H
namespace at {
namespace cuda {
namespace utils {
static inline int MaxSharedMemoryPerMultiprocessor(int device_id) {
return getDeviceProperties(device_id)->sharedMemPerMultiprocessor;
}
}
}
}
#endif
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/ArrayRef.h>
#include <ATen/ScalarType.h>
#include "ATen/Scalar.h"
#ifndef VERSION_GE_1_1
#include "ATen/Type.h"
#endif
#include "ATen/Tensor.h"
#include "ATen/Storage.h"
#include "ATen/Generator.h"
namespace py = pybind11;
int64_t get_buffer_size(
const int bn_sync_steps);
void* get_data_ptr(
const at::Tensor& data);
void* get_remote_data_ptr(
const at::Tensor& handle,
const int64_t offset);
void close_remote_data(
const at::Tensor& handle);
at::Tensor nhwc_bn_fwd_train(
const at::Tensor& x,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
const bool fuse_relu,
void* my_data,
void* pair_data,
void* pair_data2,
void* pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int occupancy,
const int grid_dim_x,
const bool coop);
at::Tensor nhwc_bn_fwd_eval(
const at::Tensor& x,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& ret_cta,
const int bn_group,
const float momentum,
const float epsilon,
const bool fuse_relu);
std::vector<at::Tensor> nhwc_bn_bwd(
const at::Tensor& x,
const at::Tensor& dy,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
const bool fuse_relu,
void* my_data,
void* pair_data,
void* pair_data2,
void* pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int occupancy,
const int grid_dim_x,
const bool coop);
at::Tensor nhwc_bn_addrelu_fwd_train(
const at::Tensor& x,
const at::Tensor& z,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& bitmask,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
void* my_data,
void* pair_data,
void* pair_data2,
void* pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int occupancy,
const int grid_dim_x,
const bool coop);
at::Tensor nhwc_bn_addrelu_fwd_eval(
const at::Tensor& x,
const at::Tensor& z,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& ret_cta,
const int bn_group,
const float momentum,
const float epsilon);
std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
const at::Tensor& x,
const at::Tensor& dy,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& bitmask,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
void* my_data,
void* pair_data,
void* pair_data2,
void* pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int occupancy,
const int grid_dim_x,
const bool coop);
int nhwc_bn_fwd_occupancy();
int nhwc_bn_bwd_occupancy();
int nhwc_bn_addrelu_fwd_occupancy();
int nhwc_bn_addrelu_bwd_occupancy();
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_buffer_size", &get_buffer_size, "get_buffer_size");
m.def("get_data_ptr", &get_data_ptr, "get_data_ptr");
m.def("get_remote_data_ptr", &get_remote_data_ptr, "get_remote_data_ptr");
m.def("close_remote_data", &close_remote_data, "close_remote_data");
m.def("bn_fwd_nhwc", &nhwc_bn_fwd_train, "bn_fwd_nhwc");
m.def("bn_fwd_eval_nhwc", &nhwc_bn_fwd_eval, "bn_fwd_eval_nhwc");
m.def("bn_bwd_nhwc", &nhwc_bn_bwd, "bn_bwd_nhwc");
m.def("bn_fwd_nhwc_occupancy", &nhwc_bn_fwd_occupancy, "bn_fwd_nhwc_occupancy");
m.def("bn_bwd_nhwc_occupancy", &nhwc_bn_bwd_occupancy, "bn_bwd_nhwc_occupancy");
m.def("bn_addrelu_fwd_nhwc", &nhwc_bn_addrelu_fwd_train, "bn_addrelu_fwd_nhwc");
m.def("bn_addrelu_fwd_eval_nhwc", &nhwc_bn_addrelu_fwd_eval, "bn_addrelu_fwd_eval_nhwc");
m.def("bn_addrelu_bwd_nhwc", &nhwc_bn_addrelu_bwd, "bn_addrelu_bwd_nhwc");
m.def("bn_addrelu_fwd_nhwc_occupancy", &nhwc_bn_addrelu_fwd_occupancy, "bn_addrelu_fwd_nhwc_occupancy");
m.def("bn_addrelu_bwd_nhwc_occupancy", &nhwc_bn_addrelu_bwd_occupancy, "bn_addrelu_bwd_nhwc_occupancy");
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include "THC/THC.h"
#include <cuda.h>
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
template<>
struct std::hash<cudaIpcMemHandle_t> {
size_t operator() (const cudaIpcMemHandle_t& handle) const {
size_t hash = 0;
uint8_t* ptr = (uint8_t*)&handle;
assert(sizeof(uint8_t) == 1);
for (int i=0; i<sizeof(cudaIpcMemHandle_t); i++) {
hash += *ptr;
ptr++;
}
return hash;
}
};
template<>
struct std::equal_to<cudaIpcMemHandle_t> {
bool operator() (const cudaIpcMemHandle_t &lhs,
const cudaIpcMemHandle_t &rhs) const {
return (std::memcmp((void*) &lhs,
(void*) &rhs,
sizeof(cudaIpcMemHandle_t)) == 0);
}
};
namespace {
namespace gpuipc {
//from: src/operator/nn/cudnn/nhwc_batch_norm_kernel.h
// The number of threads per pixel.
const int THREADS_PER_PIXEL = 16;
// The number of elements per ldg.
const int ELEMENTS_PER_LDG = 4;
// The number of reducing ops, each uses its own space : mean, var, dscale, dbias
const int REDUCE_OPS = 4;
// Maximum block.y supported - limited due to buffer allocation
const int MAX_BLOCK_Y = 256;
const int MAX_OFFSET = REDUCE_OPS*MAX_BLOCK_Y;
const int BYTES_PER_ELEM = 4;
// Buffer size per sync step
const int SINGLE_SYNC_BUFFER_BYTES = MAX_OFFSET*THREADS_PER_PIXEL*2*ELEMENTS_PER_LDG*BYTES_PER_ELEM;
};
class IpcMemHandleRegistry {
public:
void* getPtr(const cudaIpcMemHandle_t& handle, int64_t offset) {
if (registry_.count(handle) == 0) {
registry_.insert(std::make_pair(handle, RegistryEntry()));
registry_[handle].dev_ptr = ipcOpenMem(handle);
}
registry_[handle].ref_count++;
return (((uint8_t*)registry_[handle].dev_ptr) + offset);
}
void releasePtr(const cudaIpcMemHandle_t& handle) {
if (registry_.count(handle) == 0) {
}
if (--registry_[handle].ref_count == 0) {
ipcCloseMem(registry_[handle].dev_ptr);
registry_.erase(handle);
}
}
struct RegistryEntry {
void* dev_ptr;
int ref_count;
RegistryEntry() : dev_ptr(NULL) , ref_count(0) {}
};
protected:
std::unordered_map<cudaIpcMemHandle_t, RegistryEntry> registry_;
void* ipcOpenMem(const cudaIpcMemHandle_t& handle) {
void *data;
cudaIpcOpenMemHandle(&data, handle, cudaIpcMemLazyEnablePeerAccess);
cudaCheckErrors("ipc init");
return data;
}
void ipcCloseMem(void* dev_ptr) {
cudaIpcCloseMemHandle(dev_ptr);
cudaCheckErrors("ipc close");
}
};
}
static IpcMemHandleRegistry ipc_mem_registry;
int64_t get_buffer_size(const int bn_sync_steps) {
return bn_sync_steps * gpuipc::SINGLE_SYNC_BUFFER_BYTES;
}
void* get_remote_data_ptr(const at::Tensor& handle, const int64_t offset) {
cudaIpcMemHandle_t my_handle;
memcpy((unsigned char *)(&my_handle), handle.data<uint8_t>(), sizeof(my_handle));
return ipc_mem_registry.getPtr(my_handle, offset);
}
void close_remote_data(const at::Tensor& handle) {
cudaIpcMemHandle_t my_handle;
memcpy((unsigned char *)(&my_handle), handle.data<uint8_t>(), sizeof(my_handle));
ipc_mem_registry.releasePtr(my_handle);
}
void* get_data_ptr(
const at::Tensor& data) {
return data.data<uint8_t>();
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nhwc_batch_norm_kernel.h
* \brief CUDA NHWC Batch Normalization code
* \author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer
*/
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_
#include <stdint.h>
#include <algorithm>
#define DEVICE_FUNCTION static inline __device__
// CTA margin used by cooperative launch. Can be overridden by env var NHWC_BATCHNORM_LAUNCH_MARGIN.
#define NHWC_BATCHNORM_LAUNCH_MARGIN_MIN 3
#define NHWC_BATCHNORM_LAUNCH_MARGIN_DEFAULT NHWC_BATCHNORM_LAUNCH_MARGIN_MIN
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename T, int ELEMENTS_PER_LDG >
struct PackedStorage {
enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG };
typedef T Type;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int ELEMENTS_PER_LDG >
struct PackedStorage<uint16_t, ELEMENTS_PER_LDG> {
enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG/2 };
typedef int Type;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void from_float(int (&dst)[N], const float (&src)[2*N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
uint16_t lo, hi;
asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(lo) : "f"(src[2*i+0]));
asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(hi) : "f"(src[2*i+1]));
asm volatile("mov.b32 %0, {%1, %2};" : "=r"(dst[i]) : "h"(lo), "h"(hi));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void from_float(float (&dst)[N], const float (&src)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = src[i];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void to_float(float (&dst)[2*N], int (&src)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;" : "=h"(lo), "=h"(hi) : "r"(src[i]));
asm volatile("cvt.f32.f16 %0, %1;" : "=f"(dst[2*i+0]) : "h"(lo));
asm volatile("cvt.f32.f16 %0, %1;" : "=f"(dst[2*i+1]) : "h"(hi));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void to_float(float (&dst)[N], float (&src)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = src[i];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void ldg(int (&dst)[1], const uint16_t *gmem) {
dst[0] = __ldg((const int*) gmem);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void ldg_stream(int (&dst)[1], const uint16_t *gmem) {
unsigned int tmp;
asm volatile ("ld.global.cs.nc.s32 %0, [%1];" : "=r"(tmp) : "l" ((const uint *)gmem));
dst[0] = tmp;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void ldg(int (&dst)[2], const uint16_t *gmem) {
int2 tmp = __ldg((const int2*) gmem);
dst[0] = tmp.x;
dst[1] = tmp.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void ldg_stream(int (&dst)[2], const uint16_t *gmem) {
int2 tmp;
asm volatile ("ld.global.cs.nc.v2.s32 {%0,%1}, [%2];"
: "=r"(tmp.x), "=r"(tmp.y) : "l"((const int2 *)gmem));
dst[0] = tmp.x;
dst[1] = tmp.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void ldg(float (&dst)[N], const uint16_t *gmem) {
int tmp[N/2];
ldg(tmp, gmem);
to_float(dst, tmp);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void ldg_stream(float (&dst)[N], const uint16_t *gmem) {
int tmp[N/2];
ldg_stream(tmp, gmem);
to_float(dst, tmp);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void stg(uint16_t *gmem, int (&src)[1]) {
reinterpret_cast<int*>(gmem)[0] = src[0];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void stg_stream(uint16_t *gmem, int (&src)[1]) {
unsigned int tmp = src[0];
asm volatile ("st.global.cs.s32 [%0], %1;"
:: "l"((uint *)gmem) , "r"(tmp));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void stg(uint16_t *gmem, int (&src)[2]) {
reinterpret_cast<int2*>(gmem)[0] = make_int2(src[0], src[1]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void stg_stream(uint16_t *gmem, int (&src)[2]) {
asm volatile ("st.global.cs.v2.s32 [%0], {%1,%2};"
:: "l"((uint *)gmem) , "r"(src[0]), "r"( src[1]));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void stg(uint16_t *gmem, float (&src)[N]) {
int tmp[N/2];
from_float(tmp, src);
stg(gmem, tmp);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void stg_stream(uint16_t *gmem, float (&src)[N]) {
int tmp[N/2];
from_float(tmp, src);
stg_stream(gmem, tmp);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_gmem(float (&dst)[2], const float *gmem, int idx) {
float2 tmp = __ldg(reinterpret_cast<const float2*>(&gmem[2*idx]));
dst[0] = tmp.x;
dst[1] = tmp.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_gmem(float (&dst)[4], const float *gmem, int idx) {
float4 tmp = __ldg(reinterpret_cast<const float4*>(&gmem[4*idx]));
dst[0] = tmp.x;
dst[1] = tmp.y;
dst[2] = tmp.z;
dst[3] = tmp.w;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_smem(float (&x)[2], const float *smem, int idx) {
float2 tmp = *(const float2*) &smem[2*idx];
x[0] = tmp.x;
x[1] = tmp.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_smem(int (&x)[1], const int *smem, int idx) {
x[0] = smem[idx];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_smem(float (&x)[4], const float *smem, int idx) {
float4 tmp = *(const float4*) &smem[4*idx];
x[0] = tmp.x;
x[1] = tmp.y;
x[2] = tmp.z;
x[3] = tmp.w;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_smem(int (&x)[2], const int *smem, int idx) {
int2 tmp = *(const int2*) &smem[2*idx];
x[0] = tmp.x;
x[1] = tmp.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_gmem(float *gmem, int idx, const float (&src)[2]) {
reinterpret_cast<float2*>(&gmem[2*idx])[0] = make_float2(src[0], src[1]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_gmem(float *gmem, int idx, const float (&src)[4]) {
reinterpret_cast<float4*>(&gmem[4*idx])[0] = make_float4(src[0], src[1], src[2], src[3]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void scaled_write_to_gmem(float *gmem, int idx, const float (&src)[4], const float coeff) {
reinterpret_cast<float4*>(&gmem[4*idx])[0] = make_float4(src[0]*coeff, src[1]*coeff, src[2]*coeff, src[3]*coeff);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_smem(float *smem, int idx, const float (&x)[2]) {
reinterpret_cast<float2*>(&smem[2*idx])[0] = make_float2(x[0], x[1]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_smem(int *smem, int idx, const int (&x)[1]) {
smem[idx] = x[0];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_smem(float *smem, int idx, const float (&x)[4]) {
reinterpret_cast<float4*>(&smem[4*idx])[0] = make_float4(x[0], x[1], x[2], x[3]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_smem(int *smem, int idx, const int (&x)[2]) {
reinterpret_cast<int2*>(&smem[2*idx])[0] = make_int2(x[0], x[1]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void zero_array(int (&dst)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = 0;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void zero_array(float (&dst)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = 0.f;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void add(float (&x)[N], const float (&y)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
x[i] += y[i];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void multiply(float (&x)[N], const float (&y)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
x[i] *= y[i];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void scale_(float (&x)[N], float scalar) {
#pragma unroll
for (int i = 0; i < N; ++i) {
x[i] *= scalar;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void normalize(float (&x)[N], const float (&bias)[N],
const float (&scale)[N], const float (&m1)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
x[i] = bias[i] + scale[i] * (x[i] - m1[i]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Storage>
DEVICE_FUNCTION Storage relu(Storage in) {
Storage zero = (Storage)0.f;
return (in < zero)? zero : in;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void relu_activation(float (&x)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
x[i] = relu(x[i]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw,
void* params_my_data, void** params_pair_datas, int off,
const int magic,
const int sync_iters) {
// The size of a warp.
const int THREADS_PER_WARP = 32;
// The number of warps in a CTA.
const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;
// The number of threads per pixel.
const int THREADS_PER_PIXEL = 16;
// The number of elements per ldg.
const int ELEMENTS_PER_LDG = 4;
// The number of reducing ops, each uses its own space : mean, var, dscale, dbias
const int REDUCE_OPS = 4;
// Maximum block.y supported - limited due to buffer allocation
const int MAX_BLOCK_Y = 256;
const int MAX_OFFSET = REDUCE_OPS*MAX_BLOCK_Y;
// The warp decomposition.
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int lane_id = threadIdx.x % THREADS_PER_WARP;
// total size of data per sync iter
const int data_total = MAX_OFFSET*THREADS_PER_PIXEL*ELEMENTS_PER_LDG*2;
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id);
}
// The warp leaders, write to SMEM.
if (lane_id < THREADS_PER_PIXEL) {
write_to_smem(smem, warp_id*THREADS_PER_PIXEL + lane_id, x);
}
// The data is in SMEM. Do the final reduction.
__syncthreads();
// The 1st warp does all the work.
// We do the final reduction each half-warp sequentially reduces the final values.
if (warp_id == 0) {
read_from_smem(x, smem, threadIdx.x);
#pragma unroll
for (int offset = 1;
offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) {
float y[ELEMENTS_PER_LDG];
// Read the mean and variance from the other pixel.
read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP);
// Compute the updated sum.
add(x, y);
}
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id);
}
// Make sure the data was read from SMEM.
__syncwarp();
// Store the final values.
if (threadIdx.x < THREADS_PER_PIXEL) {
// probably could do it earlier, before sync
for (int sync_iter=0; sync_iter < sync_iters; ++sync_iter) {
//float* params_pair_data = (reinterpret_cast<float**>(params_pair_datas))[sync_iter];
void* params_pair_data = params_pair_datas[sync_iter];
// skip the space consumed by previous sync iterations
const int xbuf_offset = sync_iter*data_total;
// data starts after flags, but have to skip previous
const int data_offset = xbuf_offset
+ off*ELEMENTS_PER_LDG*THREADS_PER_PIXEL*2
+ ELEMENTS_PER_LDG*threadIdx.x*2;
// after sums for this GPU were computed, let CTA0 broadcast the sum to over GPU
if (blockIdx.x == 0) {
volatile float * write_data =
&((reinterpret_cast<float*>(params_pair_data))[data_offset]);
// write the data to memory region to be reflected to other GPU
asm volatile ("st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};"
:: "l"(write_data) , "f"(x[0]), "r"(magic), "f"(x[2]), "r"(magic));
asm volatile ("st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};"
:: "l"(write_data+4) , "f"(x[1]), "r"(magic), "f"(x[3]), "r"(magic));
}
// now each CTA (on each GPU) reads the data written by CTA 0 of the other GPU
volatile float * read_data =
&((reinterpret_cast<float*>(params_my_data))[data_offset]);
float other[4];
uint32_t other_flag_a, other_flag_b;
do {
asm volatile ("ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];"
: "=f"(other[0]), "=r"(other_flag_a), "=f"(other[2]), "=r"(other_flag_b) : "l"(read_data));
} while ((other_flag_a != magic) || (other_flag_b != magic));
do {
asm volatile ("ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];"
: "=f"(other[1]), "=r"(other_flag_a), "=f"(other[3]), "=r"(other_flag_b) : "l"(read_data+4));
} while ((other_flag_a != magic) || (other_flag_b != magic));
add(x, other);
}
// finally, after syncing up and accounting for partial sums from
// other GPUs as required, write the result
write_to_smem(smem, threadIdx.x, x);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) {
// The size of a warp.
const int THREADS_PER_WARP = 32;
// The number of warps in a CTA.
const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;
// The number of threads per pixel.
const int THREADS_PER_PIXEL = 8;
// The number of elements per ldg.
const int ELEMENTS_PER_LDG = 4;
// The warp decomposition.
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int lane_id = threadIdx.x % THREADS_PER_WARP;
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id);
x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL*2+lane_id);
}
// The warp leaders, write to SMEM.
if (lane_id < THREADS_PER_PIXEL) {
write_to_smem(smem, warp_id*THREADS_PER_PIXEL + lane_id, x);
}
// The data is in SMEM. Do the final reduction.
__syncthreads();
// The 1st warp does all the work.
// We do the final reduction each half-warp sequentially reduces the final values.
if (warp_id == 0) {
read_from_smem(x, smem, threadIdx.x);
#pragma unroll
for (int offset = 1;
offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) {
float y[ELEMENTS_PER_LDG];
// Read the mean and variance from the other pixel.
read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP);
// Compute the updated sum.
add(x, y);
}
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id);
x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL*2+lane_id);
}
// Make sure the data was read from SMEM.
__syncwarp();
// Store the final values.
if (threadIdx.x < THREADS_PER_PIXEL) {
write_to_smem(smem, threadIdx.x, x);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int THREADS_PER_CTA, int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG >
DEVICE_FUNCTION void parallel_sums(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) {
// The size of a warp.
const int THREADS_PER_WARP = 32;
// The number of warps in a CTA.
const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;
// The number of pixels computed by a single warp.
const int PIXELS_PER_WARP = THREADS_PER_WARP / THREADS_PER_PIXEL;
// The position in the warp.
const int nhw_in_warp = nhw % PIXELS_PER_WARP;
// The C in the warp.
const int c_in_warp = threadIdx.x % THREADS_PER_PIXEL;
// Store the values to shared memory.
write_to_smem(smem, threadIdx.x, x);
// Compute the parallel sums.
for (int offset = PIXELS_PER_WARP/2; offset > 0; offset /= 2) {
// NOP.
__syncwarp();
// Read the running sum from the other thread.
float y[ELEMENTS_PER_LDG];
if (nhw_in_warp < offset) {
read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_PIXEL);
}
// Compute the updated sum.
add(x, y);
// NOP.
__syncwarp();
// Update the sum in SMEM.
if (offset > 1 && nhw_in_warp < offset) {
write_to_smem(smem, threadIdx.x, x);
}
}
// The warps are done. Do the final reduction at the CTA level.
__syncthreads();
// The warp leaders, write to SMEM.
const int idx = (threadIdx.x/THREADS_PER_WARP)*THREADS_PER_PIXEL + c_in_warp;
if (nhw_in_warp == 0) {
write_to_smem(smem, idx, x);
}
// The data is in SMEM. Do the final reduction.
__syncthreads();
// Read the 1st element to prepare the work.
if (nhw < WARPS_PER_CTA/2) {
read_from_smem(x, smem, threadIdx.x);
}
// We have the running mean and running m2. Let's build the mean/var of the CTA.
for (int offset = WARPS_PER_CTA/2; offset > 0; offset /= 2) {
// NOP.
__syncwarp();
// Read the mean and variance from the other pixel.
float y[ELEMENTS_PER_LDG];
if (nhw < offset) {
read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_PIXEL);
}
// Compute the updated sum.
add(x, y);
// NOP.
__syncwarp();
// Store the mean/var for the different pixels.
if (nhw < offset) {
write_to_smem(smem, threadIdx.x, x);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG >
struct ParallelSums {
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void dispatch(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) {
parallel_sums<THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG>(smem, x, nhw);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct ParallelSums<16, 4> {
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void dispatch(float *smem, float (&x)[4], int nhw) {
parallel_sums_16x2<THREADS_PER_CTA>(smem, x, nhw, 0, 0, 0, 0, 0);
}
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void dispatchX(float *smem, float (&x)[4], int nhw, void* params_my_data, void** params_pair_datas, int off, const int magic, const unsigned int& sync_iters) {
parallel_sums_16x2<THREADS_PER_CTA>(smem, x, nhw, params_my_data, params_pair_datas, off, magic, sync_iters);
}
};
template<>
struct ParallelSums<8, 4> {
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void dispatch(float *smem, float (&x)[4], int nhw) {
parallel_sums_8x4<THREADS_PER_CTA>(smem, x, nhw);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline int div_up(int m, int n) {
return (m + n - 1) / n;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// It is expected that all threads in the CTA enter this function!
DEVICE_FUNCTION void inter_block_sync(int* gmem_retired_ctas, int expected_count, bool master) {
// Register the CTA.
if (threadIdx.x == 0) {
// Issue the membar.
__threadfence();
// Notify that the CTA is done.
int val_to_add = 1;
if (master) {
val_to_add = -(expected_count - 1);
}
atomicAdd(gmem_retired_ctas, val_to_add);
}
// Are all CTAs done?
if (threadIdx.x == 0) {
int retired_ctas = -1;
do {
__threadfence();
asm volatile ("ld.global.cg.b32 %0, [%1];"
: "=r"(retired_ctas) : "l"(gmem_retired_ctas));
} while (retired_ctas != 0);
}
__syncthreads();
}
////////////////////////////////////////////////////////////////////////////////////////////////////
struct NhwcBatchNormFwdInferenceParams {
// The input/output tensors.
uint16_t *gmem_src, *gmem_dst, *gmem_src1;
// the final mean and variance as calculated during the training process
float *gmem_mean, *gmem_var;
// The bias/scale.
float *gmem_bias, *gmem_scale;
// The dimensions.
int nhw, c;
// epsilon
float var_eps;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// No DESIRED_OCCUPANCY launch bounds needed, as this is not launched cooperatively
template<
typename Storage,
int THREADS_PER_CTA,
int THREADS_PER_PIXEL,
int ELEMENTS_PER_LDG,
bool USE_RELU,
bool USE_ADD_RELU
>
__global__ __launch_bounds__(THREADS_PER_CTA)
void nhwc_batch_norm_fwd_inference(NhwcBatchNormFwdInferenceParams params) {
// The number of pixels loaded in a single LDG.
const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
// The number of C elements per CTA.
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
// The start position in the NHW dimension where the CTA starts.
const int cta_nhw_stride = gridDim.x * PIXELS_PER_LDG;
// Compute the NHW coordinate of the thread in the CTA.
const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;
// thread's starting point in NHW
const int thread_nhw = thread_in_cta_nhw + blockIdx.x * PIXELS_PER_LDG;
// The position in the C dimension where the CTA starts.
const int cta_c = blockIdx.y * C_ELEMENTS_PER_CTA;
// Compute the C coordinate of the thread in the CTA.
const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;
// Compute the C coordinate of the thread.
const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;
// Is the thread working on a valid C dimension?
const int is_valid_c = thread_c < params.c;
float mean[ELEMENTS_PER_LDG], var[ELEMENTS_PER_LDG];
float scale[ELEMENTS_PER_LDG], bias[ELEMENTS_PER_LDG];
zero_array(mean);
zero_array(var);
zero_array(scale);
zero_array(bias);
if (is_valid_c) {
read_from_gmem(var, &params.gmem_var[cta_c], thread_in_cta_c);
read_from_gmem(scale, &params.gmem_scale[cta_c], thread_in_cta_c);
read_from_gmem(mean, &params.gmem_mean[cta_c], thread_in_cta_c);
read_from_gmem(bias, &params.gmem_bias[cta_c], thread_in_cta_c);
}
// Update the scale with the stddev and eps.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
scale[i] *= rsqrtf(var[i] + params.var_eps);
}
// The base pointers for reading/writing
uint16_t *const gmem_src = &params.gmem_src[thread_c];
uint16_t *const gmem_dst = &params.gmem_dst[thread_c];
const uint16_t *gmem_src1 = nullptr;
if (USE_ADD_RELU) {
gmem_src1 = &params.gmem_src1[thread_c];
}
// apply BN
for (int nhw = thread_nhw; nhw < params.nhw; nhw += cta_nhw_stride) {
float x_math[ELEMENTS_PER_LDG];
zero_array(x_math);
if (is_valid_c) {
ldg(x_math, &gmem_src[nhw*params.c]);
}
// Normalize and apply activation function
normalize(x_math, bias, scale, mean);
if (USE_ADD_RELU) {
float x1_math[ELEMENTS_PER_LDG];
ldg(x1_math, &gmem_src1[nhw*params.c]);
add(x_math, x1_math);
relu_activation(x_math);
} else if (USE_RELU) {
relu_activation(x_math);
}
if (is_valid_c) {
stg(&gmem_dst[nhw*params.c], x_math);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
struct NhwcBatchNormFwdParams {
// The input/output tensors.
uint16_t *gmem_src, *gmem_dst, *gmem_src1;
// The bias/scale.
float *gmem_bias, *gmem_scale;
// running mean/var (refer BN API from cudnn doc)
float *gmem_running_mean, *gmem_running_var;
// saved mean/var (refer BN API from cudnn doc)
float *gmem_saved_mean, *gmem_saved_var;
// ReLU bitmask
unsigned int *gmem_relu_bitmask;
// The dimensions.
int nhw, c;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
float svar_inv_count;
// factor to scale sum of squared errors to get running variance. Should be 1/nhw or 1/(nhw-1).
float rvar_inv_count;
// The buffer to do the reduction for mean, stddev and count.
float *gmem_sums;
// The buffer to count items in the different CTAs.
int *gmem_counts;
// The counters of retired CTAs.
int *gmem_retired_ctas;
// The epsilon to apply to the computation of the variance.
float var_eps;
// outer loop count
int outer_loops;
// exponential average factor
float exp_avg_factor;
// number of CTAs along .x dimension
int c_blks;
void* my_data;
void* pair_datas[4];
int magic;
int sync_iters;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename Storage,
int THREADS_PER_CTA,
int THREADS_PER_PIXEL,
int PIXELS_PER_THREAD_IN_REGISTERS,
int PIXELS_PER_THREAD_IN_SMEM,
int ELEMENTS_PER_LDG,
int USE_ONLINE_APPROACH,
int OUTER_LOOPS_,
bool USE_RELU,
bool USE_ADD_RELU,
int DESIRED_OCCUPANCY
>
__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
void nhwc_batch_norm_fwd(NhwcBatchNormFwdParams params) {
// The number of pixels loaded in a single LDG.
const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
// The number of pixels computed per CTA stored in registers.
const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;
// The number of pixels computed per CTA stored in SMEM.
const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG;
// The number of C elements per CTA.
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
// Shared memory to do CTA-wide parallel sums.
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG];
// Compute the NHW coordinate of the thread in the CTA.
const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;
// The adapter for the storage.
typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;
// The data type for packed storage in SMEM.
typedef typename PackedStorage_::Type PackedStorageType;
// The number of elements in the packed storage.
const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;
// Registers to keep the data live for the persistent approach.
PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
// Shared memory buffer to store the extra pixels.
extern __shared__ PackedStorageType smem_storage_packed[];
for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {
// The position in the NHW dimension where the CTA starts.
int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;
// The position in the C dimension where the CTA starts.
const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;
// Compute the C coordinate of the thread in the CTA.
const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;
// Compute the C coordinate of the thread.
int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;
// Is the thread working on a valid C dimension?
const int is_valid_c = thread_c < params.c;
// Clamp thread_c so that we load from valid locations even if we don't use the value
if (!is_valid_c)
thread_c = params.c - 4;
// Single pass numerically stable algorithm, see:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm
//
// n = 0, mean = 0.0, M2 = 0.0
//
// for x in data:
// n += 1
// delta = x - mean
// mean += delta/n
// delta2 = x - mean
// M2 += delta*delta2
//
// if n < 2:
// return float('nan')
// else:
// return M2 / (n - 1)
// Register to store the number of elements read so far.
float count = 0.f, mean[ELEMENTS_PER_LDG], m2[ELEMENTS_PER_LDG];
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
mean[i] = 0.f;
m2[i] = 0.f;
}
// The number of elements loaded by this CTA.
int cta_count = 0;
// The base pointer to load from.
const uint16_t *gmem_src = &params.gmem_src[thread_c];
// outer loops
int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops;
// Load the batch of elements. Compute the mean/var across those elements.
const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x;
if (OUTER_LOOPS_ != 1) {
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized, offset is evenly divisible by 32
int offset = (pixels_per_iteration * OUTER_LOOPS +
PIXELS_PER_CTA_IN_SMEM * gridDim.x - params.nhw) & ~31;
cta_nhw_regs -= offset;
cta_nhw_smem -= offset;
}
#pragma unroll 1
for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {
// The nhw position.
int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count += max(min(nhw_regs + PIXELS_PER_CTA_IN_REGISTERS, params.nhw) -
max(nhw_regs, 0), 0);
// Load the data and compute the local mean/sum and the variance.
if (USE_ONLINE_APPROACH) {
// Read the elements from memory.
float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;
zero_array(x_storage[i]);
is_valid[i] = 0.f;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
if (loop_i == OUTER_LOOPS - 1) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
} else {
ldg(x_storage[i], &gmem_src[idx*params.c]);
}
is_valid[i] = 1.f;
}
}
// Do the math.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float.
float x_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
// Update the count.
count += is_valid[i];
// Invert the count.
float inv_count = is_valid[i] ? 1.f / count : 0.f;
// Update the mean and m2 using deltas.
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
float delta0 = x_math[j] - mean[j];
mean[j] += delta0 * inv_count;
float delta1 = x_math[j] - mean[j];
m2[j] += delta0 * delta1 * is_valid[i];
}
}
} else {
// Read the elements from memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;
zero_array(x_storage[i]);
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
if (loop_i == OUTER_LOOPS - 1) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
} else {
ldg(x_storage[i], &gmem_src[idx*params.c]);
}
count += 1.f;
}
}
// Sum the elements in registers.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float.
float x_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
// Update the mean and m2 using deltas.
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
mean[j] += x_math[j];
}
}
// Compute the mean.
float inv_count = 1.f / count;
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
mean[j] *= inv_count;
}
// Compute the variance.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float.
float x_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
// Is it a valid pixel?
float is_valid = i < static_cast<int>(count) ? 1.f : 0.f;
// Update the mean and m2 using deltas.
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
m2[j] += (x_math[j] - mean[j]) * (x_math[j] - mean[j]) * is_valid;
}
}
}
}
// The elements to load and store in SMEM.
int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem;
// Load elements from SMEM, update the CTA count.
int pixels_in_smem = min(smem_nhw + PIXELS_PER_CTA_IN_SMEM, params.nhw) - max(smem_nhw, 0);
if (pixels_in_smem > 0) {
cta_count += pixels_in_smem;
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
float is_pixel_valid = (((unsigned int)idx <
(unsigned int)params.nhw) && is_valid_c) ? 1.f : 0.f;
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG];
ldg_stream(x_storage_local, &gmem_src[(is_pixel_valid ? idx : 0)*params.c]);
// The offset to store in SMEM.
const int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
// Store in SMEM.
write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);
// Update the count.
count += is_pixel_valid;
// Invert the count.
float inv_count = is_pixel_valid ? 1.f / count : 0.f;
float x_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
// Update the mean and m2 using deltas.
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
float delta0 = x_math[j] - mean[j];
mean[j] += delta0 * inv_count;
float delta1 = x_math[j] - mean[j];
m2[j] += delta0 * delta1 * is_pixel_valid;
}
}
}
// We scale the mean by the number of elements. It brings more stability.
float m1[ELEMENTS_PER_LDG];
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
m1[i] = mean[i] * count;
}
// Run the parallel sum accross the CTA to get the local sum.
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, m1, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(m1, smem, thread_in_cta_c);
__syncthreads();
// Adjust the variance.
float inv_cta_count = 1.f / static_cast<float>(cta_count);
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
float mean_diff = m1[i]*inv_cta_count - mean[i];
m2[i] = m2[i] + mean_diff * mean_diff * count;
}
// Run the parallel sum accross the CTA to get the local adjusted variance.
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, m2, thread_in_cta_nhw);
// The workspace in global memory is distributed across the different CTA.
int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2;
// Write the data for the CTA to global memory.
float *gmem_sums = &params.gmem_sums[gmem_sums_offset];
if (threadIdx.x < THREADS_PER_PIXEL) {
const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x;
write_to_gmem(&gmem_sums[ 0], idx, m1);
write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, m2);
}
// The memory location to store the number of pixels per CTA.
int *gmem_counts = &params.gmem_counts[c_blk_index*gridDim.x];
if (threadIdx.x == 0) {
gmem_counts[blockIdx.x] = cta_count;
}
// Read the bias and scale.
float bias[ELEMENTS_PER_LDG], scale[ELEMENTS_PER_LDG];
if (is_valid_c) {
read_from_gmem(bias, &params.gmem_bias[cta_c], thread_in_cta_c);
read_from_gmem(scale, &params.gmem_scale[cta_c], thread_in_cta_c);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int *gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];
inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);
// Reset the mean to compute the global mean.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
m1[i] = 0.f;
}
// Build the global mean.
#pragma unroll 1
for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {
float tmp[ELEMENTS_PER_LDG];
read_from_gmem(tmp, gmem_sums, idx);
add(m1, tmp);
}
if (params.sync_iters>0)
{
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, m1, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+3, params.magic, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, m1, thread_in_cta_nhw);
}
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(m1, smem, thread_in_cta_c);
__syncthreads();
// Normalize the mean.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
m1[i] = m1[i] * params.svar_inv_count;
}
// Reset the variance.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
m2[i] = 0.f;
}
// for add+relu fusion
const uint16_t *gmem_src1 = nullptr;
if (USE_ADD_RELU) {
gmem_src1 = &params.gmem_src1[thread_c];
}
// Build the global variance.
#pragma unroll 1
for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {
// Read the means computed by different CTAs (again). Reuse tmp if we have 1 iteration.
float tmp_mean[ELEMENTS_PER_LDG], tmp_var[ELEMENTS_PER_LDG];
read_from_gmem(tmp_mean, &gmem_sums[ 0], idx);
read_from_gmem(tmp_var, &gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx);
// Read the number of pixels visited by a given CTA.
cta_count = __ldg(&gmem_counts[idx / THREADS_PER_PIXEL]);
// Compute the diff to update the variance.
float mean_diff[ELEMENTS_PER_LDG], inv_cta_count = 1.f / static_cast<float>(cta_count);
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
mean_diff[i] = m1[i] - tmp_mean[i]*inv_cta_count;
}
// Update the variance.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
m2[i] += tmp_var[i] + mean_diff[i]*mean_diff[i]*static_cast<float>(cta_count);
}
}
if (params.sync_iters>0)
{
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, m2, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+2, params.magic, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, m2, thread_in_cta_nhw);
}
__syncthreads();
read_from_smem(m2, smem, thread_in_cta_c);
// Finalize the stddev.
// becasue saved var and running var may have different denominator, we don't do it here
// scale_(m2, inv_count);
// store the saved mean/var
float svarinv[ELEMENTS_PER_LDG];
bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
svarinv[i] = rsqrtf(m2[i] * params.svar_inv_count + params.var_eps);
}
if (is_valid_for_saving) {
write_to_gmem(params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG, m1);
write_to_gmem(params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG, svarinv);
}
// store the running mean/var
float rmean[ELEMENTS_PER_LDG], rvar[ELEMENTS_PER_LDG];
zero_array(rmean);
zero_array(rvar);
if (params.exp_avg_factor != 1.f && is_valid_for_saving) {
read_from_gmem(rmean, params.gmem_running_mean, thread_c/ELEMENTS_PER_LDG);
read_from_gmem(rvar, params.gmem_running_var, thread_c/ELEMENTS_PER_LDG);
}
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
rmean[i] = (1.f - params.exp_avg_factor) * rmean[i] + \
params.exp_avg_factor * m1[i];
rvar[i] = (1.f - params.exp_avg_factor) * rvar[i] + \
params.exp_avg_factor * (m2[i] * params.rvar_inv_count);
}
if (is_valid_for_saving) {
write_to_gmem(params.gmem_running_mean, thread_c/ELEMENTS_PER_LDG, rmean);
write_to_gmem(params.gmem_running_var, thread_c/ELEMENTS_PER_LDG, rvar);
}
// Update the scale with the stddev and eps.
multiply(scale, svarinv);
// The base pointer to write to.
uint16_t *const gmem_dst = &params.gmem_dst[thread_c];
unsigned int *const gmem_relu_bitmask = params.gmem_relu_bitmask +
((params.nhw + 31) & ~31) * 2 * c_blk_index;
// Store the elements in registers.
#pragma unroll 1
for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) {
// The value for nhw.
int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration;
// Normalize the elements and write to memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_valid_nhw =
static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);
const bool is_valid = is_valid_nhw && is_valid_c;
// Convert to float.
float x_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
// Normalize and apply activation function
normalize(x_math, bias, scale, m1);
if (USE_ADD_RELU) {
float x1_math[ELEMENTS_PER_LDG];
ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0)*params.c]);
add(x_math, x1_math);
unsigned int relu_mask;
int lane_id = threadIdx.x & 31;
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
bool rectified = x_math[i] < 0.0F;
unsigned int local_relu_mask = __ballot_sync(0xFFFFFFFFU, rectified);
if (lane_id == i) {
// Thread 0 remembers the relu_mask from the first time through this
// loop, Thread 1 the next, Thread 2 the next, and Thread 3 the last.
relu_mask = local_relu_mask;
}
if (rectified) {
x_math[i] = 0.0F;
}
}
if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) {
gmem_relu_bitmask[idx * 2 + lane_id] = relu_mask;
}
} else if (USE_RELU) {
relu_activation(x_math);
}
// Write back.
if (is_valid) {
stg_stream(&gmem_dst[idx*params.c], x_math);
}
}
// The next value of nhw.
out_nhw -= pixels_per_iteration;
// Read the next elements from memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
}
}
}
// Normalize the elements from SMEM and write them out.
if (pixels_in_smem > 0) {
#pragma unroll 2
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_valid_nhw =
static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);
const bool is_valid = is_valid_nhw && is_valid_c;
// Read from SMEM.
const int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG];
read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);
float x_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
// Normalize and apply activation function
normalize(x_math, bias, scale, m1);
if (USE_ADD_RELU) {
float x1_math[ELEMENTS_PER_LDG];
ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0)*params.c]);
add(x_math, x1_math);
unsigned int relu_mask;
int lane_id = threadIdx.x & 31;
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
bool rectified = x_math[i] < 0.0F;
unsigned int local_relu_mask = __ballot_sync(0xFFFFFFFFU, rectified);
if (lane_id == i) {
relu_mask = local_relu_mask;
}
if (rectified) {
x_math[i] = 0.0F;
}
}
if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) {
gmem_relu_bitmask[idx * 2 + lane_id] = relu_mask;
}
} else if (USE_RELU) {
relu_activation(x_math);
}
// Write back.
if (is_valid) {
stg_stream(&gmem_dst[idx*params.c], x_math);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
struct NhwcBatchNormBwdParams {
// The input/output tensors.
uint16_t *gmem_src, *gmem_dy, *gmem_dst, *gmem_dst1;
// dscale/dbias
float *gmem_dscale, *gmem_dbias;
// The scale and bias.
float *gmem_scale, *gmem_bias;
// The mean/inv-var saved from fwd pass
float *gmem_saved_mean, *gmem_saved_var;
// ReLU bitmask
unsigned int *gmem_relu_bitmask;
// The dimensions.
int nhw, c;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
float svar_inv_count;
// The buffer to do the reduction for dscale and dbias
float *gmem_sums;
// The counters of retired CTAs.
int *gmem_retired_ctas;
// outer loop count
int outer_loops;
// number of CTAs along .x dimension
int c_blks;
void* my_data;
void* pair_datas[4];
int magic;
int sync_iters;
float wgrad_coeff;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void relu_bwd(float (&dy)[N], const float (&x)[N],
const float (&mean_var_scale_bias)[N],
const float (&var_scale)[N], bool valid_data) {
#pragma unroll
for (int j = 0; j < N; ++j) {
float y = (x[j] * var_scale[j]) + mean_var_scale_bias[j];
if ((y <= 0.f) && valid_data) {
dy[j] = 0.f;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void relu_bwd(float (&dy)[N], const float (&y)[N], bool valid_data) {
#pragma unroll
for (int j = 0; j < N; ++j) {
if ((y[j] <= 0.f) && valid_data) {
dy[j] = 0.f;
}
}
}
template <int N>
DEVICE_FUNCTION void relu_bwd(float (&dy)[N], const bool (&rectified)[N], bool valid_data) {
#pragma unroll
for (int j = 0; j < N; ++j) {
if (rectified[j] && valid_data) {
dy[j] = 0.f;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void relu_bwd_for_dx(float (&dy)[N],
const float (&x)[N],
const float (&mean_var_scale_bias)[N],
const float (&var_scale)[N]) {
#pragma unroll
for (int j = 0; j < N; ++j) {
float y = (x[j] * var_scale[j]) + mean_var_scale_bias[j];
if (y <= 0.f) {
dy[j] = 0.f;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void relu_bwd_for_dx(float (&dy)[N], const float (&y)[N]) {
#pragma unroll
for (int j = 0; j < N; ++j) {
if (y[j] <= 0.f) {
dy[j] = 0.f;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void bwd_update(float (&dscale)[N], float (&dbias)[N],
const float (&dy)[N], const float (&x)[N],
const float (&mean)[N], float inv_count) {
#pragma unroll
for (int j = 0; j < N; ++j) {
float delta0 = dy[j] - dbias[j];
dbias[j] += delta0 * inv_count;
delta0 = (dy[j] * (x[j] - mean[j])) - dscale[j];
dscale[j] += delta0 * inv_count;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void bwd_dx(float (&dx)[N], const float (&dy)[N],
const float (&var)[N], const float (&x)[N], const float (&mean)[N],
const float (&dscale)[N], const float (&dbias)[N], float inv_count) {
#pragma unroll
for (int j = 0; j < N; ++j) {
float tmp1 = dy[j] - (dbias[j]* inv_count);
float tmp2 = dscale[j] * inv_count;
float tmp3 = x[j] - mean[j];
dx[j] = var[j] * (tmp1 - (tmp2 * tmp3));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename Storage,
int THREADS_PER_CTA,
int THREADS_PER_PIXEL,
int PIXELS_PER_THREAD_IN_REGISTERS,
int PIXELS_PER_THREAD_IN_SMEM,
int ELEMENTS_PER_LDG,
int USE_ONLINE_APPROACH,
int OUTER_LOOPS_,
int DESIRED_OCCUPANCY
>
__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
void nhwc_batch_norm_bwd(NhwcBatchNormBwdParams params) {
// The number of pixels loaded in a single LDG.
const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
// The number of pixels computed per CTA stored in registers.
const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;
// The number of pixels computed per CTA stored in SMEM.
const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG;
// The number of C elements per CTA.
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
// Shared memory to do CTA-wide parallel sums.
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG];
// The adapter for the storage.
typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;
// The data type for packed storage in SMEM.
typedef typename PackedStorage_::Type PackedStorageType;
// The number of elements in the packed storage.
const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;
// Registers to keep the data live for the persistent approach.
PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
// Shared memory buffer to store the extra pixels.
extern __shared__ PackedStorageType smem_storage_packed[];
for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {
// The position in the NHW dimension where the CTA starts.
int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;
// Compute the NHW coordinate of the thread in the CTA.
const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;
// The position in the C dimension where the CTA starts.
const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;
// Compute the C coordinate of the thread in the CTA.
const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;
// Compute the C coordinate of the thread.
const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;
// Is the thread working on a valid C dimension?
const int is_valid_c = thread_c < params.c;
// Registers to store the mean used for entire duration
float mean[ELEMENTS_PER_LDG];
zero_array(mean);
if (is_valid_c) {
read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG);
}
// accumulation related registers
float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG];
zero_array(dscale);
zero_array(dbias);
// The number of elements loaded by this CTA.
int cta_count = 0;
// The base pointers to load from.
const uint16_t *gmem_src = &params.gmem_src[thread_c];
const uint16_t *gmem_dy = &params.gmem_dy[thread_c];
// outer loops
int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops;
// Load the batch of elements. Compute sum across them
const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x;
if (OUTER_LOOPS_ != 1) {
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized
int offset = params.nhw - pixels_per_iteration * OUTER_LOOPS -
PIXELS_PER_CTA_IN_SMEM * gridDim.x;
cta_nhw_regs += offset;
cta_nhw_smem += offset;
}
#pragma unroll 1
for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {
// The nhw position.
int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs));
// Read the elements from memory.
float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;
zero_array(x_storage[i]);
zero_array(dy_storage[i]);
is_valid[i] = 0.f;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
if (loop_i == OUTER_LOOPS - 1) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);
} else {
ldg(x_storage[i], &gmem_src[idx*params.c]);
ldg(dy_storage[i], &gmem_dy[idx*params.c]);
}
is_valid[i] = 1.f;
}
}
// Do the math.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float and update
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
to_float(dy_math, dy_storage[i]);
// Update the count.
count += is_valid[i];
// Invert the count.
float inv_count = is_valid[i] ? 1.f / count : 0.f;
bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);
}
}
// The elements to load and store in SMEM.
int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem;
// Load elements from SMEM, update the CTA count.
int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw);
if (pixels_in_smem > 0) {
cta_count += pixels_in_smem;
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
bool is_pixel_valid = (((unsigned int)idx <
(unsigned int)params.nhw) && is_valid_c);
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],
dy_storage_local[PACKED_ELEMENTS_PER_LDG];
zero_array(x_storage_local);
zero_array(dy_storage_local);
if (is_pixel_valid) {
ldg_stream(x_storage_local, &gmem_src[idx*params.c]);
ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]);
}
// The offset to store in SMEM.
int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
// Store in SMEM.
write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);
offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local);
// Update the count.
count += is_pixel_valid;
// Invert the count.
float inv_count = is_pixel_valid ? 1.f / count : 0.f;
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
to_float(dy_math, dy_storage_local);
bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);
}
}
// We scale the mean by the number of elements. It brings more stability.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
dbias[i] *= count;
dscale[i] *= count;
}
// dscale parallel sum
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dscale, smem, thread_in_cta_c);
__syncthreads();
// dbias parallel sum
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dbias, smem, thread_in_cta_c);
__syncthreads();
// The workspace in global memory is distributed across the different CTA.
int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2;
// Write the data for the CTA to global memory.
float *gmem_sums = &params.gmem_sums[gmem_sums_offset];
if (threadIdx.x < THREADS_PER_PIXEL) {
const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x;
write_to_gmem(&gmem_sums[ 0], idx, dscale);
write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int *gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];
inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);
// Reset the accumulators for global summation
zero_array(dscale);
zero_array(dbias);
// Build the global accumulation
#pragma unroll 1
for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {
float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG];
read_from_gmem(tmp1, gmem_sums, idx);
read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx);
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
dscale[i] += tmp1[i];
dbias[i] += tmp2[i];
}
}
// dscale parallel sum
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw);
}
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dscale, smem, thread_in_cta_c);
__syncthreads();
// dbias parallel sum
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw);
}
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dbias, smem, thread_in_cta_c);
// inv-var
float var[ELEMENTS_PER_LDG];
zero_array(var);
if (is_valid_c) {
read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG);
}
// Normalize the dscale.
multiply(dscale, var);
// store dscale/dbias
bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;
if (is_valid_for_saving) {
if (params.sync_iters>0)
{
scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff);
scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff);
} else {
write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale);
write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias);
}
}
// scale
float scale[ELEMENTS_PER_LDG];
zero_array(scale);
if (is_valid_c) {
read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG);
}
// Further normalize the dscale to be used in dx calculation
multiply(dscale, var);
// scale the inv-var as well, afterwards
multiply(var, scale);
// inverse count
float inv_count = params.svar_inv_count;
// The base pointer to write to.
uint16_t *const gmem_dst = &params.gmem_dst[thread_c];
// Store the elements in registers.
#pragma unroll 1
for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) {
// The value for nhw.
int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration;
// Normalize the elements and write to memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float.
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
to_float(dy_math, dy_storage[i]);
float dx[ELEMENTS_PER_LDG];
bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);
// Write back.
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
stg_stream(&gmem_dst[idx*params.c], dx);
}
}
// The next value of nhw.
out_nhw -= pixels_per_iteration;
// Read the next elements from memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);
}
}
}
// Normalize the elements from SMEM and write them out.
if (pixels_in_smem > 0) {
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;
if (is_valid) {
// Read from SMEM.
int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],
dy_storage_local[PACKED_ELEMENTS_PER_LDG];
read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);
offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x);
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
to_float(dy_math, dy_storage_local);
float dx[ELEMENTS_PER_LDG];
bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);
// Write back.
stg_stream(&gmem_dst[idx*params.c], dx);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename Storage,
int THREADS_PER_CTA,
int THREADS_PER_PIXEL,
int PIXELS_PER_THREAD_IN_REGISTERS,
int PIXELS_PER_THREAD_IN_SMEM,
int ELEMENTS_PER_LDG,
int USE_ONLINE_APPROACH,
int OUTER_LOOPS_,
int DESIRED_OCCUPANCY
>
__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
void nhwc_batch_norm_bwd_relu(NhwcBatchNormBwdParams params) {
// The number of pixels loaded in a single LDG.
const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
// The number of pixels computed per CTA stored in registers.
const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;
// The number of pixels computed per CTA stored in SMEM.
const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG;
// The number of C elements per CTA.
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
// Shared memory to do CTA-wide parallel sums.
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG];
// The adapter for the storage.
typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;
// The data type for packed storage in SMEM.
typedef typename PackedStorage_::Type PackedStorageType;
// The number of elements in the packed storage.
const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;
// Registers to keep the data live for the persistent approach.
PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
// Shared memory buffer to store the extra pixels.
extern __shared__ PackedStorageType smem_storage_packed[];
for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {
// The position in the NHW dimension where the CTA starts.
int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;
// Compute the NHW coordinate of the thread in the CTA.
const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;
// The position in the C dimension where the CTA starts.
const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;
// Compute the C coordinate of the thread in the CTA.
const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;
// Compute the C coordinate of the thread.
const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;
// Is the thread working on a valid C dimension?
const int is_valid_c = thread_c < params.c;
// Registers to store the mean/var/scale/bias used for the entire duration
// Register usage optimizations:
// 1. Can combine bias - (mean * var * scale) into a single register
// 2. Can combine var * scale into a single register
float varscale[ELEMENTS_PER_LDG];
zero_array(varscale);
if (is_valid_c) {
read_from_gmem(varscale, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG);
}
float tmp[ELEMENTS_PER_LDG];
zero_array(tmp);
if (is_valid_c) {
read_from_gmem(tmp, params.gmem_scale, thread_c/ELEMENTS_PER_LDG);
}
multiply(varscale, tmp);
float mean[ELEMENTS_PER_LDG];
zero_array(mean);
if (is_valid_c) {
read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG);
}
zero_array(tmp);
if (is_valid_c) {
read_from_gmem(tmp, params.gmem_bias, thread_c/ELEMENTS_PER_LDG);
}
float mean_var_scale_bias[ELEMENTS_PER_LDG];
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
mean_var_scale_bias[i] = tmp[i] - (mean[i] * varscale[i]);
}
// accumulation related registers
float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG];
zero_array(dscale);
zero_array(dbias);
// The number of elements loaded by this CTA.
int cta_count = 0;
// The base pointers to load from.
const uint16_t *gmem_src = &params.gmem_src[thread_c];
const uint16_t *gmem_dy = &params.gmem_dy[thread_c];
// outer loops
int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops;
// Load the batch of elements. Compute sum across them
const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x;
if (OUTER_LOOPS_ != 1) {
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized
int offset = params.nhw - pixels_per_iteration * OUTER_LOOPS -
PIXELS_PER_CTA_IN_SMEM * gridDim.x;
cta_nhw_regs += offset;
cta_nhw_smem += offset;
}
#pragma unroll 1
for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {
// The nhw position.
int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs));
// Read the elements from memory.
float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;
zero_array(x_storage[i]);
zero_array(dy_storage[i]);
is_valid[i] = 0.f;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
if (loop_i == OUTER_LOOPS - 1) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);
} else {
ldg(x_storage[i], &gmem_src[idx*params.c]);
ldg(dy_storage[i], &gmem_dy[idx*params.c]);
}
is_valid[i] = 1.f;
}
}
// Do the math.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float and update
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
to_float(dy_math, dy_storage[i]);
// Update the count.
count += is_valid[i];
// Invert the count.
float inv_count = is_valid[i] ? 1.f / count : 0.f;
relu_bwd(dy_math, x_math, mean_var_scale_bias, varscale, is_valid[i]);
bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);
}
}
// The elements to load and store in SMEM.
int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem;
// Load elements from SMEM, update the CTA count.
int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw);
if (pixels_in_smem > 0) {
cta_count += pixels_in_smem;
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
bool is_pixel_valid = (((unsigned int)idx <
(unsigned int)params.nhw) && is_valid_c);
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],
dy_storage_local[PACKED_ELEMENTS_PER_LDG];
zero_array(x_storage_local);
zero_array(dy_storage_local);
if (is_pixel_valid) {
ldg_stream(x_storage_local, &gmem_src[idx*params.c]);
ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]);
}
// The offset to store in SMEM.
int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
// Store in SMEM.
write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);
offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local);
// Update the count.
count += is_pixel_valid;
// Invert the count.
float inv_count = is_pixel_valid ? 1.f / count : 0.f;
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
to_float(dy_math, dy_storage_local);
relu_bwd(dy_math, x_math, mean_var_scale_bias, varscale, is_pixel_valid);
bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);
}
}
// We scale the mean by the number of elements. It brings more stability.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
dbias[i] *= count;
dscale[i] *= count;
}
// dscale parallel sum
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dscale, smem, thread_in_cta_c);
__syncthreads();
// dbias parallel sum
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dbias, smem, thread_in_cta_c);
__syncthreads();
// The workspace in global memory is distributed across the different CTA.
int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2;
// Write the data for the CTA to global memory.
float *gmem_sums = &params.gmem_sums[gmem_sums_offset];
if (threadIdx.x < THREADS_PER_PIXEL) {
const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x;
write_to_gmem(&gmem_sums[ 0], idx, dscale);
write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int *gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];
inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);
// Reset the accumulators for global summation
zero_array(dscale);
zero_array(dbias);
// Build the global accumulation
#pragma unroll 1
for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {
float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG];
read_from_gmem(tmp1, gmem_sums, idx);
read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx);
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
dscale[i] += tmp1[i];
dbias[i] += tmp2[i];
}
}
// dscale parallel sum
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw);
}
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dscale, smem, thread_in_cta_c);
__syncthreads();
// dbias parallel sum
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw);
}
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dbias, smem, thread_in_cta_c);
// Normalize the dscale.
float var[ELEMENTS_PER_LDG];
zero_array(var);
if (is_valid_c) {
read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG);
}
multiply(dscale, var);
// store dscale/dbias
bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;
if (is_valid_for_saving) {
if (params.sync_iters>0)
{
scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff);
scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff);
} else {
write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale);
write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias);
}
}
// Further normalize the dscale to be used in dx calculation
float scale[ELEMENTS_PER_LDG];
zero_array(scale);
if (is_valid_c) {
read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG);
}
multiply(dscale, var);
// scale the inv-var as well, afterwards
multiply(var, scale);
// inverse count
float inv_count = params.svar_inv_count;
// The base pointer to write to.
uint16_t *const gmem_dst = &params.gmem_dst[thread_c];
// Store the elements in registers.
#pragma unroll 1
for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) {
// The value for nhw.
int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration;
// Normalize the elements and write to memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float.
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
to_float(dy_math, dy_storage[i]);
relu_bwd_for_dx(dy_math, x_math, mean_var_scale_bias, var);
float dx[ELEMENTS_PER_LDG];
bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);
// Write back.
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
stg_stream(&gmem_dst[idx*params.c], dx);
}
}
// The next value of nhw.
out_nhw -= pixels_per_iteration;
// Read the next elements from memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);
}
}
}
// Normalize the elements from SMEM and write them out.
if (pixels_in_smem > 0) {
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;
if (is_valid) {
// Read from SMEM.
int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],
dy_storage_local[PACKED_ELEMENTS_PER_LDG];
read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);
offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x);
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
to_float(dy_math, dy_storage_local);
relu_bwd_for_dx(dy_math, x_math, mean_var_scale_bias, var);
float dx[ELEMENTS_PER_LDG];
bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);
// Write back.
stg_stream(&gmem_dst[idx*params.c], dx);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename Storage,
int THREADS_PER_CTA,
int THREADS_PER_PIXEL,
int PIXELS_PER_THREAD_IN_REGISTERS,
int PIXELS_PER_THREAD_IN_SMEM,
int ELEMENTS_PER_LDG,
int USE_ONLINE_APPROACH,
int OUTER_LOOPS_,
int DESIRED_OCCUPANCY
>
__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
void nhwc_batch_norm_bwd_add_relu(NhwcBatchNormBwdParams params) {
// The number of pixels loaded in a single LDG.
const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
// The number of pixels computed per CTA stored in registers.
const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;
// The number of pixels computed per CTA stored in SMEM.
const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG;
// The number of C elements per CTA.
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
// Shared memory to do CTA-wide parallel sums.
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG];
// The adapter for the storage.
typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;
// The data type for packed storage in SMEM.
typedef typename PackedStorage_::Type PackedStorageType;
// The number of elements in the packed storage.
const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;
// Registers to keep the data live for the persistent approach.
PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
// Shared memory buffer to store the extra pixels.
extern __shared__ PackedStorageType smem_storage_packed[];
for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {
// The position in the NHW dimension where the CTA starts.
int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;
// Compute the NHW coordinate of the thread in the CTA.
const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;
// The position in the C dimension where the CTA starts.
const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;
// Compute the C coordinate of the thread in the CTA.
const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;
// Compute the C coordinate of the thread.
const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;
// Is the thread working on a valid C dimension?
const int is_valid_c = thread_c < params.c;
float mean[ELEMENTS_PER_LDG];
zero_array(mean);
if (is_valid_c) {
read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG);
}
// accumulation related registers
float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG];
zero_array(dscale);
zero_array(dbias);
// The number of elements loaded by this CTA.
int cta_count = 0;
// The base pointers to load from.
const uint16_t *gmem_src = &params.gmem_src[thread_c];
const uint16_t *gmem_dy = &params.gmem_dy[thread_c];
uint16_t *gmem_dst1 = &params.gmem_dst1[thread_c];
// outer loops
int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops;
// Load the batch of elements. Compute sum across them
const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x;
if (OUTER_LOOPS_ != 1) {
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized, offset is evenly divisible by 32
int offset = (pixels_per_iteration * OUTER_LOOPS + PIXELS_PER_CTA_IN_SMEM * gridDim.x -
params.nhw) & ~31;
cta_nhw_regs -= offset;
cta_nhw_smem -= offset;
}
const unsigned int *const gmem_relu_bitmask = params.gmem_relu_bitmask +
((params.nhw + 31) & ~31) * 2 * c_blk_index;
#pragma unroll 1
for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {
// The nhw position.
int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs));
int lane_id = threadIdx.x & 31;
// Read the elements from memory.
float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];
unsigned int relu_mask[PIXELS_PER_THREAD_IN_REGISTERS];
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;
zero_array(x_storage[i]);
zero_array(dy_storage[i]);
is_valid[i] = 0.f;
const bool is_valid_nhw =
static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);
if (is_valid_nhw) {
if (is_valid_c) {
if (loop_i == OUTER_LOOPS - 1) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);
} else {
ldg(x_storage[i], &gmem_src[idx*params.c]);
ldg(dy_storage[i], &gmem_dy[idx*params.c]);
}
is_valid[i] = 1.f;
}
if (lane_id < ELEMENTS_PER_LDG) {
relu_mask[i] = gmem_relu_bitmask[idx * 2 + lane_id];
}
}
}
// Do the math.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;
// Convert to float and update
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
bool rectified[ELEMENTS_PER_LDG];
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
rectified[j] = ((__shfl_sync(0xFFFFFFFFU, relu_mask[i], j) &
(1U << lane_id)) != 0);
}
to_float(x_math, x_storage[i]);
to_float(dy_math, dy_storage[i]);
// Update the count.
count += is_valid[i];
// Invert the count.
float inv_count = is_valid[i] ? 1.f / count : 0.f;
relu_bwd(dy_math, rectified, is_valid[i]);
bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);
// Lastly we need 'dy' only for BN, so store the 'relu-dgrad'ed version
from_float(dy_storage[i], dy_math);
// dZ for elementwise add
if (is_valid[i]) {
if (loop_i == OUTER_LOOPS - 1) {
stg_stream(&gmem_dst1[idx*params.c], dy_storage[i]);
} else {
stg(&gmem_dst1[idx*params.c], dy_storage[i]);
}
}
}
}
// The elements to load and store in SMEM.
int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem;
// Load elements from SMEM, update the CTA count.
int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw);
if (pixels_in_smem > 0) {
cta_count += pixels_in_smem;
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_pixel_valid_nhw =
static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);
const bool is_pixel_valid = is_pixel_valid_nhw && is_valid_c;
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],
dy_storage_local[PACKED_ELEMENTS_PER_LDG];
unsigned int relu_mask;
int lane_id = threadIdx.x & 31;
zero_array(x_storage_local);
zero_array(dy_storage_local);
if (is_pixel_valid_nhw) {
if (is_valid_c) {
ldg_stream(x_storage_local, &gmem_src[idx*params.c]);
ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]);
}
if (lane_id < ELEMENTS_PER_LDG) {
relu_mask = gmem_relu_bitmask[idx * 2 + lane_id];
}
}
bool rectified[ELEMENTS_PER_LDG];
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
rectified[j] = ((__shfl_sync(0xFFFFFFFFU, relu_mask, j) &
(1U << lane_id)) != 0);
}
// The offset to store in SMEM.
int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
// Store in SMEM.
write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);
offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
// Update the count.
count += is_pixel_valid;
// Invert the count.
float inv_count = is_pixel_valid ? 1.f / count : 0.f;
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
to_float(dy_math, dy_storage_local);
relu_bwd(dy_math, rectified, is_pixel_valid);
bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);
from_float(dy_storage_local, dy_math);
// dZ for elementwise add
if (is_pixel_valid) {
stg_stream(&gmem_dst1[idx*params.c], dy_storage_local);
}
// only store the 'relu-dgrad'ed version!
write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local);
}
}
// We scale the mean by the number of elements. It brings more stability.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
dbias[i] *= count;
dscale[i] *= count;
}
// dscale parallel sum
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dscale, smem, thread_in_cta_c);
__syncthreads();
// dbias parallel sum
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dbias, smem, thread_in_cta_c);
__syncthreads();
// The workspace in global memory is distributed across the different CTA.
int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2;
// Write the data for the CTA to global memory.
float *gmem_sums = &params.gmem_sums[gmem_sums_offset];
if (threadIdx.x < THREADS_PER_PIXEL) {
const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x;
write_to_gmem(&gmem_sums[ 0], idx, dscale);
write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int *gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];
inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);
// Reset the accumulators for global summation
zero_array(dscale);
zero_array(dbias);
// Build the global accumulation
#pragma unroll 1
for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {
float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG];
read_from_gmem(tmp1, gmem_sums, idx);
read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx);
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
dscale[i] += tmp1[i];
dbias[i] += tmp2[i];
}
}
// dscale parallel sum
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw);
}
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dscale, smem, thread_in_cta_c);
__syncthreads();
// dbias parallel sum
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw);
}
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dbias, smem, thread_in_cta_c);
// Normalize the dscale.
float var[ELEMENTS_PER_LDG];
zero_array(var);
if (is_valid_c) {
read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG);
}
multiply(dscale, var);
// store dscale/dbias
bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;
if (is_valid_for_saving) {
if (params.sync_iters>0)
{
scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff);
scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff);
} else {
write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale);
write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias);
}
}
// Further normalize the dscale to be used in dx calculation
float scale[ELEMENTS_PER_LDG];
zero_array(scale);
if (is_valid_c) {
read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG);
}
multiply(dscale, var);
// scale the inv-var as well, afterwards
multiply(var, scale);
// inverse count
float inv_count = params.svar_inv_count;
// The base pointer to write to.
uint16_t *const gmem_dst = &params.gmem_dst[thread_c];
// Store the elements in registers.
#pragma unroll 1
for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) {
// The value for nhw.
int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration;
// Normalize the elements and write to memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;
// Convert to float.
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
to_float(dy_math, dy_storage[i]);
float dx[ELEMENTS_PER_LDG];
bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);
// Write back.
if (is_valid) {
stg_stream(&gmem_dst[idx*params.c], dx);
}
}
// The next value of nhw.
out_nhw -= pixels_per_iteration;
// Read the next elements from memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
float y[ELEMENTS_PER_LDG];
zero_array(y);
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
ldg_stream(dy_storage[i], &gmem_dst1[idx*params.c]);
}
}
}
// Normalize the elements from SMEM and write them out.
if (pixels_in_smem > 0) {
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;
if (is_valid) {
// Read from SMEM.
int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],
dy_storage_local[PACKED_ELEMENTS_PER_LDG];
read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);
offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x);
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
to_float(dy_math, dy_storage_local);
float dx[ELEMENTS_PER_LDG];
bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);
// Write back.
stg_stream(&gmem_dst[idx*params.c], dx);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads();
}
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_
#include <torch/extension.h>
// CUDA forward declarations
std::vector<at::Tensor> softmax_xentropy_cuda(
const at::Tensor &input,
const at::Tensor &labels,
const float smoothing,
const bool half_to_float);
at::Tensor softmax_xentropy_backward_cuda(
const at::Tensor &grad_loss,
const at::Tensor &logits,
const at::Tensor &max_log_sum_exp,
const at::Tensor &labels,
const float smoothing);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> softmax_xentropy_forward(
const at::Tensor &input,
const at::Tensor &labels,
const float smoothing,
const bool half_to_float) {
CHECK_CUDA(input);
CHECK_INPUT(labels);
return softmax_xentropy_cuda(input, labels, smoothing, half_to_float);
}
at::Tensor softmax_xentropy_backward(
const at::Tensor &grad_loss,
const at::Tensor &logits,
const at::Tensor &max_log_sum_exp,
const at::Tensor &labels,
const float smoothing) {
CHECK_CUDA(grad_loss);
CHECK_CUDA(logits);
CHECK_INPUT(max_log_sum_exp);
CHECK_INPUT(labels);
return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)");
m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)");
}
/**
* From PyTorch:
*
* Copyright (c) 2016- Facebook, Inc (Adam Paszke)
* Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
* Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
* Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
* Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
* Copyright (c) 2011-2013 NYU (Clement Farabet)
* Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
* Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
* Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
*
* From Caffe2:
*
* Copyright (c) 2016-present, Facebook Inc. All rights reserved.
*
* All contributions by Facebook:
* Copyright (c) 2016 Facebook Inc.
*
* All contributions by Google:
* Copyright (c) 2015 Google Inc.
* All rights reserved.
*
* All contributions by Yangqing Jia:
* Copyright (c) 2015 Yangqing Jia
* All rights reserved.
*
* All contributions from Caffe:
* Copyright(c) 2013, 2014, 2015, the respective contributors
* All rights reserved.
*
* All other contributions:
* Copyright(c) 2015, 2016 the respective contributors
* All rights reserved.
*
* Caffe2 uses a copyright model similar to Caffe: each contributor holds
* copyright over their contributions to Caffe2. The project versioning records
* all such contribution and copyright details. If a contributor wants to further
* mark their specific copyright on a particular contribution, they should
* indicate their copyright solely in the commit message of the change when it is
* committed.
*
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
* and IDIAP Research Institute nor the names of its contributors may be
* used to endorse or promote products derived from this software without
* specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/NumericLimits.cuh>
#include <THC/THC.h>
#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>
#include "type_shim.h"
using Tensor = at::Tensor;
using TensorList = at::TensorList;
using ScalarType = at::ScalarType;
using at::acc_type;
template<typename T, typename AccumT, typename OutT>
struct LogSoftMaxForwardEpilogue {
__device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum)
: logsum(max_input + std::log(sum)) {}
__device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp)
: logsum(max_log_sum_exp) {}
__device__ __forceinline__ OutT operator()(T input) const {
return static_cast<OutT>(input - logsum);
}
const AccumT logsum;
};
template<typename T, typename AccumT, typename OutT>
struct LogSoftMaxBackwardEpilogue {
__device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum)
: sum(sum) {}
__device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const {
return static_cast<T>(gradOutput - std::exp(static_cast<AccumT>(output)) * sum);
}
const AccumT sum;
};
const int max_threads = 1024;
inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {
uint64_t block_size = 1;
uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads));
while (block_size < max_block_size) block_size *= 2;
// Launch at least a single warp - the kernel assumes that.
block_size = std::max(block_size, static_cast<uint64_t>(32));
return dim3(block_size);
}
template<typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};
template<typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
////////////////////////////////////////////////////////////////////////////////
// Regular kernel (fast when dim_size is large; requires inner_size == 1)
////////////////////////////////////////////////////////////////////////////////
template <typename T, typename AccumT>
struct MaxFloat
{
__device__ __forceinline__ AccumT operator()(AccumT max, T v) const {
return ::max(max, (AccumT)v);
}
};
template<typename T, typename AccumT>
struct AddFloat
{
__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
return sum + v;
}
};
template<typename T, typename AccumT>
struct SumExpFloat
{
__device__ __forceinline__ SumExpFloat(AccumT v)
: max_k(v) {}
__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
return sum + std::exp(v - max_k);
}
const AccumT max_k;
};
template <template<typename> class Reduction, typename AccumT>
__device__ __forceinline__ AccumT
blockReduce(AccumT* smem, AccumT val,
const Reduction<AccumT>& r,
AccumT defaultVal)
{
// To avoid RaW races from chaining blockReduce calls together, we need a sync here
__syncthreads();
smem[threadIdx.x] = val;
__syncthreads();
AccumT warpVal = defaultVal;
// First warp will perform per-warp reductions for the remaining warps
uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;
if (threadIdx.x < 32) {
int lane = threadIdx.x % 32;
if (lane < blockDim.x / 32) {
#pragma unroll
for (int i = 0; i < 32; ++i) {
warpVal = r(warpVal, smem[lane * 32 + i]);
}
__syncwarp(mask);
smem[lane] = warpVal;
}
}
__syncthreads();
// First thread will perform a reduction of the above per-warp reductions
AccumT blockVal = defaultVal;
if (threadIdx.x == 0) {
for (int i = 0; i < blockDim.x / 32; ++i) {
blockVal = r(blockVal, smem[i]);
}
smem[0] = blockVal;
}
// Sync and broadcast
__syncthreads();
return smem[0];
}
template <template<typename> class Reduction1, template<typename> class Reduction2, typename AccumT>
__device__ __forceinline__ void
blockReduce(AccumT* smem,
AccumT* reducVal1,
AccumT val1,
const Reduction1<AccumT>& r1,
AccumT defaultVal1,
AccumT* reducVal2,
AccumT val2,
const Reduction2<AccumT>& r2,
AccumT defaultVal2)
{
// To avoid RaW races from chaining blockReduce calls together, we need a sync here
__syncthreads();
smem[threadIdx.x] = val1;
smem[blockDim.x + threadIdx.x] = val2;
__syncthreads();
AccumT warpVal1 = defaultVal1;
AccumT warpVal2 = defaultVal2;
// First warp will perform per-warp reductions for the remaining warps
uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;
if (threadIdx.x < 32) {
int lane = threadIdx.x % 32;
if (lane < blockDim.x / 32) {
#pragma unroll
for (int i = 0; i < 32; ++i) {
warpVal1 = r1(warpVal1, smem[lane * 32 + i]);
warpVal2 = r2(warpVal2, smem[lane * 32 + i + blockDim.x]);
}
__syncwarp(mask);
smem[lane] = warpVal1;
smem[lane + blockDim.x] = warpVal2;
}
}
__syncthreads();
// First thread will perform a reduction of the above per-warp reductions
AccumT blockVal1 = defaultVal1;
AccumT blockVal2 = defaultVal2;
if (threadIdx.x == 0) {
for (int i = 0; i < blockDim.x / 32; ++i) {
blockVal1 = r1(blockVal1, smem[i]);
blockVal2 = r2(blockVal2, smem[i + blockDim.x]);
}
smem[0] = blockVal1;
smem[blockDim.x] = blockVal2;
}
// Sync and broadcast
__syncthreads();
*reducVal1 = smem[0];
*reducVal2 = smem[blockDim.x];
__syncthreads();
}
template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT>
__device__ __forceinline__ AccumT
ilpReduce(T* data,
int size,
const Reduction<T, AccumT>& r,
AccumT defaultVal)
{
AccumT threadVal = defaultVal;
int offset = threadIdx.x;
int last = size % (ILP * blockDim.x);
// Body (unroll by ILP times)
for (; offset < size - last; offset += blockDim.x * ILP) {
T tmp[ILP];
#pragma unroll
for (int j = 0; j < ILP; ++j)
tmp[j] = data[offset + j * blockDim.x];
#pragma unroll
for (int j = 0; j < ILP; ++j)
threadVal = r(threadVal, tmp[j]);
}
// Epilogue
for (; offset < size; offset += blockDim.x)
threadVal = r(threadVal, data[offset]);
return threadVal;
}
template <template<typename, typename> class Reduction1, template<typename, typename> class Reduction2, int ILP, typename T, typename AccumT>
__device__ __forceinline__ void
ilpReduce(T* data,
int size,
AccumT* reducVal1,
const Reduction1<T, AccumT>& r1,
AccumT defaultVal1,
AccumT* reducVal2,
const Reduction2<T, AccumT>& r2,
AccumT defaultVal2)
{
AccumT threadVal1 = defaultVal1;
AccumT threadVal2 = defaultVal2;
int offset = threadIdx.x;
int last = size % (ILP * blockDim.x);
// Body (unroll by ILP times)
for (; offset < size - last; offset += blockDim.x * ILP) {
T tmp[ILP];
#pragma unroll
for (int j = 0; j < ILP; ++j)
tmp[j] = data[offset + j * blockDim.x];
#pragma unroll
for (int j = 0; j < ILP; ++j) {
threadVal1 = r1(threadVal1, tmp[j]);
threadVal2 = r2(threadVal2, tmp[j]);
}
}
// Epilogue
for (; offset < size; offset += blockDim.x) {
threadVal1 = r1(threadVal1, data[offset]);
threadVal2 = r2(threadVal2, data[offset]);
}
*reducVal1 = threadVal1;
*reducVal2 = threadVal2;
}
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template <typename, typename, typename> class Epilogue>
__global__ void
cunn_SoftMaxXEntropyForward(
accscalar_t *losses,
outscalar_t *max_log_sum_exp,
scalar_t *input,
int64_t *labels,
int64_t classes,
const float smoothing)
{
extern __shared__ unsigned char smem[];
auto sdata = reinterpret_cast<accscalar_t*>(smem);
// forward pointers to batch[blockIdx.x]
// each block handles a sample in the mini-batch
input += blockIdx.x * classes;
//output += blockIdx.x * classes;
int64_t label = labels[blockIdx.x];
// find the max and sum
accscalar_t threadMax, threadSum, max_k, sum_k;
ilpReduce<MaxFloat, AddFloat, ILP, scalar_t, accscalar_t>(
input, classes,
&threadMax, MaxFloat<scalar_t, accscalar_t>(),
-at::numeric_limits<accscalar_t>::max(),
&threadSum, AddFloat<scalar_t, accscalar_t>(),
static_cast<accscalar_t>(0));
blockReduce<Max, Add, accscalar_t>(
sdata,
&max_k, threadMax, Max<accscalar_t>(),
-at::numeric_limits<accscalar_t>::max(),
&sum_k, threadSum, Add<accscalar_t>(),
static_cast<accscalar_t>(0));
// reduce all values
accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(
input, classes, SumExpFloat<scalar_t, accscalar_t>(max_k), static_cast<accscalar_t>(0));
accscalar_t sumAll = blockReduce<Add, accscalar_t>(
sdata, threadExp, Add<accscalar_t>(), static_cast<accscalar_t>(0));
Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll);
// calculate per element loss with label smoothing
// reserve max + log_sum_exp for bprop
if (threadIdx.x == 0) {
accscalar_t log_prob = epilogue(static_cast<accscalar_t>(input[label]));
losses[blockIdx.x] = (max_k + std::log(sumAll) - sum_k / classes) \
* smoothing - log_prob * (1 - smoothing);
max_log_sum_exp[blockIdx.x] = max_k + std::log(sumAll);
}
}
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
__global__ void
cunn_SoftMaxXEntropyBackward(
scalar_t *gradInput,
scalar_t *logits,
outscalar_t *max_log_sum_exp,
outscalar_t *gradOutput,
int64_t *labels,
const float smoothing,
int classes)
{
gradInput += blockIdx.x * classes;
logits += blockIdx.x * classes;
accscalar_t smooth_positives = 1.0 - smoothing;
accscalar_t smooth_negatives = smoothing / classes;
accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
int64_t label = labels[blockIdx.x];
accscalar_t coeff = max_log_sum_exp[blockIdx.x];
int offset = threadIdx.x;
int last = classes % (ILP * blockDim.x);
for (; offset < classes - last; offset += blockDim.x * ILP) {
accscalar_t tmpLogits[ILP];
#pragma unroll
for (int j = 0; j < ILP; ++j) {
tmpLogits[j] = static_cast<accscalar_t>(logits[offset + j * blockDim.x]);
}
#pragma unroll
for (int j = 0; j < ILP; ++j)
gradInput[offset + j * blockDim.x] = tmpGradOutput * (
std::exp(tmpLogits[j] - coeff) - static_cast<accscalar_t>(
(offset + j * blockDim.x == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
}
for (; offset < classes; offset += blockDim.x)
gradInput[offset] = tmpGradOutput * (std::exp(
static_cast<accscalar_t>(logits[offset]) - coeff) -
static_cast<accscalar_t>((offset == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
}
template<template<typename, typename, typename> class Epilogue>
std::vector<Tensor> host_softmax_xentropy(
const Tensor & input_,
const Tensor & labels_,
const float smoothing,
const bool half_to_float){
if (half_to_float) AT_ASSERTM(input_.type().scalarType() == ScalarType::Half,"conversion is supported for Half type only");
AT_ASSERTM(labels_.type().scalarType() == ScalarType::Long,"Label type should be CUDA Long");
auto input = input_.contiguous();
Tensor max_log_sum_exp = at::empty_like(labels_, half_to_float ? input.options().dtype(ScalarType::Float) : input.options());
Tensor losses = at::empty_like(labels_, input_.options().dtype(ScalarType::Float));
static_assert(std::is_same<acc_type<at::Half, true>, float>::value ||
std::is_same<acc_type<at::Half, true>, double>::value,
"accscalar_t for half should be float or double");
AT_ASSERTM(input.dim() == 2, "Currently only 2 dim input supported");
AT_ASSERTM(labels_.dim() == 1, "Labels should be 1 dimensional");
AT_ASSERTM(input.size(0) == labels_.size(0), "Input and label should have same number of examples");
AT_ASSERTM(input.numel() > 0, "Number of classes in input should not be 0");
const int64_t dim = 1;
int64_t outer_size = 1;
int64_t dim_size = input.size(dim);
int64_t inner_size = 1;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
for (int64_t i = 0; i < dim; ++i)
outer_size *= input.size(i);
for (int64_t i = dim + 1; i < input.dim(); ++i)
inner_size *= input.size(i);
// This kernel spawns a block per each element in the batch.
// XXX: it assumes that inner_size == 1
AT_CHECK(inner_size == 1, "Currently only inner size 1 supported");
const int ILP = 2;
dim3 grid(outer_size);
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
using namespace at;
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "host_softmax_xentropy",
using accscalar_t = at::acc_type<scalar_t_0, true>;
if (!half_to_float) {
cunn_SoftMaxXEntropyForward<ILP, scalar_t_0, accscalar_t, scalar_t_0, Epilogue>
<<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(
losses.data<accscalar_t>(), max_log_sum_exp.data<scalar_t_0>(),
input.data<scalar_t_0>(), labels_.data<int64_t>(),
dim_size, smoothing
);
} else {
cunn_SoftMaxXEntropyForward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue>
<<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(
losses.data<accscalar_t>(), max_log_sum_exp.data<accscalar_t>(),
input.data<scalar_t_0>(), labels_.data<int64_t>(),
dim_size, smoothing
);
}
);
THCudaCheck(cudaGetLastError());
std::vector<at::Tensor> ret = {losses, max_log_sum_exp};
return ret;
}
template<template<typename, typename, typename> class Epilogue>
Tensor host_softmax_xentropy_backward(
const at::Tensor &grad_loss,
const at::Tensor &logits_,
const at::Tensor &max_log_sum_exp,
const at::Tensor &labels,
const float smoothing,
bool half_to_float) {
const int64_t dim = 1;
Tensor gI = at::empty_like(logits_);
if (grad_loss.numel() == 0) {
return gI;
}
auto grad = grad_loss.contiguous();
auto logits = logits_.contiguous();
static_assert(std::is_same<acc_type<at::Half, true>, float>::value ||
std::is_same<acc_type<at::Half, true>, double>::value,
"accscalar_t for half should be float or double");
if (grad.dim() == 0) grad = grad.view(1);
AT_ASSERTM(logits_.dim() == 2, "Currently only 2 dim input supported");
AT_ASSERTM(labels.dim() == 1, "Labels should be 1 dimensional");
AT_ASSERTM(logits_.numel() > 0, "Number of classes in input should not be 0");
AT_ASSERTM(logits_.size(0) == labels.size(0), "Input and label should have same number of examples");
AT_ASSERTM(labels.size(0) == grad.size(0), "Label and loss should have same number of examples");
int64_t outer_size = 1;
int64_t dim_size = logits.size(dim);
int64_t inner_size = 1;
for (int64_t i = 0; i < dim; ++i)
outer_size *= logits.size(i);
for (int64_t i = dim + 1; i < logits.dim(); ++i)
inner_size *= logits.size(i);
// See descriptions of kernels above.
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_CHECK(inner_size == 1, "Currently only inner size 1 supported");
const int ILP = 2;
dim3 grid(outer_size);
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
DISPATCH_FLOAT_AND_HALF(gI.scalar_type(), 0, "host_softmax_xentropy_backward",
using accscalar_t = acc_type<scalar_t_0, true>;
if (!half_to_float) {
cunn_SoftMaxXEntropyBackward<ILP, scalar_t_0, accscalar_t, scalar_t_0, Epilogue>
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
gI.data<scalar_t_0>(), logits.data<scalar_t_0>(),
max_log_sum_exp.data<scalar_t_0>(),
grad.data<scalar_t_0>(), labels.data<int64_t>(),
smoothing, dim_size
);
} else {
cunn_SoftMaxXEntropyBackward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue>
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
gI.data<scalar_t_0>(), logits.data<scalar_t_0>(),
max_log_sum_exp.data<accscalar_t>(),
grad.data<accscalar_t>(), labels.data<int64_t>(),
smoothing, dim_size
);
}
);
THCudaCheck(cudaGetLastError());
return gI;
}
std::vector<Tensor> softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing, const bool half_to_float){
return host_softmax_xentropy<LogSoftMaxForwardEpilogue>(input, labels, smoothing, half_to_float);
}
at::Tensor softmax_xentropy_backward_cuda(
const at::Tensor &grad_loss,
const at::Tensor &logits,
const at::Tensor &max_log_sum_exp,
const at::Tensor &labels,
const float smoothing) {
bool half_to_float = grad_loss.type().scalarType() != logits.type().scalarType();
if (half_to_float) {
AT_ASSERTM((grad_loss.type().scalarType() == ScalarType::Float && logits.type().scalarType() == ScalarType::Half), "expected input and grad types to match, or input to be at::Half and grad to be at::Float");
}
return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, half_to_float);
}
try:
import torch
import bnp
from .batch_norm import BatchNorm2d_NHWC
del torch
del bnp
del batch_norm
except ImportError as err:
print("apex was installed without --bnp flag, contrib.groupbn is not available")
import torch
import numpy as np
from torch.nn.modules.batchnorm import _BatchNorm
import bnp
class bn_NHWC_impl(torch.autograd.Function):
@staticmethod
def forward(ctx, x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, is_train, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream):
if is_train:
ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv)
ctx.epsilon = epsilon
ctx.momentum = mom
ctx.ret_cta = ret_cta
ctx.fuse_relu = fuse_relu
ctx.my_data = my_data
ctx.pair_data = pair_data
ctx.magic = magic
ctx.pair_data2 = pair_data2
ctx.pair_data3 = pair_data3
ctx.bn_group = bn_group
ctx.bwd_occup = bwd_occup
ctx.bwd_grid_x = bwd_grid_x
ctx.multi_stream = multi_stream
res = bnp.bn_fwd_nhwc(x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, fwd_occup, fwd_grid_x, multi_stream)
return res
else:
return bnp.bn_fwd_eval_nhwc(x, s, b, rm, riv, ret_cta, bn_group, mom, epsilon, fuse_relu)
@staticmethod
def backward(ctx, grad_y):
x, s, b, rm, riv, mini_m, mini_riv = ctx.saved_variables
epsilon = ctx.epsilon
mom = ctx.momentum
ret_cta = ctx.ret_cta
fuse_relu = ctx.fuse_relu
my_data = ctx.my_data
pair_data = ctx.pair_data
magic = ctx.magic
pair_data2 = ctx.pair_data2
pair_data3 = ctx.pair_data3
bn_group = ctx.bn_group
bwd_occup = ctx.bwd_occup
bwd_grid_x = ctx.bwd_grid_x
multi_stream = ctx.multi_stream
dx, dscale, dbias = bnp.bn_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, bwd_occup, bwd_grid_x, multi_stream)
return dx, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class bn_addrelu_NHWC_impl(torch.autograd.Function):
@staticmethod
def forward(ctx, x, z, s, b, rm, riv, mini_m, mini_riv, grid_dim_y, ret_cta, mom, epsilon, is_train, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream):
if is_train:
bitmask = torch.cuda.IntTensor(((x.numel()+31)//32) * 2 * grid_dim_y)
ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv, bitmask)
ctx.epsilon = epsilon
ctx.momentum = mom
ctx.ret_cta = ret_cta
ctx.my_data = my_data
ctx.pair_data = pair_data
ctx.magic = magic
ctx.pair_data2 = pair_data2
ctx.pair_data3 = pair_data3
ctx.bn_group = bn_group
ctx.bwd_occup = bwd_occup
ctx.bwd_grid_x = bwd_grid_x
ctx.multi_stream = multi_stream
res = bnp.bn_addrelu_fwd_nhwc(x, z, s, b, rm, riv, mini_m, mini_riv, bitmask, ret_cta, mom, epsilon, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, fwd_occup, fwd_grid_x, multi_stream)
return res
else:
return bnp.bn_addrelu_fwd_eval_nhwc(x, z, s, b, rm, riv, ret_cta, bn_group, mom, epsilon)
@staticmethod
def backward(ctx, grad_y):
x, s, b, rm, riv, mini_m, mini_riv, bitmask = ctx.saved_variables
epsilon = ctx.epsilon
mom = ctx.momentum
ret_cta = ctx.ret_cta
my_data = ctx.my_data
pair_data = ctx.pair_data
magic = ctx.magic
pair_data2 = ctx.pair_data2
pair_data3 = ctx.pair_data3
bn_group = ctx.bn_group
bwd_occup = ctx.bwd_occup
bwd_grid_x = ctx.bwd_grid_x
multi_stream = ctx.multi_stream
dx, dz, dscale, dbias = bnp.bn_addrelu_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, bitmask, ret_cta, mom, epsilon, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, bwd_occup, bwd_grid_x, multi_stream)
return dx, dz, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class BatchNorm2d_NHWC(_BatchNorm):
# if using BatchNorm2d_NHWC simultaneously with multiple streams set multi_stream to True
def __init__(self, num_features, fuse_relu=False, bn_group=1, max_cta_per_sm=2, cta_launch_margin=12, multi_stream=False):
super(BatchNorm2d_NHWC, self).__init__(num_features)
self.fuse_relu = fuse_relu
self.multi_stream = multi_stream
self.minibatch_mean = torch.cuda.FloatTensor(num_features)
self.minibatch_riv = torch.cuda.FloatTensor(num_features)
#defaut to distributed bn disabled
self.bn_group = bn_group
self.max_cta_per_sm = max_cta_per_sm #used only in training fwd and bwd
self.cta_launch_margin = cta_launch_margin #used only in training fwd and bwd
self.my_data = None
self.pair_data = None
self.pair_data2 = None
self.pair_data3 = None
self.local_rank = 0
self.magic = torch.IntTensor([0])
#calculate cta per sm occupancies
assert(max_cta_per_sm>0) # won't be able to do much with 0 CTAs :)
self.fwd_occupancy = min(bnp.bn_fwd_nhwc_occupancy(), max_cta_per_sm)
self.bwd_occupancy = min(bnp.bn_bwd_nhwc_occupancy(), max_cta_per_sm)
self.addrelu_fwd_occupancy = min(bnp.bn_addrelu_fwd_nhwc_occupancy(), max_cta_per_sm)
self.addrelu_bwd_occupancy = min(bnp.bn_addrelu_bwd_nhwc_occupancy(), max_cta_per_sm)
#calculate grid dimentions based on occupancy numbers
mp_count = torch.cuda.get_device_properties(None).multi_processor_count
self.fwd_grid_dim_x = max(mp_count*self.fwd_occupancy - cta_launch_margin , 1)
self.bwd_grid_dim_x = max(mp_count*self.bwd_occupancy - cta_launch_margin , 1)
self.addrelu_fwd_grid_dim_x = max(mp_count*self.addrelu_fwd_occupancy - cta_launch_margin , 1)
self.addrelu_bwd_grid_dim_x = max(mp_count*self.addrelu_bwd_occupancy - cta_launch_margin , 1)
self.grid_dim_y = (num_features + 63) // 64
# allocate scratch space used by implementation
# TODO: scratch space that is not supposed to be exposed at user code. We only need one time initialization, the
# same buffer could be reused in future iterations. Currently we exposed it here instead of requesting new
# buffer from cache allocator to avoid unnecessary initialization at future iterations.
self.ret_cta = torch.cuda.ByteTensor(8192).fill_(0)
#FIXME: turn pair handles into an array
if bn_group>1:
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
assert(world_size >= bn_group)
assert(world_size % bn_group == 0)
bn_sync_steps = 1
if (bn_group==4):
bn_sync_steps = 2
if (bn_group==8):
bn_sync_steps = 3
self.ipc_buffer = torch.cuda.ByteTensor(bnp.get_buffer_size(bn_sync_steps))
self.my_data = bnp.get_data_ptr(self.ipc_buffer)
# we are walking on very thin ice here by utilizing internal `_share_cuda_()`
self.storage = self.ipc_buffer.storage()
self.share_cuda = self.storage._share_cuda_()
internal_cuda_mem = self.share_cuda
# internal_cuda_mem[1]: ipc_mem_handle
my_handle = torch.cuda.ByteTensor(np.frombuffer(internal_cuda_mem[1], dtype=np.uint8))
# internal_cuda_mem[3]: offset
my_offset = torch.cuda.IntTensor([internal_cuda_mem[3]])
handles_all = torch.empty(world_size, my_handle.size(0), dtype=my_handle.dtype, device=my_handle.device)
handles_l = list(handles_all.unbind(0))
torch.distributed.all_gather(handles_l, my_handle)
offsets_all = torch.empty(world_size, my_offset.size(0), dtype=my_offset.dtype, device=my_offset.device)
offsets_l = list(offsets_all.unbind(0))
torch.distributed.all_gather(offsets_l, my_offset)
#whom do I actually care about? that would be local_rank XOR 1
self.pair_handle = handles_l[local_rank ^ 1].cpu().contiguous()
pair_offset = offsets_l[local_rank ^ 1].cpu()
self.pair_data = bnp.get_remote_data_ptr(self.pair_handle, pair_offset)
if bn_group>2:
self.pair_handle2 = handles_l[local_rank ^ 2].cpu().contiguous()
pair_offset2 = offsets_l[local_rank ^ 2].cpu()
self.pair_data2 = bnp.get_remote_data_ptr(self.pair_handle2, pair_offset2)
if bn_group>4:
self.pair_handle3 = handles_l[local_rank ^ 4].cpu().contiguous()
pair_offset3 = offsets_l[local_rank ^ 4].cpu()
self.pair_data3 = bnp.get_remote_data_ptr(self.pair_handle3, pair_offset3)
#FIXME: get magic value into C code and eliminate from here
self.magic = torch.IntTensor([2])
self.local_rank = local_rank
def forward(self, x, z=None):
if z is not None:
assert(self.fuse_relu==True)
return bn_addrelu_NHWC_impl.apply(x, z,
self.weight, self.bias,
self.running_mean, self.running_var,
self.minibatch_mean, self.minibatch_riv, self.grid_dim_y, self.ret_cta,
self.momentum,
self.eps, self.training, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3,
self.addrelu_fwd_occupancy, self.addrelu_fwd_grid_dim_x,
self.addrelu_bwd_occupancy, self.addrelu_bwd_grid_dim_x,
self.multi_stream)
else:
return bn_NHWC_impl.apply(x,
self.weight, self.bias,
self.running_mean, self.running_var,
self.minibatch_mean, self.minibatch_riv, self.ret_cta,
self.momentum,
self.eps, self.fuse_relu, self.training, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3,
self.fwd_occupancy, self.fwd_grid_dim_x,
self.bwd_occupancy, self.bwd_grid_dim_x,
self.multi_stream)
def __del__(self):
if self.bn_group>1:
bnp.close_remote_data(self.pair_handle)
if self.bn_group>2:
bnp.close_remote_data(self.pair_handle2)
if self.bn_group>4:
bnp.close_remote_data(self.pair_handle3)
import torch
from apex.contrib import xentropy as label_smoothing
import unittest
import warnings
import random
import numpy as np
import time
def label_smoothing_raw(x, target, padding_idx, smoothing):
logprobs = torch.nn.functional.log_softmax(x, dim=-1, dtype=torch.float32)
non_pad_mask = (target != padding_idx)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)[non_pad_mask]
smooth_loss = -logprobs.mean(dim=-1)[non_pad_mask]
loss = (1.0 - smoothing) * nll_loss + smoothing * smooth_loss
return loss
def label_smoothing_opt_1(x, target, padding_idx, smoothing):
logprobs = torch.nn.functional.log_softmax(x, dim=-1, dtype=torch.float32)
pad_mask = (target == padding_idx)
ll_loss = logprobs.gather(dim=-1, index=target.unsqueeze(1)).squeeze(1)
smooth_loss = logprobs.mean(dim=-1)
loss = (smoothing - 1.0) * ll_loss - smoothing * smooth_loss
loss.masked_fill_(pad_mask, 0)
return loss
class LabelSmoothingTest(unittest.TestCase):
def setUp(self, seed=1234):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Set pytorch print precision
torch.set_printoptions(precision=10)
def gen_test_inputs(self, N, T, H, smoothing, padding_idx):
logits = torch.randn((N*T, H), dtype=torch.half, device='cuda',
requires_grad=True)
labels = torch.randint(0, H, [N*T], device='cuda')
for i in random.sample(range(N*T), N*T//6):
labels[i] = padding_idx
half_to_float = (logits.dtype == torch.half)
return logits, labels, half_to_float
def print_max_diff_elem(self, ref, tst):
ref, tst = ref.flatten(), tst.flatten()
diff = (ref - tst).abs().max()
idx = (ref - tst).abs().argmax()
print("Max atol idx: {}, diff: {:.6f}, ref: {:.6f}, tst: {:.6f}".format(
idx, diff, ref[idx], tst[idx]))
def test_label_smoothing_function(self):
# Set label smoothing configuration
smoothing, padding_idx = 0.1, 0
N, T, H = 128, 74, 32320
iters = 10
loss_func = label_smoothing.SoftmaxCrossEntropyLoss.apply
for i in range(iters):
logits, labels, half_to_float = self.gen_test_inputs(
N, T, H, smoothing, padding_idx)
# Run original softmax cross entropy with label smoothing
logits.grad = None
losses = label_smoothing_raw(logits, labels, padding_idx, smoothing)
loss = losses.sum()
loss.backward()
ref_loss = loss.clone().detach()
ref_grad = logits.grad.clone().detach()
# Run optimized softmax cross entropy with label smoothing
logits.grad = None
losses = loss_func(logits, labels, smoothing, padding_idx, half_to_float)
loss = losses.sum()
loss.backward()
val_loss = loss.clone().detach()
val_grad = logits.grad.clone().detach()
# Validate
self.print_max_diff_elem(ref_grad, val_grad)
self.assertTrue(torch.allclose(ref_loss, val_loss, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_grad, val_grad, atol=1e-5, rtol=1e-5))
def test_label_smoothing_perf(self):
# Set label smoothing configuration
smoothing, padding_idx = 0.1, 0
N, T, H = 128, 74, 32320
iters = 1000
loss_func = label_smoothing.SoftmaxCrossEntropyLoss.apply
print()
logits, labels, half_to_float = self.gen_test_inputs(
N, T, H, smoothing, padding_idx)
# Run original softmax cross entropy with label smoothing
torch.cuda.synchronize()
ts = time.time()
for i in range(iters):
logits.grad = None
losses = label_smoothing_raw(logits, labels, padding_idx, smoothing)
loss = losses.sum() / N
loss.backward()
torch.cuda.synchronize()
print("Raw time {:.2f} s elapsed for {} iterations, norm {:.4f}".format(
time.time() - ts, iters, logits.grad.norm()))
# Run optimized softmax cross entropy with label smoothing
torch.cuda.synchronize()
ts = time.time()
for i in range(iters):
logits.grad = None
losses = loss_func(logits, labels, smoothing, padding_idx, half_to_float)
loss = losses.sum() / N
loss.backward()
torch.cuda.synchronize()
print("Opt time {:.2f} s elapsed for {} iterations, norm {:.4f}".format(
time.time() - ts, iters, logits.grad.norm()))
if __name__ == '__main__':
unittest.main()
try:
import torch
import xentropy_cuda
from .softmax_xentropy import SoftmaxCrossEntropyLoss
del torch
del xentropy_cuda
del softmax_xentropy
except ImportError as err:
print("apex was installed without --xentropy flag, contrib.xentropy is not available")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment