Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
4d6ed501
Commit
4d6ed501
authored
Aug 12, 2019
by
Deyu Fu
Browse files
Merge branch 'multi_tensor_sgd' into deyuf/fused_optimizer_v2
parents
690b1f71
9f64bf27
Changes
40
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
6397 additions
and
142 deletions
+6397
-142
README.md
README.md
+4
-4
apex/amp/_initialize.py
apex/amp/_initialize.py
+4
-25
apex/amp/_process_optimizer.py
apex/amp/_process_optimizer.py
+229
-90
apex/amp/handle.py
apex/amp/handle.py
+9
-11
apex/amp/scaler.py
apex/amp/scaler.py
+19
-12
apex/contrib/__init__.py
apex/contrib/__init__.py
+0
-0
apex/contrib/csrc/groupbn/batch_norm.cu
apex/contrib/csrc/groupbn/batch_norm.cu
+331
-0
apex/contrib/csrc/groupbn/batch_norm.h
apex/contrib/csrc/groupbn/batch_norm.h
+734
-0
apex/contrib/csrc/groupbn/batch_norm_add_relu.cu
apex/contrib/csrc/groupbn/batch_norm_add_relu.cu
+343
-0
apex/contrib/csrc/groupbn/batch_norm_add_relu.h
apex/contrib/csrc/groupbn/batch_norm_add_relu.h
+681
-0
apex/contrib/csrc/groupbn/cuda_utils.h
apex/contrib/csrc/groupbn/cuda_utils.h
+20
-0
apex/contrib/csrc/groupbn/interface.cpp
apex/contrib/csrc/groupbn/interface.cpp
+175
-0
apex/contrib/csrc/groupbn/ipc.cu
apex/contrib/csrc/groupbn/ipc.cu
+130
-0
apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h
apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h
+2685
-0
apex/contrib/csrc/xentropy/interface.cpp
apex/contrib/csrc/xentropy/interface.cpp
+52
-0
apex/contrib/csrc/xentropy/xentropy_kernel.cu
apex/contrib/csrc/xentropy/xentropy_kernel.cu
+610
-0
apex/contrib/groupbn/__init__.py
apex/contrib/groupbn/__init__.py
+9
-0
apex/contrib/groupbn/batch_norm.py
apex/contrib/groupbn/batch_norm.py
+225
-0
apex/contrib/test/test_label_smoothing.py
apex/contrib/test/test_label_smoothing.py
+128
-0
apex/contrib/xentropy/__init__.py
apex/contrib/xentropy/__init__.py
+9
-0
No files found.
README.md
View file @
4d6ed501
# Introduction
# Introduction
This repository holds NVIDIA-maintained utilities to streamline
This repository holds NVIDIA-maintained utilities to streamline
mixed precision and distributed training in Pytorch.
mixed precision and distributed training in Pytorch.
Some of the code here will be included in upstream Pytorch eventually.
Some of the code here will be included in upstream Pytorch eventually.
The intention of Apex is to make up-to-date utilities available to
The intention of Apex is to make up-to-date utilities available to
users as quickly as possible.
users as quickly as possible.
## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex)
## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex)
...
@@ -29,7 +29,7 @@ different flags to `amp.initialize`.
...
@@ -29,7 +29,7 @@ different flags to `amp.initialize`.
## 2. Distributed Training
## 2. Distributed Training
`apex.parallel.DistributedDataParallel`
is a module wrapper, similar to
`apex.parallel.DistributedDataParallel`
is a module wrapper, similar to
`torch.nn.parallel.DistributedDataParallel`
. It enables convenient multiprocess distributed training,
`torch.nn.parallel.DistributedDataParallel`
. It enables convenient multiprocess distributed training,
optimized for NVIDIA's NCCL communication library.
optimized for NVIDIA's NCCL communication library.
...
...
apex/amp/_initialize.py
View file @
4d6ed501
...
@@ -124,29 +124,13 @@ def check_optimizers(optimizers):
...
@@ -124,29 +124,13 @@ def check_optimizers(optimizers):
raise
RuntimeError
(
"An incoming optimizer is an instance of {}. "
.
format
(
bad_optim_type
)
+
raise
RuntimeError
(
"An incoming optimizer is an instance of {}. "
.
format
(
bad_optim_type
)
+
"The optimizer(s) passed to amp.initialize() must be bare
\n
"
"The optimizer(s) passed to amp.initialize() must be bare
\n
"
"instances of either ordinary Pytorch optimizers, or Apex fused
\n
"
"instances of either ordinary Pytorch optimizers, or Apex fused
\n
"
"optimizers (
currently just
FusedAdam
, but
FusedSGD
will be added
\n
"
"optimizers (FusedAdam
or
FusedSGD
).
\n
"
"
soon).
You should not manually wrap your optimizer in either
\n
"
"You should not manually wrap your optimizer in either
\n
"
"apex.fp16_utils.FP16_Optimizer or apex.optimizers.FP16_Optimizer.
\n
"
"apex.fp16_utils.FP16_Optimizer or apex.optimizers.FP16_Optimizer.
\n
"
"amp.initialize will take care of that for you (if necessary) based
\n
"
"amp.initialize will take care of that for you (if necessary) based
\n
"
"on the specified opt_level (and optional overridden properties)."
)
"on the specified opt_level (and optional overridden properties)."
)
def
wrap_fused_adam
(
optimizer
,
properties
):
msg
=
'Currently, the usage of FusedAdam is restricted to '
\
'amp.initialize(..., opt_level="O2", keep_batchnorm_fp32=False, '
\
'loss_scale=float or "dynamic"). We are working on enabling more general usage.'
assert
properties
.
master_weights
is
True
,
msg
assert
properties
.
cast_model_type
is
torch
.
float16
,
msg
assert
(
properties
.
keep_batchnorm_fp32
is
False
or
properties
.
keep_batchnorm_fp32
is
None
),
msg
if
properties
.
loss_scale
==
"dynamic"
:
return
FP16_Optimizer_for_fused
(
optimizer
,
dynamic_loss_scale
=
True
)
else
:
return
FP16_Optimizer_for_fused
(
optimizer
,
static_loss_scale
=
properties
.
loss_scale
)
def
_initialize
(
models
,
optimizers
,
properties
,
num_losses
=
1
,
cast_model_outputs
=
None
):
def
_initialize
(
models
,
optimizers
,
properties
,
num_losses
=
1
,
cast_model_outputs
=
None
):
from
apex.parallel
import
DistributedDataParallel
as
apex_DDP
from
apex.parallel
import
DistributedDataParallel
as
apex_DDP
from
.amp
import
init
as
amp_init
from
.amp
import
init
as
amp_init
...
@@ -176,7 +160,6 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
...
@@ -176,7 +160,6 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
if
not
_amp_state
.
allow_incoming_model_not_fp32
:
if
not
_amp_state
.
allow_incoming_model_not_fp32
:
check_params_fp32
(
models
)
check_params_fp32
(
models
)
# In the future, when FP16_Optimizer can be deprecated and master weights can
# In the future, when FP16_Optimizer can be deprecated and master weights can
# become an attribute, remember to stash master weights before casting the model.
# become an attribute, remember to stash master weights before casting the model.
...
@@ -207,7 +190,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
...
@@ -207,7 +190,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
model
.
forward
=
patch_forward
(
model
.
forward
)
model
.
forward
=
patch_forward
(
model
.
forward
)
# State dict trick to recast any preexisting per-param state tensors
# State dict trick to recast any preexisting per-param state tensors
for
optimizer
in
optimizers
:
for
optimizer
in
optimizers
:
optimizer
.
load_state_dict
(
optimizer
.
state_dict
())
optimizer
.
load_state_dict
(
optimizer
.
state_dict
())
elif
cast_model_outputs
is
not
None
:
elif
cast_model_outputs
is
not
None
:
...
@@ -223,11 +206,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
...
@@ -223,11 +206,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
model
.
forward
=
patch_forward
(
model
.
forward
)
model
.
forward
=
patch_forward
(
model
.
forward
)
for
i
,
optimizer
in
enumerate
(
optimizers
):
for
i
,
optimizer
in
enumerate
(
optimizers
):
# Still need to special case this for the first pass
optimizers
[
i
]
=
_process_optimizer
(
optimizer
,
properties
)
if
isinstance
(
optimizer
,
FusedAdam
):
optimizers
[
i
]
=
wrap_fused_adam
(
optimizer
,
properties
)
else
:
optimizers
[
i
]
=
_process_optimizer
(
optimizer
,
properties
)
_amp_state
.
loss_scalers
=
[]
_amp_state
.
loss_scalers
=
[]
for
_
in
range
(
num_losses
):
for
_
in
range
(
num_losses
):
...
...
apex/amp/_process_optimizer.py
View file @
4d6ed501
...
@@ -3,6 +3,7 @@ from ..fp16_utils import master_params_to_model_params
...
@@ -3,6 +3,7 @@ from ..fp16_utils import master_params_to_model_params
from
..multi_tensor_apply
import
multi_tensor_applier
from
..multi_tensor_apply
import
multi_tensor_applier
from
._amp_state
import
maybe_print
from
._amp_state
import
maybe_print
import
torch
import
torch
from
..optimizers
import
FusedAdam
,
FusedSGD
class
AmpOptimizerState
(
object
):
class
AmpOptimizerState
(
object
):
...
@@ -10,6 +11,20 @@ class AmpOptimizerState(object):
...
@@ -10,6 +11,20 @@ class AmpOptimizerState(object):
pass
pass
def
_master_params_to_model_params
(
self
):
stash
=
self
.
_amp_stash
if
multi_tensor_applier
.
available
:
if
len
(
stash
.
all_fp16_params
)
>
0
:
multi_tensor_applier
(
stash
.
multi_tensor_scale
,
stash
.
dummy_overflow_buf
,
[
stash
.
all_fp32_from_fp16_params
,
stash
.
all_fp16_params
],
1.0
)
else
:
for
fp16_group
,
fp32_from_fp16_group
in
zip
(
stash
.
fp16_groups
,
stash
.
fp32_from_fp16_groups
):
master_params_to_model_params
(
fp16_group
,
fp32_from_fp16_group
)
def
lazy_init_with_master_weights
(
self
):
def
lazy_init_with_master_weights
(
self
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
stash
.
fp16_groups
=
[]
stash
.
fp16_groups
=
[]
...
@@ -60,6 +75,8 @@ def lazy_init_with_master_weights(self):
...
@@ -60,6 +75,8 @@ def lazy_init_with_master_weights(self):
for
group
in
stash
.
fp32_from_fp32_groups
:
for
group
in
stash
.
fp32_from_fp32_groups
:
stash
.
all_fp32_from_fp32_params
+=
group
stash
.
all_fp32_from_fp32_params
+=
group
# all_fp16_grad_stash is only needed for fused optimizers.
stash
.
all_fp16_grad_stash
=
[
None
for
_
in
stash
.
all_fp16_params
]
# stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
# stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
stash
.
all_fp32_from_fp32_grad_stash
=
[
None
for
_
in
stash
.
all_fp32_from_fp32_params
]
stash
.
all_fp32_from_fp32_grad_stash
=
[
None
for
_
in
stash
.
all_fp32_from_fp32_params
]
...
@@ -73,15 +90,55 @@ def lazy_init_with_master_weights(self):
...
@@ -73,15 +90,55 @@ def lazy_init_with_master_weights(self):
self
.
load_state_dict
(
self
.
state_dict
())
self
.
load_state_dict
(
self
.
state_dict
())
def
post_backward_models_are_masters
(
scaler
,
params
,
stashed_grads
,
scale_override
=
None
):
grads_have_scale
,
stashed_have_scale
,
out_scale
=
scaler
.
loss_scale
(),
1.0
,
1.0
if
scale_override
is
not
None
:
grads_have_scale
,
stashed_have_scale
,
out_scale
=
scale_override
# This is a lot of python overhead...
grads_needing_unscale
=
[]
grads_needing_unscale_with_stash
=
[]
stashed
=
[]
for
param
,
stashed_grad
in
zip
(
params
,
stashed_grads
):
if
param
.
grad
is
None
and
stashed_grad
is
not
None
:
param
.
grad
=
stashed_grad
elif
param
.
grad
is
not
None
and
stashed_grad
is
None
:
grads_needing_unscale
.
append
(
param
.
grad
)
elif
param
.
grad
is
not
None
and
stashed_grad
is
not
None
:
grads_needing_unscale_with_stash
.
append
(
param
.
grad
)
stashed
.
append
(
stashed_grad
)
else
:
# param.grad is None and stashed_grad is None
continue
# unscale() implements grads*(1/scale), so "scale" should be grads_have_scale/out_scale.
if
len
(
grads_needing_unscale
)
>
0
:
scaler
.
unscale
(
grads_needing_unscale
,
grads_needing_unscale
,
None
,
# unused_scale, currently present to avoid API breakage elsewhere
models_are_masters
=
True
,
scale_override
=
grads_have_scale
/
out_scale
)
if
len
(
grads_needing_unscale_with_stash
)
>
0
:
scaler
.
unscale_with_stashed
(
grads_needing_unscale_with_stash
,
stashed
,
grads_needing_unscale_with_stash
,
scale_override
=
(
grads_have_scale
,
stashed_have_scale
,
out_scale
))
# Clear the stash.
for
i
in
range
(
len
(
stashed_grads
)):
stashed_grads
[
i
]
=
None
def
prepare_backward_with_master_weights
(
self
):
def
prepare_backward_with_master_weights
(
self
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
if
not
stash
.
lazy_init_called
:
self
.
_amp_lazy_init
()
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
# Set up to leverage grad copy elision:
# Set up to leverage grad copy elision.
# This may behave differently from an unpatched optimizer if zero_grad is used and the param is unused.
param
.
grad
=
None
param
.
grad
=
None
# for i, param in enumerate(stash.all_fp32_from_fp16_params):
# for i, param in enumerate(stash.all_fp32_from_fp16_params):
...
@@ -96,6 +153,8 @@ def prepare_backward_with_master_weights(self):
...
@@ -96,6 +153,8 @@ def prepare_backward_with_master_weights(self):
def
post_backward_with_master_weights
(
self
,
scaler
):
def
post_backward_with_master_weights
(
self
,
scaler
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
self
.
_amp_lazy_init
()
# This is a lot of python overhead...
# This is a lot of python overhead...
fp16_grads_needing_unscale
=
[]
fp16_grads_needing_unscale
=
[]
new_fp32_grads
=
[]
new_fp32_grads
=
[]
...
@@ -129,37 +188,10 @@ def post_backward_with_master_weights(self, scaler):
...
@@ -129,37 +188,10 @@ def post_backward_with_master_weights(self, scaler):
preexisting_fp32_grads
)
preexisting_fp32_grads
)
# fp32 params can be treated as they would be in the "no_master_weights" case.
# fp32 params can be treated as they would be in the "no_master_weights" case.
grads_needing_unscale
=
[]
post_backward_models_are_masters
(
grads_needing_unscale_with_stash
=
[]
scaler
,
stashed
=
[]
stash
.
all_fp32_from_fp32_params
,
for
param
,
stashed_grad
in
zip
(
stash
.
all_fp32_from_fp32_params
,
stash
.
all_fp32_from_fp32_grad_stash
)
stash
.
all_fp32_from_fp32_grad_stash
):
if
param
.
grad
is
None
and
stashed_grad
is
not
None
:
param
.
grad
=
stashed_grad
elif
param
.
grad
is
not
None
and
stashed_grad
is
None
:
grads_needing_unscale
.
append
(
param
.
grad
)
elif
param
.
grad
is
not
None
and
stashed_grad
is
not
None
:
grads_needing_unscale_with_stash
.
append
(
param
.
grad
)
stashed
.
append
(
stashed_grad
)
else
:
# param.grad is None and stashed_grad is None:
continue
if
len
(
grads_needing_unscale
)
>
0
:
scaler
.
unscale
(
grads_needing_unscale
,
grads_needing_unscale
,
scaler
.
loss_scale
(),
models_are_masters
=
True
)
if
len
(
grads_needing_unscale_with_stash
)
>
0
:
scaler
.
unscale_with_stashed
(
grads_needing_unscale_with_stash
,
stashed
,
grads_needing_unscale_with_stash
)
# Clear the stash.
for
i
in
range
(
len
(
stash
.
all_fp32_from_fp32_grad_stash
)):
stash
.
all_fp32_from_fp32_grad_stash
[
i
]
=
None
def
lazy_init_no_master_weights
(
self
):
def
lazy_init_no_master_weights
(
self
):
...
@@ -184,9 +216,7 @@ def lazy_init_no_master_weights(self):
...
@@ -184,9 +216,7 @@ def lazy_init_no_master_weights(self):
def
prepare_backward_no_master_weights
(
self
):
def
prepare_backward_no_master_weights
(
self
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
if
not
stash
.
lazy_init_called
:
self
.
_amp_lazy_init
()
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
stash
.
all_fp16_grad_stash
[
i
]
=
param
.
grad
stash
.
all_fp16_grad_stash
[
i
]
=
param
.
grad
...
@@ -202,55 +232,141 @@ def prepare_backward_no_master_weights(self):
...
@@ -202,55 +232,141 @@ def prepare_backward_no_master_weights(self):
def
post_backward_no_master_weights
(
self
,
scaler
):
def
post_backward_no_master_weights
(
self
,
scaler
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
self
.
_amp_lazy_init
()
split_types
=
((
stash
.
all_fp16_params
,
stash
.
all_fp16_grad_stash
),
split_types
=
((
stash
.
all_fp16_params
,
stash
.
all_fp16_grad_stash
),
(
stash
.
all_fp32_params
,
stash
.
all_fp32_grad_stash
))
(
stash
.
all_fp32_params
,
stash
.
all_fp32_grad_stash
))
for
params
,
stashed_grads
in
split_types
:
for
params
,
stashed_grads
in
split_types
:
# This is a lot of python overhead...
post_backward_models_are_masters
(
scaler
,
params
,
stashed_grads
)
grads_needing_unscale
=
[]
grads_needing_unscale_with_stash
=
[]
stashed
=
[]
for
param
,
stashed_grad
in
zip
(
params
,
stashed_grads
):
if
param
.
grad
is
None
and
stashed_grad
is
not
None
:
param
.
grad
=
stashed_grad
elif
param
.
grad
is
not
None
and
stashed_grad
is
None
:
grads_needing_unscale
.
append
(
param
.
grad
)
elif
param
.
grad
is
not
None
and
stashed_grad
is
not
None
:
grads_needing_unscale_with_stash
.
append
(
param
.
grad
)
stashed
.
append
(
stashed_grad
)
else
:
# param.grad is None and stashed_grad is None
continue
if
len
(
grads_needing_unscale
)
>
0
:
scaler
.
unscale
(
grads_needing_unscale
,
grads_needing_unscale
,
scaler
.
loss_scale
(),
models_are_masters
=
True
)
if
len
(
grads_needing_unscale_with_stash
)
>
0
:
#####################################################################################
scaler
.
unscale_with_stashed
(
# FusedAdam versions
grads_needing_unscale_with_stash
,
#####################################################################################
stashed
,
grads_needing_unscale_with_stash
)
# Clear the stash.
def
prepare_backward_with_master_weights_FusedAdam
(
self
):
for
i
in
range
(
len
(
stashed_grads
)):
stash
=
self
.
_amp_stash
stashed_grads
[
i
]
=
None
self
.
_amp_lazy_init
()
def
_master_params_to_model_params
(
self
):
def
post_backward_with_master_weights_FusedAdam
(
self
,
scaler
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
if
multi_tensor_applier
.
available
:
if
len
(
stash
.
all_fp16_params
)
>
0
:
self
.
_amp_lazy_init
()
multi_tensor_applier
(
stash
.
multi_tensor_scale
,
stash
.
scale
=
scaler
.
loss_scale
()
stash
.
dummy_overflow_buf
,
stash
.
grads
=
[[
param
.
grad
.
data
for
param
in
group
]
for
group
in
stash
.
fp16_groups
]
[
stash
.
all_fp32_from_fp16_params
,
stash
.
all_fp16_params
],
stash
.
output_params
=
[[
param
for
param
in
group
]
for
group
in
stash
.
fp16_groups
]
1.0
)
norm_groups
=
[]
skip
=
False
for
grad_group
in
stash
.
grads
:
norm
,
_
=
multi_tensor_applier
(
stash
.
multi_tensor_l2norm
,
stash
.
dummy_overflow_buf
,
[
grad_group
],
False
)
# Still syncing here for now.
norm
=
float
(
norm
)
norm_groups
.
append
(
norm
)
if
norm
==
float
(
'inf'
)
or
norm
==
-
float
(
'inf'
)
or
norm
!=
norm
:
skip
=
True
if
skip
:
scaler
.
_overflow_buf
.
fill_
(
1.
)
scaler
.
_has_overflow
=
True
stash
.
grad_norms
=
norm_groups
def
prepare_backward_no_master_weights_FusedAdam
(
self
):
stash
=
self
.
_amp_stash
self
.
_amp_lazy_init
()
def
post_backward_no_master_weights_FusedAdam
(
self
,
scaler
):
stash
=
self
.
_amp_stash
self
.
_amp_lazy_init
()
stash
.
scale
=
scaler
.
loss_scale
()
stash
.
grads
=
None
stash
.
output_params
=
None
stash
.
grad_norms
=
None
#####################################################################################
# FusedSGD versions
# Eat this ugly code duplication for now. First make it work, then make it clean.
# It's difficult to anticipate what can be unified between the FusedAdam and FusedSGD
# implementations until I have them both working.
#####################################################################################
# FusedSGD never explicitly materializes the fp32 gradients for "fp32 from fp16" master params
# outside the kernel, so we must accumulate directly into the model grads.
def
prepare_backward_with_master_weights_FusedSGD
(
self
):
if
self
.
materialize_master_grads
:
prepare_backward_with_master_weights
(
self
)
else
:
else
:
for
fp16_group
,
fp32_from_fp16_group
in
zip
(
stash
.
fp16_groups
,
stash
.
fp32_from_fp16_groups
):
stash
=
self
.
_amp_stash
master_params_to_model_params
(
fp16_group
,
fp32_from_fp16_group
)
self
.
_amp_lazy_init
()
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
stash
.
all_fp16_grad_stash
[
i
]
=
param
.
grad
# Set up to leverage grad copy elision:
param
.
grad
=
None
for
i
,
param
in
enumerate
(
stash
.
all_fp32_from_fp32_params
):
stash
.
all_fp32_from_fp32_grad_stash
[
i
]
=
param
.
grad
# Set up to leverage grad copy elision:
param
.
grad
=
None
def
post_backward_with_master_weights_FusedSGD
(
self
,
scaler
):
if
self
.
materialize_master_grads
:
post_backward_with_master_weights
(
self
,
scaler
)
else
:
stash
=
self
.
_amp_stash
self
.
_amp_lazy_init
()
grads_have_scale
=
scaler
.
loss_scale
()
stashed_have_scale
=
self
.
most_recent_scale
out_scale
=
grads_have_scale
if
self
.
scale_set_by_backward
:
out_scale
=
min
(
grads_have_scale
,
self
.
most_recent_scale
)
split_types
=
((
stash
.
all_fp16_params
,
stash
.
all_fp16_grad_stash
),
(
stash
.
all_fp32_from_fp32_params
,
stash
.
all_fp32_from_fp32_grad_stash
))
# unscale_with_stashed() implements grads*1/scale + stashed_grads*1.
# stashed_grads are scaled by self.most_recent_scale.
for
params
,
stashed_grads
in
split_types
:
post_backward_models_are_masters
(
scaler
,
params
,
stashed_grads
,
(
grads_have_scale
,
stashed_have_scale
,
out_scale
))
self
.
most_recent_scale
=
out_scale
self
.
scale_set_by_backward
=
True
def
prepare_backward_no_master_weights_FusedSGD
(
self
):
prepare_backward_no_master_weights
(
self
)
def
post_backward_no_master_weights_FusedSGD
(
self
,
scaler
):
post_backward_no_master_weights
(
self
,
scaler
)
def
_amp_lazy_init
(
self
):
stash
=
self
.
_amp_stash
if
not
stash
.
lazy_init_called
:
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
def
_process_optimizer
(
optimizer
,
properties
):
def
_process_optimizer
(
optimizer
,
properties
):
...
@@ -266,7 +382,8 @@ def _process_optimizer(optimizer, properties):
...
@@ -266,7 +382,8 @@ def _process_optimizer(optimizer, properties):
for
name
in
(
"_lazy_init_maybe_master_weights"
,
for
name
in
(
"_lazy_init_maybe_master_weights"
,
"_master_params_to_model_params"
,
"_master_params_to_model_params"
,
"_prepare_amp_backward"
,
"_prepare_amp_backward"
,
"_post_amp_backward"
):
"_post_amp_backward"
,
"_amp_lazy_init"
):
if
hasattr
(
optimizer
,
name
):
if
hasattr
(
optimizer
,
name
):
raise
RuntimeError
(
"Incoming optimizer already has {} defined."
.
format
(
name
))
raise
RuntimeError
(
"Incoming optimizer already has {} defined."
.
format
(
name
))
...
@@ -274,6 +391,7 @@ def _process_optimizer(optimizer, properties):
...
@@ -274,6 +391,7 @@ def _process_optimizer(optimizer, properties):
if
multi_tensor_applier
.
available
:
if
multi_tensor_applier
.
available
:
import
amp_C
import
amp_C
optimizer
.
_amp_stash
.
multi_tensor_scale
=
amp_C
.
multi_tensor_scale
optimizer
.
_amp_stash
.
multi_tensor_scale
=
amp_C
.
multi_tensor_scale
optimizer
.
_amp_stash
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
optimizer
.
_amp_stash
.
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
]);
optimizer
.
_amp_stash
.
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
]);
if
properties
.
master_weights
:
if
properties
.
master_weights
:
...
@@ -288,7 +406,8 @@ def _process_optimizer(optimizer, properties):
...
@@ -288,7 +406,8 @@ def _process_optimizer(optimizer, properties):
if
closure
is
not
None
:
if
closure
is
not
None
:
raise
RuntimeError
(
"Currently, Amp does not support closure use with optimizers."
)
raise
RuntimeError
(
"Currently, Amp does not support closure use with optimizers."
)
retval
=
old_step
()
retval
=
old_step
()
self
.
_master_params_to_model_params
()
if
not
(
isinstance
(
self
,
FusedAdam
)
or
isinstance
(
self
,
FusedSGD
)):
self
.
_master_params_to_model_params
()
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
for
param
in
self
.
_amp_stash
.
all_fp32_from_fp16_params
:
for
param
in
self
.
_amp_stash
.
all_fp32_from_fp16_params
:
param
.
grad
=
None
param
.
grad
=
None
...
@@ -298,9 +417,7 @@ def _process_optimizer(optimizer, properties):
...
@@ -298,9 +417,7 @@ def _process_optimizer(optimizer, properties):
old_zero_grad
=
optimizer
.
zero_grad
old_zero_grad
=
optimizer
.
zero_grad
def
new_zero_grad
(
self
):
def
new_zero_grad
(
self
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
if
not
stash
.
lazy_init_called
:
self
.
_amp_lazy_init
()
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
# Zero the model grads.
# Zero the model grads.
for
param
in
stash
.
all_fp16_params
:
for
param
in
stash
.
all_fp16_params
:
if
param
.
grad
is
not
None
:
if
param
.
grad
is
not
None
:
...
@@ -315,20 +432,42 @@ def _process_optimizer(optimizer, properties):
...
@@ -315,20 +432,42 @@ def _process_optimizer(optimizer, properties):
param
.
grad
=
None
param
.
grad
=
None
optimizer
.
zero_grad
=
types
.
MethodType
(
new_zero_grad
,
optimizer
)
optimizer
.
zero_grad
=
types
.
MethodType
(
new_zero_grad
,
optimizer
)
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
if
isinstance
(
optimizer
,
FusedAdam
):
prepare_backward_with_master_weights
,
optimizer
)
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_with_master_weights_FusedAdam
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_with_master_weights
,
optimizer
)
post_backward_with_master_weights_FusedAdam
,
optimizer
)
elif
isinstance
(
optimizer
,
FusedSGD
):
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_with_master_weights_FusedSGD
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_with_master_weights_FusedSGD
,
optimizer
)
else
:
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_with_master_weights
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_with_master_weights
,
optimizer
)
else
:
else
:
optimizer
.
_lazy_init_maybe_master_weights
=
types
.
MethodType
(
optimizer
.
_lazy_init_maybe_master_weights
=
types
.
MethodType
(
lazy_init_no_master_weights
,
optimizer
)
lazy_init_no_master_weights
,
optimizer
)
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
if
isinstance
(
optimizer
,
FusedAdam
):
prepare_backward_no_master_weights
,
optimizer
)
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_no_master_weights_FusedAdam
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_no_master_weights_FusedAdam
,
optimizer
)
elif
isinstance
(
optimizer
,
FusedSGD
):
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_no_master_weights_FusedSGD
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_no_master_weights_FusedSGD
,
optimizer
)
else
:
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_no_master_weights
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_no_master_weights
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
optimizer
.
_amp_lazy_init
=
types
.
MethodType
(
_amp_lazy_init
,
optimizer
)
post_backward_no_master_weights
,
optimizer
)
old_add_param_group
=
optimizer
.
add_param_group
old_add_param_group
=
optimizer
.
add_param_group
...
...
apex/amp/handle.py
View file @
4d6ed501
...
@@ -6,8 +6,6 @@ from . import utils
...
@@ -6,8 +6,6 @@ from . import utils
from
.opt
import
OptimWrapper
from
.opt
import
OptimWrapper
from
.scaler
import
LossScaler
from
.scaler
import
LossScaler
from
._amp_state
import
_amp_state
,
master_params
,
maybe_print
from
._amp_state
import
_amp_state
,
master_params
,
maybe_print
from
..fp16_utils
import
FP16_Optimizer
as
FP16_Optimizer_general
from
..optimizers
import
FP16_Optimizer
as
FP16_Optimizer_for_fused
from
..parallel.LARC
import
LARC
from
..parallel.LARC
import
LARC
...
@@ -89,13 +87,8 @@ def scale_loss(loss,
...
@@ -89,13 +87,8 @@ def scale_loss(loss,
if
isinstance
(
optimizers
,
torch
.
optim
.
Optimizer
)
or
isinstance
(
optimizers
,
LARC
):
if
isinstance
(
optimizers
,
torch
.
optim
.
Optimizer
)
or
isinstance
(
optimizers
,
LARC
):
optimizers
=
[
optimizers
]
optimizers
=
[
optimizers
]
# this is what happens when i have to support tools from different sources under the same API...
loss_scaler
=
_amp_state
.
loss_scalers
[
loss_id
]
# TODO: Rewrite FusedAdam to use multi-tensor apply and the same loss scaler.
loss_scale
=
loss_scaler
.
loss_scale
()
if
isinstance
(
optimizers
,
FP16_Optimizer_for_fused
):
loss_scale
=
optimizers
.
cur_scale
else
:
loss_scaler
=
_amp_state
.
loss_scalers
[
loss_id
]
loss_scale
=
loss_scaler
.
loss_scale
()
if
((
not
_amp_state
.
opt_properties
.
master_weights
)
if
((
not
_amp_state
.
opt_properties
.
master_weights
)
and
(
not
loss_scaler
.
dynamic
)
and
(
not
loss_scaler
.
dynamic
)
...
@@ -120,8 +113,8 @@ def scale_loss(loss,
...
@@ -120,8 +113,8 @@ def scale_loss(loss,
for
optimizer
in
optimizers
:
for
optimizer
in
optimizers
:
optimizer
.
_amp_stash
.
params_have_scaled_gradients
=
True
optimizer
.
_amp_stash
.
params_have_scaled_gradients
=
True
else
:
else
:
# FusedAdam and FusedSGD
will
take care of unscaling as part of their step() methods.
# FusedAdam and FusedSGD
may
take care of unscaling as part of their step() methods.
if
not
isinstance
(
optimizers
,
FP16_Optimizer_for_fused
):
#
if not isinstance(optimizers, FP16_Optimizer_for_fused):
loss_scaler
.
clear_overflow_state
()
loss_scaler
.
clear_overflow_state
()
for
optimizer
in
optimizers
:
for
optimizer
in
optimizers
:
optimizer
.
_post_amp_backward
(
loss_scaler
)
optimizer
.
_post_amp_backward
(
loss_scaler
)
...
@@ -142,10 +135,15 @@ def scale_loss(loss,
...
@@ -142,10 +135,15 @@ def scale_loss(loss,
maybe_print
((
"Gradient overflow. Skipping step, loss scaler "
+
maybe_print
((
"Gradient overflow. Skipping step, loss scaler "
+
"{} reducing loss scale to {}"
).
format
(
loss_id
,
"{} reducing loss scale to {}"
).
format
(
loss_id
,
loss_scaler
.
loss_scale
()))
loss_scaler
.
loss_scale
()))
# TODO: I don't like the special casing for different optimizer implementations.
# Maybe skip should delegate to a method owned by the optimizers themselves.
if
hasattr
(
opt
.
_amp_stash
,
"all_fp32_from_fp16_params"
):
if
hasattr
(
opt
.
_amp_stash
,
"all_fp32_from_fp16_params"
):
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
for
param
in
opt
.
_amp_stash
.
all_fp32_from_fp16_params
:
for
param
in
opt
.
_amp_stash
.
all_fp32_from_fp16_params
:
param
.
grad
=
None
param
.
grad
=
None
if
hasattr
(
opt
,
"most_recent_scale"
):
opt
.
most_recent_scale
=
1.0
opt
.
scale_set_by_backward
=
False
opt
.
step
=
opt_step
opt
.
step
=
opt_step
opt
.
_amp_stash
.
already_patched
=
False
opt
.
_amp_stash
.
already_patched
=
False
return
skip_step
return
skip_step
...
...
apex/amp/scaler.py
View file @
4d6ed501
...
@@ -16,7 +16,7 @@ def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=F
...
@@ -16,7 +16,7 @@ def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=F
master_grad
.
mul_
(
scale
)
master_grad
.
mul_
(
scale
)
return
False
return
False
def
axpby_check_overflow_python
(
model_grad
,
stashed_grad
,
master_grad
,
scale
,
check_overflow
=
False
):
def
axpby_check_overflow_python
(
model_grad
,
stashed_grad
,
master_grad
,
a
,
b
,
check_overflow
=
False
):
# Exception handling for 18.04 compatibility
# Exception handling for 18.04 compatibility
if
check_overflow
:
if
check_overflow
:
cpu_sum
=
float
(
model_grad
.
float
().
sum
())
cpu_sum
=
float
(
model_grad
.
float
().
sum
())
...
@@ -26,9 +26,8 @@ def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, scale, ch
...
@@ -26,9 +26,8 @@ def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, scale, ch
# if master_grad is not model_grad: # copy_ probably internally short-circuits this
# if master_grad is not model_grad: # copy_ probably internally short-circuits this
# master_grad.copy_(model_grad)
# master_grad.copy_(model_grad)
assert
stashed_grad
.
dtype
==
master_grad
.
dtype
assert
stashed_grad
.
dtype
==
master_grad
.
dtype
converted_model_grad
=
model_grad
.
to
(
master_grad
.
dtype
)
converted_model_grad
=
model_grad
.
data
.
to
(
master_grad
.
dtype
)
stashed_grad
.
add_
(
scale
,
converted_model_grad
)
master_grad
.
data
=
a
*
converted_model_grad
.
data
+
b
*
stashed_grad
.
data
master_grad
.
data
=
stashed_grad
.
data
return
False
return
False
class
LossScaler
(
object
):
class
LossScaler
(
object
):
...
@@ -92,11 +91,13 @@ class LossScaler(object):
...
@@ -92,11 +91,13 @@ class LossScaler(object):
break
break
# unused_scale keeps some of the old API alive for hopefully a short time.
# unused_scale keeps some of the old API alive for hopefully a short time.
def
unscale
(
self
,
model_grads
,
master_grads
,
unused_scale
,
models_are_masters
=
False
):
def
unscale
(
self
,
model_grads
,
master_grads
,
unused_scale
,
models_are_masters
=
False
,
scale_override
=
None
):
if
self
.
_has_overflow
:
if
self
.
_has_overflow
:
return
return
scale
=
self
.
_loss_scale
scale
=
self
.
_loss_scale
if
scale_override
is
not
None
:
scale
=
scale_override
if
scale
==
1.0
and
models_are_masters
and
not
self
.
dynamic
:
if
scale
==
1.0
and
models_are_masters
and
not
self
.
dynamic
:
return
return
...
@@ -126,7 +127,8 @@ class LossScaler(object):
...
@@ -126,7 +127,8 @@ class LossScaler(object):
model_grads
,
model_grads
,
stashed_master_grads
,
stashed_master_grads
,
master_grads
,
master_grads
,
scale
):
a
,
b
):
for
model
,
stashed
,
master
in
zip
(
model_grads
,
stashed_master_grads
,
master_grads
):
for
model
,
stashed
,
master
in
zip
(
model_grads
,
stashed_master_grads
,
master_grads
):
if
model
is
None
and
stashed
is
None
:
if
model
is
None
and
stashed
is
None
:
continue
continue
...
@@ -141,7 +143,8 @@ class LossScaler(object):
...
@@ -141,7 +143,8 @@ class LossScaler(object):
self
.
_has_overflow
=
axpby_check_overflow_python
(
model
,
self
.
_has_overflow
=
axpby_check_overflow_python
(
model
,
stashed
,
stashed
,
master
,
master
,
1.
/
scale
,
a
,
b
,
self
.
dynamic
)
self
.
dynamic
)
if
self
.
_has_overflow
and
self
.
dynamic
:
if
self
.
_has_overflow
and
self
.
dynamic
:
break
break
...
@@ -149,11 +152,14 @@ class LossScaler(object):
...
@@ -149,11 +152,14 @@ class LossScaler(object):
def
unscale_with_stashed
(
self
,
def
unscale_with_stashed
(
self
,
model_grads
,
model_grads
,
stashed_master_grads
,
stashed_master_grads
,
master_grads
):
master_grads
,
scale_override
=
None
):
if
self
.
_has_overflow
:
if
self
.
_has_overflow
:
return
return
scale
=
self
.
_loss_scale
grads_have_scale
,
stashed_have_scale
,
out_scale
=
self
.
_loss_scale
,
1.0
,
1.0
if
scale_override
is
not
None
:
grads_have_scale
,
stashed_have_scale
,
out_scale
=
scale_override
if
LossScaler
.
has_fused_kernel
:
if
LossScaler
.
has_fused_kernel
:
if
(
not
LossScaler
.
warned_unscaling_non_fp32_grad
if
(
not
LossScaler
.
warned_unscaling_non_fp32_grad
...
@@ -167,14 +173,15 @@ class LossScaler(object):
...
@@ -167,14 +173,15 @@ class LossScaler(object):
multi_tensor_applier
(
LossScaler
.
multi_tensor_axpby_cuda
,
multi_tensor_applier
(
LossScaler
.
multi_tensor_axpby_cuda
,
self
.
_overflow_buf
,
self
.
_overflow_buf
,
[
model_grads
,
stashed_master_grads
,
master_grads
],
[
model_grads
,
stashed_master_grads
,
master_grads
],
1.
/
scale
,
out_scale
/
grads_have_scale
,
#
1./scale,
1.0
,
out_scale
/
stashed_have_scale
,
#
1.0,
0
)
# check only arg 0, aka the incoming model grads, for infs
0
)
# check only arg 0, aka the incoming model grads, for infs
else
:
else
:
self
.
unscale_with_stashed_python
(
model_grads
,
self
.
unscale_with_stashed_python
(
model_grads
,
stashed_master_grads
,
stashed_master_grads
,
master_grads
,
master_grads
,
scale
)
out_scale
/
grads_have_scale
,
out_scale
/
stashed_have_scale
)
# Defer to update_scale
# Defer to update_scale
# If the fused kernel is available, we only need one D2H memcopy and sync.
# If the fused kernel is available, we only need one D2H memcopy and sync.
...
...
apex/contrib/__init__.py
0 → 100644
View file @
4d6ed501
apex/contrib/csrc/groupbn/batch_norm.cu
0 → 100644
View file @
4d6ed501
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include "THC/THC.h"
#include "batch_norm.h"
#include <cuda.h>
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
static
size_t
round_up_to_multiple
(
size_t
x
,
int
multiple
)
{
return
((
x
+
multiple
-
1
)
/
multiple
)
*
multiple
;
}
// TODO: Stop manually allocating CUDA memory; allocate an ATen byte
// tensor instead.
struct
Workspace
{
Workspace
(
size_t
size
)
:
size
(
size
),
data
(
NULL
)
{
data
=
THCudaMalloc
(
at
::
globalContext
().
lazyInitCUDA
(),
size
);
}
Workspace
(
const
Workspace
&
)
=
delete
;
Workspace
(
Workspace
&&
)
=
default
;
Workspace
&
operator
=
(
Workspace
&&
)
=
default
;
~
Workspace
()
{
if
(
data
)
{
THCudaFree
(
at
::
globalContext
().
lazyInitCUDA
(),
data
);
}
}
size_t
size
;
void
*
data
;
};
// Return {y}
at
::
Tensor
nhwc_bn_fwd_train
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
const
int
N
=
x
.
size
(
0
);
const
int
H
=
x
.
size
(
1
);
const
int
W
=
x
.
size
(
2
);
const
int
C
=
x
.
size
(
3
);
// generating new magic number and use that for sync
int
*
magic
=
magic_tensor
.
data
<
int
>
();
*
magic
=
(
*
magic
+
1
)
&
0xff
;
// Allocate output tensor
at
::
Tensor
y
=
at
::
empty
({
N
,
H
,
W
,
C
},
x
.
options
());
// Create wrapper
NhwcBatchNorm
*
bn
=
new
NhwcBatchNorm
();
bn
->
setInputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
,
bn_group
);
bn
->
setOutputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
);
bn
->
setConstants
(
momentum
,
epsilon
);
// set pointers within the wrapper
bn
->
setInputOutputPointers
(
x
.
data
<
at
::
Half
>
(),
nullptr
,
y
.
data
<
at
::
Half
>
(),
nullptr
);
bn
->
setWeightPointers
({
scale
.
data
<
float
>
(),
bias
.
data
<
float
>
()},
{
nullptr
,
nullptr
});
bn
->
setParameterPointers
({
running_mean
.
data
<
float
>
(),
running_inv_var
.
data
<
float
>
()});
// deal with workspace(s)
auto
workspace_bytes
=
bn
->
numWorkspaceBytes
();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t
total_workspace_bytes
=
0
;
std
::
vector
<
size_t
>
workspace_offsets
;
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
total_workspace_bytes
=
round_up_to_multiple
(
total_workspace_bytes
,
512
);
workspace_offsets
.
push_back
(
total_workspace_bytes
);
auto
alloc_bytes
=
workspace_bytes
[
index
];
total_workspace_bytes
+=
alloc_bytes
;
}
// Allocate the workspace
Workspace
ws
(
total_workspace_bytes
);
std
::
vector
<
void
*>
workspace
;
workspace
.
push_back
(
minibatch_mean
.
data
<
float
>
());
workspace
.
push_back
(
minibatch_inv_var
.
data
<
float
>
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
2
];
void
*
retired_ctas
=
ret_cta
.
data
<
uint8_t
>
();
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
void
*
ptr
=
reinterpret_cast
<
uint8_t
*>
(
ws
.
data
)
+
workspace_offsets
[
index
-
3
];
workspace
.
push_back
(
ptr
);
}
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
// Don't fuse in ReLU for now at least
bn
->
fwd
(
stream
,
fuse_relu
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
*
magic
,
occupancy
,
grid_dim_x
,
coop
);
return
y
;
}
at
::
Tensor
nhwc_bn_fwd_eval
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
int
bn_group
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
)
{
const
int
N
=
x
.
size
(
0
);
const
int
H
=
x
.
size
(
1
);
const
int
W
=
x
.
size
(
2
);
const
int
C
=
x
.
size
(
3
);
// Allocate output tensor
at
::
Tensor
y
=
at
::
empty
({
N
,
H
,
W
,
C
},
x
.
options
());
// Create wrapper
NhwcBatchNorm
*
bn
=
new
NhwcBatchNorm
();
bn
->
setInputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
,
bn_group
);
bn
->
setOutputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
);
bn
->
setConstants
(
momentum
,
epsilon
);
// set pointers within the wrapper
bn
->
setInputOutputPointers
(
x
.
data
<
at
::
Half
>
(),
nullptr
,
y
.
data
<
at
::
Half
>
(),
nullptr
);
bn
->
setWeightPointers
({
scale
.
data
<
float
>
(),
bias
.
data
<
float
>
()},
{
nullptr
,
nullptr
});
bn
->
setParameterPointers
({
running_mean
.
data
<
float
>
(),
running_inv_var
.
data
<
float
>
()});
// deal with workspace(s)
auto
workspace_bytes
=
bn
->
numWorkspaceBytes
();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t
total_workspace_bytes
=
0
;
std
::
vector
<
size_t
>
workspace_offsets
;
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
total_workspace_bytes
=
round_up_to_multiple
(
total_workspace_bytes
,
512
);
workspace_offsets
.
push_back
(
total_workspace_bytes
);
auto
alloc_bytes
=
workspace_bytes
[
index
];
total_workspace_bytes
+=
alloc_bytes
;
}
// Allocate the workspace
Workspace
ws
(
total_workspace_bytes
);
std
::
vector
<
void
*>
workspace
;
workspace
.
push_back
(
nullptr
);
workspace
.
push_back
(
nullptr
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
2
];
void
*
retired_ctas
=
ret_cta
.
data
<
uint8_t
>
();
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
void
*
ptr
=
reinterpret_cast
<
uint8_t
*>
(
ws
.
data
)
+
workspace_offsets
[
index
-
3
];
workspace
.
push_back
(
ptr
);
}
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
// Don't fuse in ReLU for now at least
bn
->
fwdInference
(
stream
,
fuse_relu
);
return
y
;
}
std
::
vector
<
at
::
Tensor
>
nhwc_bn_bwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
dy
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
// shape
const
int
N
=
x
.
size
(
0
);
const
int
H
=
x
.
size
(
1
);
const
int
W
=
x
.
size
(
2
);
const
int
C
=
x
.
size
(
3
);
// generating new magic number and use that for sync
int
*
magic
=
magic_tensor
.
data
<
int
>
();
*
magic
=
(
*
magic
+
1
)
&
0xff
;
// outputs
at
::
Tensor
x_grad
,
scale_grad
,
bias_grad
;
// Allocate outputs
x_grad
=
at
::
empty_like
(
x
);
scale_grad
=
at
::
empty_like
(
scale
);
bias_grad
=
at
::
empty_like
(
bias
);
// Create wrapper
NhwcBatchNorm
*
bn
=
new
NhwcBatchNorm
();
bn
->
setInputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
,
bn_group
);
bn
->
setOutputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
);
bn
->
setConstants
(
momentum
,
epsilon
);
// set pointers within the wrapper
bn
->
setInputOutputPointers
(
x
.
data
<
at
::
Half
>
(),
x_grad
.
data
<
at
::
Half
>
(),
nullptr
,
dy
.
data
<
at
::
Half
>
());
bn
->
setWeightPointers
({
scale
.
data
<
float
>
(),
bias
.
data
<
float
>
()},
{
scale_grad
.
data
<
float
>
(),
bias_grad
.
data
<
float
>
()});
bn
->
setParameterPointers
({
running_mean
.
data
<
float
>
(),
running_inv_var
.
data
<
float
>
()});
// deal with workspace(s)
auto
workspace_bytes
=
bn
->
numWorkspaceBytes
();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t
total_workspace_bytes
=
0
;
std
::
vector
<
size_t
>
workspace_offsets
;
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
total_workspace_bytes
=
round_up_to_multiple
(
total_workspace_bytes
,
512
);
workspace_offsets
.
push_back
(
total_workspace_bytes
);
auto
alloc_bytes
=
workspace_bytes
[
index
];
total_workspace_bytes
+=
alloc_bytes
;
}
// Allocate the workspace
Workspace
ws
(
total_workspace_bytes
);
std
::
vector
<
void
*>
workspace
;
workspace
.
push_back
(
minibatch_mean
.
data
<
float
>
());
workspace
.
push_back
(
minibatch_inv_var
.
data
<
float
>
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
2
];
void
*
retired_ctas
=
ret_cta
.
data
<
uint8_t
>
();
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
void
*
ptr
=
reinterpret_cast
<
uint8_t
*>
(
ws
.
data
)
+
workspace_offsets
[
index
-
3
];
workspace
.
push_back
(
ptr
);
}
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
bn
->
dgrad
(
stream
,
fuse_relu
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
*
magic
,
occupancy
,
grid_dim_x
,
coop
);
return
std
::
vector
<
at
::
Tensor
>
{
x_grad
,
scale_grad
,
bias_grad
};
}
int
nhwc_bn_fwd_occupancy
()
{
int
device_id
=-
1
;
cudaGetDevice
(
&
device_id
);
//max occupancy supported by the code is 2
return
NhwcBatchNorm
::
smem_driven_fwd_occupancy
(
device_id
,
2
);
}
int
nhwc_bn_bwd_occupancy
()
{
int
device_id
=-
1
;
cudaGetDevice
(
&
device_id
);
//max occupancy supported by the code is 2
return
NhwcBatchNorm
::
smem_driven_bwd_occupancy
(
device_id
,
2
);
}
apex/contrib/csrc/groupbn/batch_norm.h
0 → 100644
View file @
4d6ed501
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nhwc_batch_norm.h
* \brief CUDA NHWC Batch Normalization code
* \author Shankara Rao Thejaswi Nanditale, Dick Carter, Evgeni Krimer
*/
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
#include <cudnn.h>
#include <algorithm>
#include <vector>
#include <string>
#include "nhwc_batch_norm_kernel.h"
#include "cuda_utils.h"
#define VERBOSE_DEFAULT false
class
NhwcBatchNorm
{
public:
NhwcBatchNorm
()
{
name_
=
"nhwc_batchnorm"
;
createTensorDescriptor
(
&
X_tensor_desc_
);
createTensorDescriptor
(
&
Y_tensor_desc_
);
}
~
NhwcBatchNorm
()
{
destroyTensorDescriptor
(
X_tensor_desc_
);
destroyTensorDescriptor
(
Y_tensor_desc_
);
}
void
die
()
{
std
::
cerr
<<
"batchnorm not initialized"
<<
std
::
endl
;
exit
(
-
1
);
}
void
fwd
(
cudaStream_t
stream
,
bool
use_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
void
dgrad
(
cudaStream_t
stream
,
bool
use_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
void
fwdInference
(
cudaStream_t
stream
,
bool
use_relu
);
dim3
calc_fwd_grid
(
int
*
loop
,
const
int
grid_dim_x
);
dim3
calc_bwd_grid
(
int
*
loop
,
const
int
grid_dim_x
);
void
setInputDescriptor
(
const
cudnnTensorFormat_t
format
,
const
cudnnDataType_t
data_type
,
int
n
,
int
c
,
int
h
,
int
w
,
int
bn_group
)
{
m_
=
n
*
h
*
w
;
int
m_bn_adjusted
=
m_
*
bn_group
;
c_
=
c
;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
svar_inv_count_
=
1.
f
/
m_bn_adjusted
;
// factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1).
int
divisor
=
m_bn_adjusted
-
1
;
// nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs.
rvar_inv_count_
=
divisor
==
0
?
1.
f
:
1.
f
/
divisor
;
setTensorDescriptor
(
X_tensor_desc_
,
format
,
data_type
,
n
,
c
,
h
,
w
);
}
void
setOutputDescriptor
(
const
cudnnTensorFormat_t
format
,
const
cudnnDataType_t
data_type
,
int
n
,
int
c
,
int
h
,
int
w
)
{
setTensorDescriptor
(
Y_tensor_desc_
,
format
,
data_type
,
n
,
c
,
h
,
w
);
}
const
std
::
vector
<
size_t
>
numWorkspaceBytes
()
const
;
void
setWorkspacePointers
(
const
std
::
vector
<
void
*>&
workspace
,
const
std
::
vector
<
size_t
>&
num_workspace_bytes
);
void
setInputOutputPointers
(
void
*
X
,
void
*
dX
,
void
*
Y
,
void
*
dY
)
{
X_
=
X
;
dX_
=
dX
;
Y_
=
Y
;
dY_
=
dY
;
}
// Sets the pointers for the scale and weight (in that order) data and derivative buffers.
void
setWeightPointers
(
const
std
::
vector
<
void
*>&
weight_pointers
,
const
std
::
vector
<
void
*>&
deriv_pointers
)
{
assert
(
weight_pointers
.
size
()
==
2
);
assert
(
deriv_pointers
.
size
()
==
2
);
scale_
=
static_cast
<
float
*>
(
weight_pointers
[
0
]);
bias_
=
static_cast
<
float
*>
(
weight_pointers
[
1
]);
dscale_
=
static_cast
<
float
*>
(
deriv_pointers
[
0
]);
dbias_
=
static_cast
<
float
*>
(
deriv_pointers
[
1
]);
}
// Sets the pointers for the population mean and variance buffers, in that order.
void
setParameterPointers
(
const
std
::
vector
<
void
*>&
param_pointers
)
{
assert
(
param_pointers
.
size
()
==
2
);
population_mean_
=
static_cast
<
float
*>
(
param_pointers
[
0
]);
population_variance_
=
static_cast
<
float
*>
(
param_pointers
[
1
]);
}
void
setConstants
(
const
double
exp_avg_factor
,
const
double
eps
)
{
exp_avg_factor_
=
exp_avg_factor
;
eps_
=
eps
;
}
void
processCudnnStatus
(
const
cudnnStatus_t
&
status
,
const
std
::
string
&
string
=
std
::
string
(),
bool
verbose
=
VERBOSE_DEFAULT
)
{
if
(
status
!=
CUDNN_STATUS_SUCCESS
)
LOG
(
FATAL
)
<<
string
<<
" "
<<
cudnnGetErrorString
(
status
);
else
if
(
verbose
)
LOG
(
INFO
)
<<
string
<<
" "
<<
cudnnGetErrorString
(
status
);
}
void
checkCudaStatus
(
const
std
::
string
&
string
=
std
::
string
(),
bool
verbose
=
VERBOSE_DEFAULT
)
{
cudaError_t
status
=
cudaGetLastError
();
if
(
status
!=
cudaSuccess
)
LOG
(
FATAL
)
<<
string
<<
" "
<<
cudaGetErrorString
(
status
);
else
if
(
verbose
)
LOG
(
INFO
)
<<
string
<<
" "
<<
cudaGetErrorString
(
status
);
}
size_t
size_retired_ctas
(
int
grid_y
)
const
{
// Note that the value of max_grid_y to handle known GPUs is about 160.
const
int
max_grid_y
=
1024
;
if
(
grid_y
>
max_grid_y
)
LOG
(
INFO
)
<<
"GPU capabilities exceeds assumptions."
;
const
int
retired_cta_bytes
=
max_grid_y
*
2
*
sizeof
(
int
);
// Since the region will be initialized once and used for many kernels,
// the idea is to return an ample size that will cover all uses.
return
retired_cta_bytes
;
}
cudnnTensorDescriptor_t
X_tensor_desc_
=
nullptr
;
cudnnTensorDescriptor_t
Y_tensor_desc_
=
nullptr
;
void
*
X_
=
nullptr
;
void
*
dX_
=
nullptr
;
void
*
Y_
=
nullptr
;
void
*
dY_
=
nullptr
;
// Learned scale and bias weights.
float
*
scale_
=
nullptr
;
float
*
dscale_
=
nullptr
;
float
*
bias_
=
nullptr
;
float
*
dbias_
=
nullptr
;
// Computed population mean and variance parameters.
float
*
population_mean_
=
nullptr
;
float
*
population_variance_
=
nullptr
;
// Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd).
float
*
minibatch_mean_
=
nullptr
;
float
*
minibatch_variance_
=
nullptr
;
int
m_
=
0
;
// Number of values per channel that BN is normalizing.
int
c_
=
0
;
// Number of channels over which BN is normalizing.
float
svar_inv_count_
=
0.
f
;
// factor to scale sum of squared errors to get saved variance
float
rvar_inv_count_
=
0.
f
;
// factor to scale sum of squared errors to get running variance
double
exp_avg_factor_
=
0.
;
double
eps_
=
0.
;
std
::
string
name_
;
private:
void
setTensorDescriptor
(
cudnnTensorDescriptor_t
descriptor
,
cudnnTensorFormat_t
format
,
cudnnDataType_t
data_type
,
int
n
,
int
c
,
int
h
,
int
w
)
{
cudnnStatus_t
status
=
CUDNN_STATUS_SUCCESS
;
status
=
cudnnSetTensor4dDescriptor
(
descriptor
,
format
,
data_type
,
n
,
c
,
h
,
w
);
processCudnnStatus
(
status
,
"set tensor descriptor"
);
}
void
createTensorDescriptor
(
cudnnTensorDescriptor_t
*
descriptor
)
{
cudnnStatus_t
status
=
CUDNN_STATUS_SUCCESS
;
status
=
cudnnCreateTensorDescriptor
(
descriptor
);
processCudnnStatus
(
status
,
"create tensor_descriptor"
);
}
void
destroyTensorDescriptor
(
cudnnTensorDescriptor_t
descriptor
)
{
cudnnStatus_t
status
=
CUDNN_STATUS_SUCCESS
;
status
=
cudnnDestroyTensorDescriptor
(
descriptor
);
processCudnnStatus
(
status
,
"destroy tensor_descriptor"
);
}
protected:
float
*
partial_sums_
=
nullptr
;
int
*
partial_counts_
=
nullptr
;
int
*
retired_ctas_
=
nullptr
;
void
_setFwdParams
(
NhwcBatchNormFwdParams
*
params
)
const
;
void
_setFwdInferenceParams
(
NhwcBatchNormFwdInferenceParams
*
params
)
const
;
void
_setBwdParams
(
NhwcBatchNormBwdParams
*
params
)
const
;
// @todo: ability to configure these?
// Kernel params
static
const
int
USE_ONLINE_APPROACH
=
1
;
static
const
int
THREADS_PER_CTA
=
512
;
static
const
int
THREADS_PER_PIXEL
=
16
;
static
const
int
C_ELEMENTS_PER_CTA
=
64
;
static
const
int
ELEMENTS_PER_LDG
=
C_ELEMENTS_PER_CTA
/
THREADS_PER_PIXEL
;
static
const
int
MAX_SMEM_WITHOUT_OPT_IN
=
48
*
1024
;
typedef
uint16_t
StorageType
;
//typedef float StorageType;
// increasing this to 6 causes spills in fwd kernel!
static
const
int
PIXELS_PER_THREAD_IN_REGISTERS_FWD
=
5
;
static
const
int
PIXELS_PER_THREAD_IN_REGISTERS_BWD
=
3
;
static
const
int
PIXELS_PER_THREAD_IN_SMEM_FWD
=
10
;
static
const
int
PIXELS_PER_THREAD_IN_SMEM_BWD
=
5
;
static
const
int
PIXELS_PER_THREAD_FWD
=
PIXELS_PER_THREAD_IN_REGISTERS_FWD
+
\
PIXELS_PER_THREAD_IN_SMEM_FWD
;
static
const
int
PIXELS_PER_THREAD_BWD
=
PIXELS_PER_THREAD_IN_REGISTERS_BWD
+
\
PIXELS_PER_THREAD_IN_SMEM_BWD
;
static
const
int
PIXELS_PER_THREAD_FWD_INFERENCE
=
4
;
// Derived params
static
const
size_t
SMEM_SIZE_FWD
=
PIXELS_PER_THREAD_IN_SMEM_FWD
*
THREADS_PER_CTA
*
\
ELEMENTS_PER_LDG
*
sizeof
(
StorageType
);
static
const
size_t
SMEM_SIZE_BWD
=
PIXELS_PER_THREAD_IN_SMEM_BWD
*
THREADS_PER_CTA
*
\
ELEMENTS_PER_LDG
*
2
*
sizeof
(
StorageType
);
static
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
static
const
int
PIXELS_PER_CTA_FWD
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
*
\
PIXELS_PER_THREAD_FWD
;
static
const
int
PIXELS_PER_CTA_BWD
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
*
\
PIXELS_PER_THREAD_BWD
;
static
const
int
PIXELS_PER_CTA_FWD_INFERENCE
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
*
\
PIXELS_PER_THREAD_FWD_INFERENCE
;
// max grid.y in case of group bn is limited by exchange buffer size
static
const
int
MAX_GBN_BLOCK_Y
=
256
;
// Helper function to launch the forward kernel.
// We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel
// version that was compiled with that occupancy in its launch bounds. This way, we avoid
// needless register spills.
void
_fwdKernelLauncher
(
cudaStream_t
stream
,
NhwcBatchNormFwdParams
params
,
dim3
grid_dim
,
int
outer_loops
,
bool
use_relu
,
const
int
occupancy
,
const
bool
coop
)
{
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(¶ms); \
using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
cudaLaunchKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if
(
outer_loops
==
1
&&
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
1
,
true
,
false
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
1
,
true
,
false
,
1
,
coop
);
}
else
if
(
outer_loops
==
1
&&
!
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
1
,
false
,
false
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
1
,
false
,
false
,
1
,
coop
);
}
else
if
(
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
0
,
true
,
false
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
0
,
true
,
false
,
1
,
coop
);
}
else
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
0
,
false
,
false
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
0
,
false
,
false
,
1
,
coop
);
}
#undef LAUNCH_FWD_KERNEL
}
// Helper function to launch the backward kernel.
void
_bwdKernelLauncher
(
cudaStream_t
stream
,
NhwcBatchNormBwdParams
params
,
dim3
grid_dim
,
int
outer_loops
,
bool
use_relu
,
const
int
occupancy
,
const
bool
coop
)
{
#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_func = nhwc_batch_norm_bwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(bwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " bwd coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(¶ms); \
using BWD_FUNC = decltype(nhwc_batch_norm_bwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
cudaLaunchKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd coop serial kernel"); \
} while (0)
#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_relu_func = nhwc_batch_norm_bwd_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(bwd_relu_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " bwd-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(¶ms); \
using BWD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
cudaLaunchKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if
(
outer_loops
==
1
&&
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_RELU_KERNEL
(
1
,
2
,
coop
);
else
LAUNCH_BWD_RELU_KERNEL
(
1
,
1
,
coop
);
}
else
if
(
outer_loops
==
1
&&
!
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_KERNEL
(
1
,
2
,
coop
);
else
LAUNCH_BWD_KERNEL
(
1
,
1
,
coop
);
}
else
if
(
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_RELU_KERNEL
(
0
,
2
,
coop
);
else
LAUNCH_BWD_RELU_KERNEL
(
0
,
1
,
coop
);
}
else
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_KERNEL
(
0
,
2
,
coop
);
else
LAUNCH_BWD_KERNEL
(
0
,
1
,
coop
);
}
#undef LAUNCH_BWD_KERNEL
}
public:
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static
int
smem_driven_fwd_occupancy
(
int
device_id
,
const
int
max_cta_per_sm
)
{
using
namespace
at
::
cuda
::
utils
;
int
fwd_reduction_bytes
=
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
32
)
*
ELEMENTS_PER_LDG
*
sizeof
(
float
);
int
fwd_smem_bytes
=
SMEM_SIZE_FWD
+
fwd_reduction_bytes
;
int
occupancy
=
MaxSharedMemoryPerMultiprocessor
(
device_id
)
/
fwd_smem_bytes
;
return
std
::
min
(
max_cta_per_sm
,
occupancy
);
}
// Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.
static
int
smem_driven_bwd_occupancy
(
int
device_id
,
const
int
max_cta_per_sm
)
{
using
namespace
at
::
cuda
::
utils
;
int
bwd_reduction_bytes
=
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
32
)
*
ELEMENTS_PER_LDG
*
sizeof
(
float
);
int
bwd_smem_bytes
=
SMEM_SIZE_BWD
+
bwd_reduction_bytes
;
int
occupancy
=
MaxSharedMemoryPerMultiprocessor
(
device_id
)
/
bwd_smem_bytes
;
return
std
::
min
(
max_cta_per_sm
,
occupancy
);
}
};
const
std
::
vector
<
size_t
>
NhwcBatchNorm
::
numWorkspaceBytes
()
const
{
assert
(
c_
>
0
);
// choose the max memory required between fwd/bwd passes
int
grid_x_fwd
=
div_up
(
m_
,
PIXELS_PER_CTA_FWD
);
int
grid_x_bwd
=
div_up
(
m_
,
PIXELS_PER_CTA_BWD
);
int
grid_x
=
max
(
grid_x_fwd
,
grid_x_bwd
);
int
grid_y
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
const
size_t
num_mean_bytes
=
c_
*
sizeof
(
float
);
const
size_t
num_variance_bytes
=
num_mean_bytes
;
const
size_t
size_sums
=
grid_y
*
grid_x
*
THREADS_PER_PIXEL
*
\
ELEMENTS_PER_LDG
*
2
*
sizeof
(
float
);
const
size_t
size_counts
=
grid_y
*
grid_x
*
sizeof
(
int
);
return
{
num_mean_bytes
,
num_variance_bytes
,
size_retired_ctas
(
grid_y
),
size_sums
,
size_counts
};
}
void
NhwcBatchNorm
::
setWorkspacePointers
(
const
std
::
vector
<
void
*>&
workspace
,
const
std
::
vector
<
size_t
>&
num_workspace_bytes
)
{
assert
(
workspace
.
size
()
==
5
);
assert
(
num_workspace_bytes
.
size
()
==
5
);
minibatch_mean_
=
static_cast
<
float
*>
(
workspace
[
0
]);
minibatch_variance_
=
static_cast
<
float
*>
(
workspace
[
1
]);
retired_ctas_
=
static_cast
<
int
*>
(
workspace
[
2
]);
partial_sums_
=
static_cast
<
float
*>
(
workspace
[
3
]);
partial_counts_
=
static_cast
<
int
*>
(
workspace
[
4
]);
}
void
NhwcBatchNorm
::
_setFwdParams
(
NhwcBatchNormFwdParams
*
params
)
const
{
params
->
gmem_src
=
static_cast
<
uint16_t
*>
(
X_
);
params
->
gmem_dst
=
static_cast
<
uint16_t
*>
(
Y_
);
params
->
gmem_src1
=
nullptr
;
params
->
gmem_bias
=
bias_
;
params
->
gmem_scale
=
scale_
;
params
->
gmem_running_mean
=
population_mean_
;
params
->
gmem_running_var
=
population_variance_
;
params
->
gmem_saved_mean
=
minibatch_mean_
;
params
->
gmem_saved_var
=
minibatch_variance_
;
params
->
gmem_relu_bitmask
=
nullptr
;
params
->
nhw
=
m_
;
params
->
c
=
c_
;
params
->
svar_inv_count
=
svar_inv_count_
;
params
->
rvar_inv_count
=
rvar_inv_count_
;
params
->
gmem_sums
=
partial_sums_
;
params
->
gmem_counts
=
partial_counts_
;
params
->
gmem_retired_ctas
=
retired_ctas_
;
params
->
var_eps
=
eps_
;
params
->
outer_loops
=
0
;
params
->
exp_avg_factor
=
static_cast
<
float
>
(
exp_avg_factor_
);
params
->
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
}
void
NhwcBatchNorm
::
_setFwdInferenceParams
(
NhwcBatchNormFwdInferenceParams
*
params
)
const
{
params
->
gmem_src
=
static_cast
<
uint16_t
*>
(
X_
);
params
->
gmem_dst
=
static_cast
<
uint16_t
*>
(
Y_
);
params
->
gmem_src1
=
nullptr
;
params
->
gmem_bias
=
bias_
;
params
->
gmem_scale
=
scale_
;
params
->
gmem_mean
=
population_mean_
;
params
->
gmem_var
=
population_variance_
;
params
->
nhw
=
m_
;
params
->
c
=
c_
;
params
->
var_eps
=
eps_
;
}
void
NhwcBatchNorm
::
_setBwdParams
(
NhwcBatchNormBwdParams
*
params
)
const
{
params
->
gmem_src
=
static_cast
<
uint16_t
*>
(
X_
);
params
->
gmem_dy
=
static_cast
<
uint16_t
*>
(
dY_
);
params
->
gmem_dst
=
static_cast
<
uint16_t
*>
(
dX_
);
params
->
gmem_dst1
=
nullptr
;
params
->
gmem_relu_bitmask
=
nullptr
;
params
->
gmem_dscale
=
dscale_
;
params
->
gmem_dbias
=
dbias_
;
params
->
gmem_scale
=
scale_
;
params
->
gmem_bias
=
bias_
;
params
->
gmem_saved_mean
=
minibatch_mean_
;
params
->
gmem_saved_var
=
minibatch_variance_
;
params
->
nhw
=
m_
;
params
->
c
=
c_
;
params
->
svar_inv_count
=
svar_inv_count_
;
params
->
gmem_sums
=
partial_sums_
;
params
->
gmem_retired_ctas
=
retired_ctas_
;
params
->
outer_loops
=
0
;
params
->
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
}
void
NhwcBatchNorm
::
fwdInference
(
cudaStream_t
stream
,
bool
use_relu
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
&&
scale_
!=
nullptr
&&
bias_
!=
nullptr
// && minibatch_mean_ != nullptr
// && minibatch_variance_ != nullptr
&&
population_mean_
!=
nullptr
&&
population_variance_
!=
nullptr
&&
X_
!=
nullptr
// && dX_ != nullptr
&&
Y_
!=
nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&&
partial_sums_
!=
nullptr
&&
partial_counts_
!=
nullptr
;
if
(
!
ptrs_are_set
)
die
();
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_FWD_INFERENCE
);
grid_dim
.
y
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
// @todo: maybe just move this inside initialize routine?
NhwcBatchNormFwdInferenceParams
params
;
_setFwdInferenceParams
(
&
params
);
if
(
use_relu
)
{
nhwc_batch_norm_fwd_inference
<
StorageType
,
THREADS_PER_CTA
,
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
,
true
,
false
>
<<<
grid_dim
,
THREADS_PER_CTA
,
0
,
stream
>>>
(
params
);
checkCudaStatus
(
name_
+
" fwd_inference-relu kernel"
);
}
else
{
nhwc_batch_norm_fwd_inference
<
StorageType
,
THREADS_PER_CTA
,
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
,
false
,
false
>
<<<
grid_dim
,
THREADS_PER_CTA
,
0
,
stream
>>>
(
params
);
checkCudaStatus
(
name_
+
" fwd_inference kernel"
);
}
}
dim3
NhwcBatchNorm
::
calc_fwd_grid
(
int
*
loop
,
const
int
grid_dim_x
)
{
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_FWD
);
int
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
unsigned
int
max_grid_x
=
grid_dim_x
;
if
(
grid_dim
.
x
<=
max_grid_x
)
{
*
loop
=
1
;
if
(
max_grid_x
/
grid_dim
.
x
>
1
)
{
grid_dim
.
y
=
std
::
min
(
c_blks
,
static_cast
<
int
>
(
max_grid_x
/
grid_dim
.
x
));
assert
(
grid_dim
.
y
<
MAX_GBN_BLOCK_Y
);
//FIXME: turn into a loop
}
else
{
grid_dim
.
y
=
1
;
}
}
else
{
grid_dim
.
x
=
max_grid_x
;
grid_dim
.
y
=
1
;
int
nhw_in_regs
=
m_
-
PIXELS_PER_THREAD_IN_SMEM_FWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
int
pixels_per_iteration
=
PIXELS_PER_THREAD_IN_REGISTERS_FWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
*
loop
=
div_up
(
nhw_in_regs
,
pixels_per_iteration
);
}
return
grid_dim
;
}
dim3
NhwcBatchNorm
::
calc_bwd_grid
(
int
*
loop
,
const
int
grid_dim_x
)
{
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_BWD
);
int
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
unsigned
int
max_grid_x
=
grid_dim_x
;
if
(
grid_dim
.
x
<=
max_grid_x
)
{
*
loop
=
1
;
if
(
max_grid_x
/
grid_dim
.
x
>
1
)
{
grid_dim
.
y
=
std
::
min
(
c_blks
,
static_cast
<
int
>
(
max_grid_x
/
grid_dim
.
x
));
assert
(
grid_dim
.
y
<
MAX_GBN_BLOCK_Y
);
//FIXME: turn into a loop
}
else
{
grid_dim
.
y
=
1
;
}
}
else
{
grid_dim
.
x
=
max_grid_x
;
grid_dim
.
y
=
1
;
int
nhw_in_regs
=
m_
-
PIXELS_PER_THREAD_IN_SMEM_BWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
int
pixels_per_iteration
=
PIXELS_PER_THREAD_IN_REGISTERS_BWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
*
loop
=
div_up
(
nhw_in_regs
,
pixels_per_iteration
);
}
return
grid_dim
;
}
void
NhwcBatchNorm
::
fwd
(
cudaStream_t
stream
,
bool
use_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
&&
scale_
!=
nullptr
&&
bias_
!=
nullptr
&&
minibatch_mean_
!=
nullptr
&&
minibatch_variance_
!=
nullptr
&&
population_mean_
!=
nullptr
&&
population_variance_
!=
nullptr
&&
X_
!=
nullptr
// && dX_ != nullptr
&&
Y_
!=
nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&&
partial_sums_
!=
nullptr
&&
partial_counts_
!=
nullptr
&&
retired_ctas_
!=
nullptr
;
if
(
!
ptrs_are_set
)
die
();
// reset of retired_cta_count no longer needed
NhwcBatchNormFwdParams
params
;
_setFwdParams
(
&
params
);
params
.
my_data
=
my_data
;
params
.
pair_datas
[
0
]
=
pair_data
;
params
.
pair_datas
[
1
]
=
pair_data2
;
params
.
pair_datas
[
2
]
=
pair_data3
;
params
.
magic
=
magic
;
params
.
sync_iters
=
(
bn_group
==
8
)
?
3
:
(
bn_group
>>
1
);
dim3
grid_dim
=
calc_fwd_grid
(
&
params
.
outer_loops
,
grid_dim_x
);
_fwdKernelLauncher
(
stream
,
params
,
grid_dim
,
params
.
outer_loops
,
use_relu
,
occupancy
,
coop
);
}
void
NhwcBatchNorm
::
dgrad
(
cudaStream_t
stream
,
bool
use_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
&&
scale_
!=
nullptr
&&
(
bias_
!=
nullptr
||
!
use_relu
)
&&
minibatch_mean_
!=
nullptr
&&
minibatch_variance_
!=
nullptr
// && population_mean_ != nullptr
// && population_variance_ != nullptr
&&
X_
!=
nullptr
&&
dX_
!=
nullptr
// && Y_ != nullptr
&&
dY_
!=
nullptr
&&
dscale_
!=
nullptr
&&
dbias_
!=
nullptr
;
if
(
!
ptrs_are_set
)
die
();
// reset of retired_cta_count no longer needed
NhwcBatchNormBwdParams
params
;
_setBwdParams
(
&
params
);
params
.
my_data
=
my_data
;
params
.
pair_datas
[
0
]
=
pair_data
;
params
.
pair_datas
[
1
]
=
pair_data2
;
params
.
pair_datas
[
2
]
=
pair_data3
;
params
.
magic
=
magic
;
params
.
sync_iters
=
(
bn_group
==
8
)
?
3
:
(
bn_group
>>
1
);
params
.
wgrad_coeff
=
1.0
/
bn_group
;
dim3
grid_dim
=
calc_bwd_grid
(
&
params
.
outer_loops
,
grid_dim_x
);
_bwdKernelLauncher
(
stream
,
params
,
grid_dim
,
params
.
outer_loops
,
use_relu
,
occupancy
,
coop
);
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
apex/contrib/csrc/groupbn/batch_norm_add_relu.cu
0 → 100644
View file @
4d6ed501
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include "THC/THC.h"
#include "batch_norm_add_relu.h"
#include <cuda.h>
//FIXME move the common stuff to common h file
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
static
size_t
round_up_to_multiple
(
size_t
x
,
int
multiple
)
{
return
((
x
+
multiple
-
1
)
/
multiple
)
*
multiple
;
}
// TODO: Stop manually allocating CUDA memory; allocate an ATen byte
// tensor instead.
struct
Workspace
{
Workspace
(
size_t
size
)
:
size
(
size
),
data
(
NULL
)
{
data
=
THCudaMalloc
(
at
::
globalContext
().
lazyInitCUDA
(),
size
);
}
Workspace
(
const
Workspace
&
)
=
delete
;
Workspace
(
Workspace
&&
)
=
default
;
Workspace
&
operator
=
(
Workspace
&&
)
=
default
;
~
Workspace
()
{
if
(
data
)
{
THCudaFree
(
at
::
globalContext
().
lazyInitCUDA
(),
data
);
}
}
size_t
size
;
void
*
data
;
};
// Return {y}
at
::
Tensor
nhwc_bn_addrelu_fwd_train
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
z
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
bitmask
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
const
int
N
=
x
.
size
(
0
);
const
int
H
=
x
.
size
(
1
);
const
int
W
=
x
.
size
(
2
);
const
int
C
=
x
.
size
(
3
);
// generating new magic number and use that for sync
int
*
magic
=
magic_tensor
.
data
<
int
>
();
*
magic
=
(
*
magic
+
1
)
&
0xff
;
// Allocate output tensor
at
::
Tensor
y
=
at
::
empty
({
N
,
H
,
W
,
C
},
x
.
options
());
// Create wrapper
NhwcBatchNormAddRelu
*
bn
=
new
NhwcBatchNormAddRelu
();
bn
->
setInputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
,
bn_group
);
bn
->
setOutputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
);
bn
->
setConstants
(
momentum
,
epsilon
);
// set pointers within the wrapper
bn
->
setInputOutputPointers
(
x
.
data
<
at
::
Half
>
(),
nullptr
,
y
.
data
<
at
::
Half
>
(),
nullptr
,
z
.
data
<
at
::
Half
>
(),
nullptr
);
bn
->
setWeightPointers
({
scale
.
data
<
float
>
(),
bias
.
data
<
float
>
()},
{
nullptr
,
nullptr
});
bn
->
setParameterPointers
({
running_mean
.
data
<
float
>
(),
running_inv_var
.
data
<
float
>
()});
// deal with workspace(s)
auto
workspace_bytes
=
bn
->
numWorkspaceBytes
();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t
total_workspace_bytes
=
0
;
std
::
vector
<
size_t
>
workspace_offsets
;
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
total_workspace_bytes
=
round_up_to_multiple
(
total_workspace_bytes
,
512
);
workspace_offsets
.
push_back
(
total_workspace_bytes
);
auto
alloc_bytes
=
workspace_bytes
[
index
];
total_workspace_bytes
+=
alloc_bytes
;
}
// Allocate the workspace
Workspace
ws
(
total_workspace_bytes
);
std
::
vector
<
void
*>
workspace
;
workspace
.
push_back
(
minibatch_mean
.
data
<
float
>
());
workspace
.
push_back
(
minibatch_inv_var
.
data
<
float
>
());
workspace
.
push_back
(
bitmask
.
data
<
int32_t
>
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
3
];
void
*
retired_ctas
=
ret_cta
.
data
<
uint8_t
>
();
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
void
*
ptr
=
reinterpret_cast
<
uint8_t
*>
(
ws
.
data
)
+
workspace_offsets
[
index
-
4
];
workspace
.
push_back
(
ptr
);
}
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
// Don't fuse in ReLU for now at least
bn
->
fwd
(
stream
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
*
magic
,
occupancy
,
grid_dim_x
,
coop
);
return
y
;
}
at
::
Tensor
nhwc_bn_addrelu_fwd_eval
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
z
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
int
bn_group
,
const
float
momentum
,
const
float
epsilon
)
{
const
int
N
=
x
.
size
(
0
);
const
int
H
=
x
.
size
(
1
);
const
int
W
=
x
.
size
(
2
);
const
int
C
=
x
.
size
(
3
);
// Allocate output tensor
at
::
Tensor
y
=
at
::
empty
({
N
,
H
,
W
,
C
},
x
.
options
());
// Create wrapper
NhwcBatchNormAddRelu
*
bn
=
new
NhwcBatchNormAddRelu
();
bn
->
setInputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
,
bn_group
);
bn
->
setOutputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
);
bn
->
setConstants
(
momentum
,
epsilon
);
// set pointers within the wrapper
bn
->
setInputOutputPointers
(
x
.
data
<
at
::
Half
>
(),
nullptr
,
y
.
data
<
at
::
Half
>
(),
nullptr
,
z
.
data
<
at
::
Half
>
(),
nullptr
);
bn
->
setWeightPointers
({
scale
.
data
<
float
>
(),
bias
.
data
<
float
>
()},
{
nullptr
,
nullptr
});
bn
->
setParameterPointers
({
running_mean
.
data
<
float
>
(),
running_inv_var
.
data
<
float
>
()});
// deal with workspace(s)
auto
workspace_bytes
=
bn
->
numWorkspaceBytes
();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t
total_workspace_bytes
=
0
;
std
::
vector
<
size_t
>
workspace_offsets
;
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
total_workspace_bytes
=
round_up_to_multiple
(
total_workspace_bytes
,
512
);
workspace_offsets
.
push_back
(
total_workspace_bytes
);
auto
alloc_bytes
=
workspace_bytes
[
index
];
total_workspace_bytes
+=
alloc_bytes
;
}
// Allocate the workspace
Workspace
ws
(
total_workspace_bytes
);
std
::
vector
<
void
*>
workspace
;
workspace
.
push_back
(
nullptr
);
workspace
.
push_back
(
nullptr
);
workspace
.
push_back
(
nullptr
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
3
];
void
*
retired_ctas
=
ret_cta
.
data
<
uint8_t
>
();
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
void
*
ptr
=
reinterpret_cast
<
uint8_t
*>
(
ws
.
data
)
+
workspace_offsets
[
index
-
4
];
workspace
.
push_back
(
ptr
);
}
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
// Don't fuse in ReLU for now at least
bn
->
fwdInference
(
stream
);
return
y
;
}
std
::
vector
<
at
::
Tensor
>
nhwc_bn_addrelu_bwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
dy
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
bitmask
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
// shape
const
int
N
=
x
.
size
(
0
);
const
int
H
=
x
.
size
(
1
);
const
int
W
=
x
.
size
(
2
);
const
int
C
=
x
.
size
(
3
);
// generating new magic number and use that for sync
int
*
magic
=
magic_tensor
.
data
<
int
>
();
*
magic
=
(
*
magic
+
1
)
&
0xff
;
// outputs
at
::
Tensor
x_grad
,
z_grad
,
scale_grad
,
bias_grad
;
// Allocate outputs
x_grad
=
at
::
empty_like
(
x
);
z_grad
=
at
::
empty_like
(
x
);
scale_grad
=
at
::
empty_like
(
scale
);
bias_grad
=
at
::
empty_like
(
bias
);
// Create wrapper
NhwcBatchNormAddRelu
*
bn
=
new
NhwcBatchNormAddRelu
();
bn
->
setInputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
,
bn_group
);
bn
->
setOutputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
);
bn
->
setConstants
(
momentum
,
epsilon
);
// set pointers within the wrapper
bn
->
setInputOutputPointers
(
x
.
data
<
at
::
Half
>
(),
x_grad
.
data
<
at
::
Half
>
(),
nullptr
,
dy
.
data
<
at
::
Half
>
(),
nullptr
,
z_grad
.
data
<
at
::
Half
>
());
bn
->
setWeightPointers
({
scale
.
data
<
float
>
(),
bias
.
data
<
float
>
()},
{
scale_grad
.
data
<
float
>
(),
bias_grad
.
data
<
float
>
()});
bn
->
setParameterPointers
({
running_mean
.
data
<
float
>
(),
running_inv_var
.
data
<
float
>
()});
// deal with workspace(s)
auto
workspace_bytes
=
bn
->
numWorkspaceBytes
();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t
total_workspace_bytes
=
0
;
std
::
vector
<
size_t
>
workspace_offsets
;
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
total_workspace_bytes
=
round_up_to_multiple
(
total_workspace_bytes
,
512
);
workspace_offsets
.
push_back
(
total_workspace_bytes
);
auto
alloc_bytes
=
workspace_bytes
[
index
];
total_workspace_bytes
+=
alloc_bytes
;
}
// Allocate the workspace
Workspace
ws
(
total_workspace_bytes
);
std
::
vector
<
void
*>
workspace
;
workspace
.
push_back
(
minibatch_mean
.
data
<
float
>
());
workspace
.
push_back
(
minibatch_inv_var
.
data
<
float
>
());
workspace
.
push_back
(
bitmask
.
data
<
int32_t
>
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
3
];
void
*
retired_ctas
=
ret_cta
.
data
<
uint8_t
>
();
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
void
*
ptr
=
reinterpret_cast
<
uint8_t
*>
(
ws
.
data
)
+
workspace_offsets
[
index
-
4
];
workspace
.
push_back
(
ptr
);
}
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
bn
->
dgrad
(
stream
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
*
magic
,
occupancy
,
grid_dim_x
,
coop
);
return
std
::
vector
<
at
::
Tensor
>
{
x_grad
,
z_grad
,
scale_grad
,
bias_grad
};
}
int
nhwc_bn_addrelu_fwd_occupancy
()
{
int
device_id
=-
1
;
cudaGetDevice
(
&
device_id
);
//max occupancy supported by the code is 2
return
NhwcBatchNormAddRelu
::
smem_driven_fwd_occupancy
(
device_id
,
2
);
}
int
nhwc_bn_addrelu_bwd_occupancy
()
{
int
device_id
=-
1
;
cudaGetDevice
(
&
device_id
);
//max occupancy supported by the code is 2
return
NhwcBatchNormAddRelu
::
smem_driven_bwd_occupancy
(
device_id
,
2
);
}
apex/contrib/csrc/groupbn/batch_norm_add_relu.h
0 → 100644
View file @
4d6ed501
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nhwc_batch_norm_add_relu.h
* \brief CUDA NHWC Batch Normalization code with fused addition
* \author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer
*/
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
#include <cudnn.h>
#include <algorithm>
#include <vector>
#include <string>
#include "nhwc_batch_norm_kernel.h"
#include "cuda_utils.h"
#define VERBOSE_DEFAULT false
class
NhwcBatchNormAddRelu
{
public:
NhwcBatchNormAddRelu
()
{
name_
=
"nhwc_batchnormaddrelu"
;
createTensorDescriptor
(
&
X_tensor_desc_
);
createTensorDescriptor
(
&
Y_tensor_desc_
);
}
~
NhwcBatchNormAddRelu
()
{
destroyTensorDescriptor
(
X_tensor_desc_
);
destroyTensorDescriptor
(
Y_tensor_desc_
);
}
void
die
()
{
std
::
cerr
<<
"batchnormaddrelu not initialized"
<<
std
::
endl
;
exit
(
-
1
);
}
void
fwd
(
cudaStream_t
stream
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
void
dgrad
(
cudaStream_t
stream
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
void
fwdInference
(
cudaStream_t
stream
);
dim3
calc_fwd_grid
(
int
*
loop
,
const
int
grid_dim_x
);
dim3
calc_bwd_grid
(
int
*
loop
,
const
int
grid_dim_x
);
void
setInputDescriptor
(
const
cudnnTensorFormat_t
format
,
const
cudnnDataType_t
data_type
,
int
n
,
int
c
,
int
h
,
int
w
,
int
bn_group
)
{
m_
=
n
*
h
*
w
;
int
m_bn_adjusted
=
m_
*
bn_group
;
c_
=
c
;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
svar_inv_count_
=
1.
f
/
m_bn_adjusted
;
// factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1).
int
divisor
=
m_bn_adjusted
-
1
;
// nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs.
rvar_inv_count_
=
divisor
==
0
?
1.
f
:
1.
f
/
divisor
;
setTensorDescriptor
(
X_tensor_desc_
,
format
,
data_type
,
n
,
c
,
h
,
w
);
}
void
setOutputDescriptor
(
const
cudnnTensorFormat_t
format
,
const
cudnnDataType_t
data_type
,
int
n
,
int
c
,
int
h
,
int
w
)
{
setTensorDescriptor
(
Y_tensor_desc_
,
format
,
data_type
,
n
,
c
,
h
,
w
);
}
const
std
::
vector
<
size_t
>
numWorkspaceBytes
()
const
;
void
setWorkspacePointers
(
const
std
::
vector
<
void
*>&
workspace
,
const
std
::
vector
<
size_t
>&
num_workspace_bytes
);
void
setInputOutputPointers
(
void
*
X
,
void
*
dX
,
void
*
Y
,
void
*
dY
,
void
*
addend
,
void
*
dAddend
)
{
X_
=
X
;
dX_
=
dX
;
Y_
=
Y
;
dY_
=
dY
;
addend_
=
addend
;
dAddend_
=
dAddend
;
}
// Sets the pointers for the scale and weight (in that order) data and derivative buffers.
void
setWeightPointers
(
const
std
::
vector
<
void
*>&
weight_pointers
,
const
std
::
vector
<
void
*>&
deriv_pointers
)
{
assert
(
weight_pointers
.
size
()
==
2
);
assert
(
deriv_pointers
.
size
()
==
2
);
scale_
=
static_cast
<
float
*>
(
weight_pointers
[
0
]);
bias_
=
static_cast
<
float
*>
(
weight_pointers
[
1
]);
dscale_
=
static_cast
<
float
*>
(
deriv_pointers
[
0
]);
dbias_
=
static_cast
<
float
*>
(
deriv_pointers
[
1
]);
}
// Sets the pointers for the population mean and variance buffers, in that order.
void
setParameterPointers
(
const
std
::
vector
<
void
*>&
param_pointers
)
{
assert
(
param_pointers
.
size
()
==
2
);
population_mean_
=
static_cast
<
float
*>
(
param_pointers
[
0
]);
population_variance_
=
static_cast
<
float
*>
(
param_pointers
[
1
]);
}
void
setConstants
(
const
double
exp_avg_factor
,
const
double
eps
)
{
exp_avg_factor_
=
exp_avg_factor
;
eps_
=
eps
;
}
void
processCudnnStatus
(
const
cudnnStatus_t
&
status
,
const
std
::
string
&
string
=
std
::
string
(),
bool
verbose
=
VERBOSE_DEFAULT
)
{
if
(
status
!=
CUDNN_STATUS_SUCCESS
)
LOG
(
FATAL
)
<<
string
<<
" "
<<
cudnnGetErrorString
(
status
);
else
if
(
verbose
)
LOG
(
INFO
)
<<
string
<<
" "
<<
cudnnGetErrorString
(
status
);
}
void
checkCudaStatus
(
const
std
::
string
&
string
=
std
::
string
(),
bool
verbose
=
VERBOSE_DEFAULT
)
{
cudaError_t
status
=
cudaGetLastError
();
if
(
status
!=
cudaSuccess
)
LOG
(
FATAL
)
<<
string
<<
" "
<<
cudaGetErrorString
(
status
);
else
if
(
verbose
)
LOG
(
INFO
)
<<
string
<<
" "
<<
cudaGetErrorString
(
status
);
}
size_t
size_retired_ctas
(
int
grid_y
)
const
{
// Note that the value of max_grid_y to handle known GPUs is about 160.
const
int
max_grid_y
=
1024
;
if
(
grid_y
>
max_grid_y
)
LOG
(
INFO
)
<<
"GPU capabilities exceeds assumptions."
;
const
int
retired_cta_bytes
=
max_grid_y
*
2
*
sizeof
(
int
);
// Since the region will be initialized once and used for many kernels,
// the idea is to return an ample size that will cover all uses.
return
retired_cta_bytes
;
}
cudnnTensorDescriptor_t
X_tensor_desc_
=
nullptr
;
cudnnTensorDescriptor_t
Y_tensor_desc_
=
nullptr
;
void
*
X_
=
nullptr
;
void
*
dX_
=
nullptr
;
void
*
Y_
=
nullptr
;
void
*
dY_
=
nullptr
;
void
*
addend_
=
nullptr
;
void
*
dAddend_
=
nullptr
;
// Learned scale and bias weights.
float
*
scale_
=
nullptr
;
float
*
dscale_
=
nullptr
;
float
*
bias_
=
nullptr
;
float
*
dbias_
=
nullptr
;
// Computed population mean and variance parameters.
float
*
population_mean_
=
nullptr
;
float
*
population_variance_
=
nullptr
;
// Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd).
float
*
minibatch_mean_
=
nullptr
;
float
*
minibatch_variance_
=
nullptr
;
int
m_
=
0
;
// Number of values per channel that BN is normalizing.
int
c_
=
0
;
// Number of channels over which BN is normalizing.
float
svar_inv_count_
=
0.
f
;
// factor to scale sum of squared errors to get saved variance
float
rvar_inv_count_
=
0.
f
;
// factor to scale sum of squared errors to get running variance
double
exp_avg_factor_
=
0.
;
double
eps_
=
0.
;
std
::
string
name_
;
private:
void
setTensorDescriptor
(
cudnnTensorDescriptor_t
descriptor
,
cudnnTensorFormat_t
format
,
cudnnDataType_t
data_type
,
int
n
,
int
c
,
int
h
,
int
w
)
{
cudnnStatus_t
status
=
CUDNN_STATUS_SUCCESS
;
status
=
cudnnSetTensor4dDescriptor
(
descriptor
,
format
,
data_type
,
n
,
c
,
h
,
w
);
processCudnnStatus
(
status
,
"set tensor descriptor"
);
}
void
createTensorDescriptor
(
cudnnTensorDescriptor_t
*
descriptor
)
{
cudnnStatus_t
status
=
CUDNN_STATUS_SUCCESS
;
status
=
cudnnCreateTensorDescriptor
(
descriptor
);
processCudnnStatus
(
status
,
"create tensor_descriptor"
);
}
void
destroyTensorDescriptor
(
cudnnTensorDescriptor_t
descriptor
)
{
cudnnStatus_t
status
=
CUDNN_STATUS_SUCCESS
;
status
=
cudnnDestroyTensorDescriptor
(
descriptor
);
processCudnnStatus
(
status
,
"destroy tensor_descriptor"
);
}
protected:
float
*
partial_sums_
=
nullptr
;
int
*
partial_counts_
=
nullptr
;
int
*
retired_ctas_
=
nullptr
;
unsigned
int
*
relu_bitmask_
=
nullptr
;
void
_setFwdParams
(
NhwcBatchNormFwdParams
*
params
)
const
;
void
_setFwdInferenceParams
(
NhwcBatchNormFwdInferenceParams
*
params
)
const
;
void
_setBwdParams
(
NhwcBatchNormBwdParams
*
params
)
const
;
// @todo: ability to configure these?
// Kernel params
static
const
int
USE_ONLINE_APPROACH
=
1
;
static
const
int
THREADS_PER_CTA
=
512
;
static
const
int
THREADS_PER_PIXEL
=
16
;
static
const
int
C_ELEMENTS_PER_CTA
=
64
;
static
const
int
ELEMENTS_PER_LDG
=
C_ELEMENTS_PER_CTA
/
THREADS_PER_PIXEL
;
static
const
int
MAX_SMEM_WITHOUT_OPT_IN
=
48
*
1024
;
typedef
uint16_t
StorageType
;
// increasing this to 6 causes spills in fwd kernel!
static
const
int
PIXELS_PER_THREAD_IN_REGISTERS_FWD
=
5
;
static
const
int
PIXELS_PER_THREAD_IN_REGISTERS_BWD
=
3
;
static
const
int
PIXELS_PER_THREAD_IN_SMEM_FWD
=
10
;
static
const
int
PIXELS_PER_THREAD_IN_SMEM_BWD
=
5
;
static
const
int
PIXELS_PER_THREAD_FWD
=
PIXELS_PER_THREAD_IN_REGISTERS_FWD
+
\
PIXELS_PER_THREAD_IN_SMEM_FWD
;
static
const
int
PIXELS_PER_THREAD_BWD
=
PIXELS_PER_THREAD_IN_REGISTERS_BWD
+
\
PIXELS_PER_THREAD_IN_SMEM_BWD
;
static
const
int
PIXELS_PER_THREAD_FWD_INFERENCE
=
4
;
// Derived params
static
const
size_t
SMEM_SIZE_FWD
=
PIXELS_PER_THREAD_IN_SMEM_FWD
*
THREADS_PER_CTA
*
\
ELEMENTS_PER_LDG
*
sizeof
(
StorageType
);
static
const
size_t
SMEM_SIZE_BWD
=
PIXELS_PER_THREAD_IN_SMEM_BWD
*
THREADS_PER_CTA
*
\
ELEMENTS_PER_LDG
*
2
*
sizeof
(
StorageType
);
static
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
static
const
int
PIXELS_PER_CTA_FWD
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
*
\
PIXELS_PER_THREAD_FWD
;
static
const
int
PIXELS_PER_CTA_BWD
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
*
\
PIXELS_PER_THREAD_BWD
;
static
const
int
PIXELS_PER_CTA_FWD_INFERENCE
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
*
\
PIXELS_PER_THREAD_FWD_INFERENCE
;
// max grid.y in case of group bn is limited by exchange buffer size
static
const
int
MAX_GBN_BLOCK_Y
=
256
;
// Helper function to launch the forward kernel.
// We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel
// version that was compiled with that occupancy in its launch bounds. This way, we avoid
// needless register spills.
void
_fwdKernelLauncher
(
cudaStream_t
stream
,
NhwcBatchNormFwdParams
params
,
dim3
grid_dim
,
int
outer_loops
,
const
int
occupancy
,
const
bool
coop
)
{
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(¶ms); \
using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
cudaLaunchKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if
(
outer_loops
==
1
)
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
1
,
false
,
true
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
1
,
false
,
true
,
1
,
coop
);
}
else
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
0
,
false
,
true
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
0
,
false
,
true
,
1
,
coop
);
}
#undef LAUNCH_FWD_KERNEL
}
// Helper function to launch the backward kernel.
void
_bwdKernelLauncher
(
cudaStream_t
stream
,
NhwcBatchNormBwdParams
params
,
dim3
grid_dim
,
int
outer_loops
,
const
int
occupancy
,
const
bool
coop
)
{
#define LAUNCH_BWD_ADD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
auto bwd_add_relu_func = nhwc_batch_norm_bwd_add_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(bwd_add_relu_func, \
cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + \
" bwd-add-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(¶ms); \
using BWD_ADD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_add_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
cudaLaunchKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-add-relu coop serial kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if
(
outer_loops
==
1
)
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_ADD_RELU_KERNEL
(
1
,
2
,
coop
);
else
LAUNCH_BWD_ADD_RELU_KERNEL
(
1
,
1
,
coop
);
}
else
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_ADD_RELU_KERNEL
(
0
,
2
,
coop
);
else
LAUNCH_BWD_ADD_RELU_KERNEL
(
0
,
1
,
coop
);
}
#undef LAUNCH_BWD_KERNEL
}
public:
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static
int
smem_driven_fwd_occupancy
(
int
device_id
,
const
int
max_cta_per_sm
)
{
using
namespace
at
::
cuda
::
utils
;
int
fwd_reduction_bytes
=
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
32
)
*
ELEMENTS_PER_LDG
*
sizeof
(
float
);
int
fwd_smem_bytes
=
SMEM_SIZE_FWD
+
fwd_reduction_bytes
;
int
occupancy
=
MaxSharedMemoryPerMultiprocessor
(
device_id
)
/
fwd_smem_bytes
;
return
std
::
min
(
max_cta_per_sm
,
occupancy
);
}
// Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.
static
int
smem_driven_bwd_occupancy
(
int
device_id
,
const
int
max_cta_per_sm
)
{
using
namespace
at
::
cuda
::
utils
;
int
bwd_reduction_bytes
=
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
32
)
*
ELEMENTS_PER_LDG
*
sizeof
(
float
);
int
bwd_smem_bytes
=
SMEM_SIZE_BWD
+
bwd_reduction_bytes
;
int
occupancy
=
MaxSharedMemoryPerMultiprocessor
(
device_id
)
/
bwd_smem_bytes
;
return
std
::
min
(
max_cta_per_sm
,
occupancy
);
}
};
const
std
::
vector
<
size_t
>
NhwcBatchNormAddRelu
::
numWorkspaceBytes
()
const
{
assert
(
c_
>
0
);
// choose the max memory required between fwd/bwd passes
int
grid_x_fwd
=
div_up
(
m_
,
PIXELS_PER_CTA_FWD
);
int
grid_x_bwd
=
div_up
(
m_
,
PIXELS_PER_CTA_BWD
);
int
grid_x
=
max
(
grid_x_fwd
,
grid_x_bwd
);
int
grid_y
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
const
size_t
num_mean_bytes
=
c_
*
sizeof
(
float
);
const
size_t
num_variance_bytes
=
num_mean_bytes
;
int
elems_per_group
=
((
m_
+
31
)
&
~
31
)
*
2
;
int
group_count
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
const
size_t
bitmask_bytes
=
elems_per_group
*
group_count
*
sizeof
(
unsigned
int
);
const
size_t
size_sums
=
grid_y
*
grid_x
*
THREADS_PER_PIXEL
*
\
ELEMENTS_PER_LDG
*
2
*
sizeof
(
float
);
const
size_t
size_counts
=
grid_y
*
grid_x
*
sizeof
(
int
);
return
{
num_mean_bytes
,
num_variance_bytes
,
bitmask_bytes
,
size_retired_ctas
(
grid_y
),
size_sums
,
size_counts
};
}
void
NhwcBatchNormAddRelu
::
setWorkspacePointers
(
const
std
::
vector
<
void
*>&
workspace
,
const
std
::
vector
<
size_t
>&
num_workspace_bytes
)
{
assert
(
workspace
.
size
()
==
6
);
assert
(
num_workspace_bytes
.
size
()
==
6
);
minibatch_mean_
=
static_cast
<
float
*>
(
workspace
[
0
]);
minibatch_variance_
=
static_cast
<
float
*>
(
workspace
[
1
]);
relu_bitmask_
=
static_cast
<
unsigned
int
*>
(
workspace
[
2
]);
retired_ctas_
=
static_cast
<
int
*>
(
workspace
[
3
]);
partial_sums_
=
static_cast
<
float
*>
(
workspace
[
4
]);
partial_counts_
=
static_cast
<
int
*>
(
workspace
[
5
]);
}
void
NhwcBatchNormAddRelu
::
_setFwdParams
(
NhwcBatchNormFwdParams
*
params
)
const
{
params
->
gmem_src
=
static_cast
<
uint16_t
*>
(
X_
);
params
->
gmem_dst
=
static_cast
<
uint16_t
*>
(
Y_
);
params
->
gmem_src1
=
static_cast
<
uint16_t
*>
(
addend_
);
params
->
gmem_bias
=
bias_
;
params
->
gmem_scale
=
scale_
;
params
->
gmem_running_mean
=
population_mean_
;
params
->
gmem_running_var
=
population_variance_
;
params
->
gmem_saved_mean
=
minibatch_mean_
;
params
->
gmem_saved_var
=
minibatch_variance_
;
params
->
gmem_relu_bitmask
=
relu_bitmask_
;
params
->
nhw
=
m_
;
params
->
c
=
c_
;
params
->
svar_inv_count
=
svar_inv_count_
;
params
->
rvar_inv_count
=
rvar_inv_count_
;
params
->
gmem_sums
=
partial_sums_
;
params
->
gmem_counts
=
partial_counts_
;
params
->
gmem_retired_ctas
=
retired_ctas_
;
params
->
var_eps
=
eps_
;
params
->
outer_loops
=
0
;
params
->
exp_avg_factor
=
static_cast
<
float
>
(
exp_avg_factor_
);
params
->
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
}
void
NhwcBatchNormAddRelu
::
_setFwdInferenceParams
(
NhwcBatchNormFwdInferenceParams
*
params
)
const
{
params
->
gmem_src
=
static_cast
<
uint16_t
*>
(
X_
);
params
->
gmem_dst
=
static_cast
<
uint16_t
*>
(
Y_
);
params
->
gmem_src1
=
static_cast
<
uint16_t
*>
(
addend_
);
params
->
gmem_bias
=
bias_
;
params
->
gmem_scale
=
scale_
;
params
->
gmem_mean
=
population_mean_
;
params
->
gmem_var
=
population_variance_
;
params
->
nhw
=
m_
;
params
->
c
=
c_
;
params
->
var_eps
=
eps_
;
}
void
NhwcBatchNormAddRelu
::
_setBwdParams
(
NhwcBatchNormBwdParams
*
params
)
const
{
params
->
gmem_src
=
static_cast
<
uint16_t
*>
(
X_
);
params
->
gmem_dy
=
static_cast
<
uint16_t
*>
(
dY_
);
params
->
gmem_dst
=
static_cast
<
uint16_t
*>
(
dX_
);
params
->
gmem_dst1
=
static_cast
<
uint16_t
*>
(
dAddend_
);
params
->
gmem_relu_bitmask
=
relu_bitmask_
;
params
->
gmem_dscale
=
dscale_
;
params
->
gmem_dbias
=
dbias_
;
params
->
gmem_scale
=
scale_
;
params
->
gmem_bias
=
bias_
;
params
->
gmem_saved_mean
=
minibatch_mean_
;
params
->
gmem_saved_var
=
minibatch_variance_
;
params
->
nhw
=
m_
;
params
->
c
=
c_
;
params
->
svar_inv_count
=
svar_inv_count_
;
params
->
gmem_sums
=
partial_sums_
;
params
->
gmem_retired_ctas
=
retired_ctas_
;
params
->
outer_loops
=
0
;
params
->
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
}
void
NhwcBatchNormAddRelu
::
fwdInference
(
cudaStream_t
stream
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
&&
scale_
!=
nullptr
&&
bias_
!=
nullptr
// && minibatch_mean_ != nullptr
// && minibatch_variance_ != nullptr
&&
population_mean_
!=
nullptr
&&
population_variance_
!=
nullptr
&&
X_
!=
nullptr
// && dX_ != nullptr
&&
Y_
!=
nullptr
&&
addend_
!=
nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&&
partial_sums_
!=
nullptr
&&
partial_counts_
!=
nullptr
;
if
(
!
ptrs_are_set
)
die
();
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_FWD_INFERENCE
);
grid_dim
.
y
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
// @todo: maybe just move this inside initialize routine?
NhwcBatchNormFwdInferenceParams
params
;
_setFwdInferenceParams
(
&
params
);
nhwc_batch_norm_fwd_inference
<
StorageType
,
THREADS_PER_CTA
,
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
,
false
,
true
>
<<<
grid_dim
,
THREADS_PER_CTA
,
0
,
stream
>>>
(
params
);
checkCudaStatus
(
name_
+
" fwd_inference-relu kernel"
);
}
dim3
NhwcBatchNormAddRelu
::
calc_fwd_grid
(
int
*
loop
,
const
int
grid_dim_x
)
{
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_FWD
);
int
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
unsigned
int
max_grid_x
=
grid_dim_x
;
if
(
grid_dim
.
x
<=
max_grid_x
)
{
*
loop
=
1
;
if
(
max_grid_x
/
grid_dim
.
x
>
1
)
{
grid_dim
.
y
=
std
::
min
(
c_blks
,
static_cast
<
int
>
(
max_grid_x
/
grid_dim
.
x
));
assert
(
grid_dim
.
y
<
MAX_GBN_BLOCK_Y
);
//FIXME: turn into a loop
}
else
{
grid_dim
.
y
=
1
;
}
}
else
{
grid_dim
.
x
=
max_grid_x
;
grid_dim
.
y
=
1
;
int
nhw_in_regs
=
m_
-
PIXELS_PER_THREAD_IN_SMEM_FWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
int
pixels_per_iteration
=
PIXELS_PER_THREAD_IN_REGISTERS_FWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
*
loop
=
div_up
(
nhw_in_regs
,
pixels_per_iteration
);
}
return
grid_dim
;
}
dim3
NhwcBatchNormAddRelu
::
calc_bwd_grid
(
int
*
loop
,
const
int
grid_dim_x
)
{
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_BWD
);
int
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
unsigned
int
max_grid_x
=
grid_dim_x
;
if
(
grid_dim
.
x
<=
max_grid_x
)
{
*
loop
=
1
;
if
(
max_grid_x
/
grid_dim
.
x
>
1
)
{
grid_dim
.
y
=
std
::
min
(
c_blks
,
static_cast
<
int
>
(
max_grid_x
/
grid_dim
.
x
));
assert
(
grid_dim
.
y
<
MAX_GBN_BLOCK_Y
);
//FIXME: turn into a loop
}
else
{
grid_dim
.
y
=
1
;
}
}
else
{
grid_dim
.
x
=
max_grid_x
;
grid_dim
.
y
=
1
;
int
nhw_in_regs
=
m_
-
PIXELS_PER_THREAD_IN_SMEM_BWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
int
pixels_per_iteration
=
PIXELS_PER_THREAD_IN_REGISTERS_BWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
*
loop
=
div_up
(
nhw_in_regs
,
pixels_per_iteration
);
}
return
grid_dim
;
}
void
NhwcBatchNormAddRelu
::
fwd
(
cudaStream_t
stream
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
&&
scale_
!=
nullptr
&&
bias_
!=
nullptr
&&
minibatch_mean_
!=
nullptr
&&
minibatch_variance_
!=
nullptr
&&
relu_bitmask_
!=
nullptr
&&
population_mean_
!=
nullptr
&&
population_variance_
!=
nullptr
&&
X_
!=
nullptr
// && dX_ != nullptr
&&
Y_
!=
nullptr
&&
addend_
!=
nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&&
partial_sums_
!=
nullptr
&&
partial_counts_
!=
nullptr
&&
retired_ctas_
!=
nullptr
;
if
(
!
ptrs_are_set
)
die
();
// reset of retired_cta_count no longer needed
NhwcBatchNormFwdParams
params
;
_setFwdParams
(
&
params
);
params
.
my_data
=
my_data
;
params
.
pair_datas
[
0
]
=
pair_data
;
params
.
pair_datas
[
1
]
=
pair_data2
;
params
.
pair_datas
[
2
]
=
pair_data3
;
params
.
magic
=
magic
;
params
.
sync_iters
=
(
bn_group
==
8
)
?
3
:
(
bn_group
>>
1
);
dim3
grid_dim
=
calc_fwd_grid
(
&
params
.
outer_loops
,
grid_dim_x
);
_fwdKernelLauncher
(
stream
,
params
,
grid_dim
,
params
.
outer_loops
,
occupancy
,
coop
);
}
void
NhwcBatchNormAddRelu
::
dgrad
(
cudaStream_t
stream
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
&&
scale_
!=
nullptr
&&
bias_
!=
nullptr
&&
minibatch_mean_
!=
nullptr
&&
minibatch_variance_
!=
nullptr
&&
relu_bitmask_
!=
nullptr
// && population_mean_ != nullptr
// && population_variance_ != nullptr
&&
X_
!=
nullptr
&&
dX_
!=
nullptr
// && Y_ != nullptr
&&
dY_
!=
nullptr
&&
dAddend_
!=
nullptr
&&
dscale_
!=
nullptr
&&
dbias_
!=
nullptr
&&
retired_ctas_
!=
nullptr
;
if
(
!
ptrs_are_set
)
die
();
// reset of retired_cta_count no longer needed
NhwcBatchNormBwdParams
params
;
_setBwdParams
(
&
params
);
params
.
my_data
=
my_data
;
params
.
pair_datas
[
0
]
=
pair_data
;
params
.
pair_datas
[
1
]
=
pair_data2
;
params
.
pair_datas
[
2
]
=
pair_data3
;
params
.
magic
=
magic
;
params
.
sync_iters
=
(
bn_group
==
8
)
?
3
:
(
bn_group
>>
1
);
params
.
wgrad_coeff
=
1.0
/
bn_group
;
dim3
grid_dim
=
calc_bwd_grid
(
&
params
.
outer_loops
,
grid_dim_x
);
_bwdKernelLauncher
(
stream
,
params
,
grid_dim
,
params
.
outer_loops
,
occupancy
,
coop
);
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
apex/contrib/csrc/groupbn/cuda_utils.h
0 → 100644
View file @
4d6ed501
#include <ATen/cuda/CUDAContext.h>
#ifndef CUDA_UTILS_H
#define CUDA_UTILS_H
namespace
at
{
namespace
cuda
{
namespace
utils
{
static
inline
int
MaxSharedMemoryPerMultiprocessor
(
int
device_id
)
{
return
getDeviceProperties
(
device_id
)
->
sharedMemPerMultiprocessor
;
}
}
}
}
#endif
apex/contrib/csrc/groupbn/interface.cpp
0 → 100644
View file @
4d6ed501
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/ArrayRef.h>
#include <ATen/ScalarType.h>
#include "ATen/Scalar.h"
#ifndef VERSION_GE_1_1
#include "ATen/Type.h"
#endif
#include "ATen/Tensor.h"
#include "ATen/Storage.h"
#include "ATen/Generator.h"
namespace
py
=
pybind11
;
int64_t
get_buffer_size
(
const
int
bn_sync_steps
);
void
*
get_data_ptr
(
const
at
::
Tensor
&
data
);
void
*
get_remote_data_ptr
(
const
at
::
Tensor
&
handle
,
const
int64_t
offset
);
void
close_remote_data
(
const
at
::
Tensor
&
handle
);
at
::
Tensor
nhwc_bn_fwd_train
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
at
::
Tensor
nhwc_bn_fwd_eval
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
int
bn_group
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
);
std
::
vector
<
at
::
Tensor
>
nhwc_bn_bwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
dy
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
at
::
Tensor
nhwc_bn_addrelu_fwd_train
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
z
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
bitmask
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
at
::
Tensor
nhwc_bn_addrelu_fwd_eval
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
z
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
int
bn_group
,
const
float
momentum
,
const
float
epsilon
);
std
::
vector
<
at
::
Tensor
>
nhwc_bn_addrelu_bwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
dy
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
bitmask
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
int
nhwc_bn_fwd_occupancy
();
int
nhwc_bn_bwd_occupancy
();
int
nhwc_bn_addrelu_fwd_occupancy
();
int
nhwc_bn_addrelu_bwd_occupancy
();
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"get_buffer_size"
,
&
get_buffer_size
,
"get_buffer_size"
);
m
.
def
(
"get_data_ptr"
,
&
get_data_ptr
,
"get_data_ptr"
);
m
.
def
(
"get_remote_data_ptr"
,
&
get_remote_data_ptr
,
"get_remote_data_ptr"
);
m
.
def
(
"close_remote_data"
,
&
close_remote_data
,
"close_remote_data"
);
m
.
def
(
"bn_fwd_nhwc"
,
&
nhwc_bn_fwd_train
,
"bn_fwd_nhwc"
);
m
.
def
(
"bn_fwd_eval_nhwc"
,
&
nhwc_bn_fwd_eval
,
"bn_fwd_eval_nhwc"
);
m
.
def
(
"bn_bwd_nhwc"
,
&
nhwc_bn_bwd
,
"bn_bwd_nhwc"
);
m
.
def
(
"bn_fwd_nhwc_occupancy"
,
&
nhwc_bn_fwd_occupancy
,
"bn_fwd_nhwc_occupancy"
);
m
.
def
(
"bn_bwd_nhwc_occupancy"
,
&
nhwc_bn_bwd_occupancy
,
"bn_bwd_nhwc_occupancy"
);
m
.
def
(
"bn_addrelu_fwd_nhwc"
,
&
nhwc_bn_addrelu_fwd_train
,
"bn_addrelu_fwd_nhwc"
);
m
.
def
(
"bn_addrelu_fwd_eval_nhwc"
,
&
nhwc_bn_addrelu_fwd_eval
,
"bn_addrelu_fwd_eval_nhwc"
);
m
.
def
(
"bn_addrelu_bwd_nhwc"
,
&
nhwc_bn_addrelu_bwd
,
"bn_addrelu_bwd_nhwc"
);
m
.
def
(
"bn_addrelu_fwd_nhwc_occupancy"
,
&
nhwc_bn_addrelu_fwd_occupancy
,
"bn_addrelu_fwd_nhwc_occupancy"
);
m
.
def
(
"bn_addrelu_bwd_nhwc_occupancy"
,
&
nhwc_bn_addrelu_bwd_occupancy
,
"bn_addrelu_bwd_nhwc_occupancy"
);
}
apex/contrib/csrc/groupbn/ipc.cu
0 → 100644
View file @
4d6ed501
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include "THC/THC.h"
#include <cuda.h>
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
template
<
>
struct
std
::
hash
<
cudaIpcMemHandle_t
>
{
size_t
operator
()
(
const
cudaIpcMemHandle_t
&
handle
)
const
{
size_t
hash
=
0
;
uint8_t
*
ptr
=
(
uint8_t
*
)
&
handle
;
assert
(
sizeof
(
uint8_t
)
==
1
);
for
(
int
i
=
0
;
i
<
sizeof
(
cudaIpcMemHandle_t
);
i
++
)
{
hash
+=
*
ptr
;
ptr
++
;
}
return
hash
;
}
};
template
<
>
struct
std
::
equal_to
<
cudaIpcMemHandle_t
>
{
bool
operator
()
(
const
cudaIpcMemHandle_t
&
lhs
,
const
cudaIpcMemHandle_t
&
rhs
)
const
{
return
(
std
::
memcmp
((
void
*
)
&
lhs
,
(
void
*
)
&
rhs
,
sizeof
(
cudaIpcMemHandle_t
))
==
0
);
}
};
namespace
{
namespace
gpuipc
{
//from: src/operator/nn/cudnn/nhwc_batch_norm_kernel.h
// The number of threads per pixel.
const
int
THREADS_PER_PIXEL
=
16
;
// The number of elements per ldg.
const
int
ELEMENTS_PER_LDG
=
4
;
// The number of reducing ops, each uses its own space : mean, var, dscale, dbias
const
int
REDUCE_OPS
=
4
;
// Maximum block.y supported - limited due to buffer allocation
const
int
MAX_BLOCK_Y
=
256
;
const
int
MAX_OFFSET
=
REDUCE_OPS
*
MAX_BLOCK_Y
;
const
int
BYTES_PER_ELEM
=
4
;
// Buffer size per sync step
const
int
SINGLE_SYNC_BUFFER_BYTES
=
MAX_OFFSET
*
THREADS_PER_PIXEL
*
2
*
ELEMENTS_PER_LDG
*
BYTES_PER_ELEM
;
};
class
IpcMemHandleRegistry
{
public:
void
*
getPtr
(
const
cudaIpcMemHandle_t
&
handle
,
int64_t
offset
)
{
if
(
registry_
.
count
(
handle
)
==
0
)
{
registry_
.
insert
(
std
::
make_pair
(
handle
,
RegistryEntry
()));
registry_
[
handle
].
dev_ptr
=
ipcOpenMem
(
handle
);
}
registry_
[
handle
].
ref_count
++
;
return
(((
uint8_t
*
)
registry_
[
handle
].
dev_ptr
)
+
offset
);
}
void
releasePtr
(
const
cudaIpcMemHandle_t
&
handle
)
{
if
(
registry_
.
count
(
handle
)
==
0
)
{
}
if
(
--
registry_
[
handle
].
ref_count
==
0
)
{
ipcCloseMem
(
registry_
[
handle
].
dev_ptr
);
registry_
.
erase
(
handle
);
}
}
struct
RegistryEntry
{
void
*
dev_ptr
;
int
ref_count
;
RegistryEntry
()
:
dev_ptr
(
NULL
)
,
ref_count
(
0
)
{}
};
protected:
std
::
unordered_map
<
cudaIpcMemHandle_t
,
RegistryEntry
>
registry_
;
void
*
ipcOpenMem
(
const
cudaIpcMemHandle_t
&
handle
)
{
void
*
data
;
cudaIpcOpenMemHandle
(
&
data
,
handle
,
cudaIpcMemLazyEnablePeerAccess
);
cudaCheckErrors
(
"ipc init"
);
return
data
;
}
void
ipcCloseMem
(
void
*
dev_ptr
)
{
cudaIpcCloseMemHandle
(
dev_ptr
);
cudaCheckErrors
(
"ipc close"
);
}
};
}
static
IpcMemHandleRegistry
ipc_mem_registry
;
int64_t
get_buffer_size
(
const
int
bn_sync_steps
)
{
return
bn_sync_steps
*
gpuipc
::
SINGLE_SYNC_BUFFER_BYTES
;
}
void
*
get_remote_data_ptr
(
const
at
::
Tensor
&
handle
,
const
int64_t
offset
)
{
cudaIpcMemHandle_t
my_handle
;
memcpy
((
unsigned
char
*
)(
&
my_handle
),
handle
.
data
<
uint8_t
>
(),
sizeof
(
my_handle
));
return
ipc_mem_registry
.
getPtr
(
my_handle
,
offset
);
}
void
close_remote_data
(
const
at
::
Tensor
&
handle
)
{
cudaIpcMemHandle_t
my_handle
;
memcpy
((
unsigned
char
*
)(
&
my_handle
),
handle
.
data
<
uint8_t
>
(),
sizeof
(
my_handle
));
ipc_mem_registry
.
releasePtr
(
my_handle
);
}
void
*
get_data_ptr
(
const
at
::
Tensor
&
data
)
{
return
data
.
data
<
uint8_t
>
();
}
apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h
0 → 100644
View file @
4d6ed501
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nhwc_batch_norm_kernel.h
* \brief CUDA NHWC Batch Normalization code
* \author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer
*/
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_
#include <stdint.h>
#include <algorithm>
#define DEVICE_FUNCTION static inline __device__
// CTA margin used by cooperative launch. Can be overridden by env var NHWC_BATCHNORM_LAUNCH_MARGIN.
#define NHWC_BATCHNORM_LAUNCH_MARGIN_MIN 3
#define NHWC_BATCHNORM_LAUNCH_MARGIN_DEFAULT NHWC_BATCHNORM_LAUNCH_MARGIN_MIN
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
ELEMENTS_PER_LDG
>
struct
PackedStorage
{
enum
{
PACKED_ELEMENTS_PER_LDG
=
ELEMENTS_PER_LDG
};
typedef
T
Type
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
ELEMENTS_PER_LDG
>
struct
PackedStorage
<
uint16_t
,
ELEMENTS_PER_LDG
>
{
enum
{
PACKED_ELEMENTS_PER_LDG
=
ELEMENTS_PER_LDG
/
2
};
typedef
int
Type
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
from_float
(
int
(
&
dst
)[
N
],
const
float
(
&
src
)[
2
*
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
uint16_t
lo
,
hi
;
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;"
:
"=h"
(
lo
)
:
"f"
(
src
[
2
*
i
+
0
]));
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;"
:
"=h"
(
hi
)
:
"f"
(
src
[
2
*
i
+
1
]));
asm
volatile
(
"mov.b32 %0, {%1, %2};"
:
"=r"
(
dst
[
i
])
:
"h"
(
lo
),
"h"
(
hi
));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
from_float
(
float
(
&
dst
)[
N
],
const
float
(
&
src
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
dst
[
i
]
=
src
[
i
];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
to_float
(
float
(
&
dst
)[
2
*
N
],
int
(
&
src
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
uint16_t
lo
,
hi
;
asm
volatile
(
"mov.b32 {%0, %1}, %2;"
:
"=h"
(
lo
),
"=h"
(
hi
)
:
"r"
(
src
[
i
]));
asm
volatile
(
"cvt.f32.f16 %0, %1;"
:
"=f"
(
dst
[
2
*
i
+
0
])
:
"h"
(
lo
));
asm
volatile
(
"cvt.f32.f16 %0, %1;"
:
"=f"
(
dst
[
2
*
i
+
1
])
:
"h"
(
hi
));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
to_float
(
float
(
&
dst
)[
N
],
float
(
&
src
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
dst
[
i
]
=
src
[
i
];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
ldg
(
int
(
&
dst
)[
1
],
const
uint16_t
*
gmem
)
{
dst
[
0
]
=
__ldg
((
const
int
*
)
gmem
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
ldg_stream
(
int
(
&
dst
)[
1
],
const
uint16_t
*
gmem
)
{
unsigned
int
tmp
;
asm
volatile
(
"ld.global.cs.nc.s32 %0, [%1];"
:
"=r"
(
tmp
)
:
"l"
((
const
uint
*
)
gmem
));
dst
[
0
]
=
tmp
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
ldg
(
int
(
&
dst
)[
2
],
const
uint16_t
*
gmem
)
{
int2
tmp
=
__ldg
((
const
int2
*
)
gmem
);
dst
[
0
]
=
tmp
.
x
;
dst
[
1
]
=
tmp
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
ldg_stream
(
int
(
&
dst
)[
2
],
const
uint16_t
*
gmem
)
{
int2
tmp
;
asm
volatile
(
"ld.global.cs.nc.v2.s32 {%0,%1}, [%2];"
:
"=r"
(
tmp
.
x
),
"=r"
(
tmp
.
y
)
:
"l"
((
const
int2
*
)
gmem
));
dst
[
0
]
=
tmp
.
x
;
dst
[
1
]
=
tmp
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
ldg
(
float
(
&
dst
)[
N
],
const
uint16_t
*
gmem
)
{
int
tmp
[
N
/
2
];
ldg
(
tmp
,
gmem
);
to_float
(
dst
,
tmp
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
ldg_stream
(
float
(
&
dst
)[
N
],
const
uint16_t
*
gmem
)
{
int
tmp
[
N
/
2
];
ldg_stream
(
tmp
,
gmem
);
to_float
(
dst
,
tmp
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
stg
(
uint16_t
*
gmem
,
int
(
&
src
)[
1
])
{
reinterpret_cast
<
int
*>
(
gmem
)[
0
]
=
src
[
0
];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
stg_stream
(
uint16_t
*
gmem
,
int
(
&
src
)[
1
])
{
unsigned
int
tmp
=
src
[
0
];
asm
volatile
(
"st.global.cs.s32 [%0], %1;"
::
"l"
((
uint
*
)
gmem
)
,
"r"
(
tmp
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
stg
(
uint16_t
*
gmem
,
int
(
&
src
)[
2
])
{
reinterpret_cast
<
int2
*>
(
gmem
)[
0
]
=
make_int2
(
src
[
0
],
src
[
1
]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
stg_stream
(
uint16_t
*
gmem
,
int
(
&
src
)[
2
])
{
asm
volatile
(
"st.global.cs.v2.s32 [%0], {%1,%2};"
::
"l"
((
uint
*
)
gmem
)
,
"r"
(
src
[
0
]),
"r"
(
src
[
1
]));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
stg
(
uint16_t
*
gmem
,
float
(
&
src
)[
N
])
{
int
tmp
[
N
/
2
];
from_float
(
tmp
,
src
);
stg
(
gmem
,
tmp
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
stg_stream
(
uint16_t
*
gmem
,
float
(
&
src
)[
N
])
{
int
tmp
[
N
/
2
];
from_float
(
tmp
,
src
);
stg_stream
(
gmem
,
tmp
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
read_from_gmem
(
float
(
&
dst
)[
2
],
const
float
*
gmem
,
int
idx
)
{
float2
tmp
=
__ldg
(
reinterpret_cast
<
const
float2
*>
(
&
gmem
[
2
*
idx
]));
dst
[
0
]
=
tmp
.
x
;
dst
[
1
]
=
tmp
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
read_from_gmem
(
float
(
&
dst
)[
4
],
const
float
*
gmem
,
int
idx
)
{
float4
tmp
=
__ldg
(
reinterpret_cast
<
const
float4
*>
(
&
gmem
[
4
*
idx
]));
dst
[
0
]
=
tmp
.
x
;
dst
[
1
]
=
tmp
.
y
;
dst
[
2
]
=
tmp
.
z
;
dst
[
3
]
=
tmp
.
w
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
read_from_smem
(
float
(
&
x
)[
2
],
const
float
*
smem
,
int
idx
)
{
float2
tmp
=
*
(
const
float2
*
)
&
smem
[
2
*
idx
];
x
[
0
]
=
tmp
.
x
;
x
[
1
]
=
tmp
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
read_from_smem
(
int
(
&
x
)[
1
],
const
int
*
smem
,
int
idx
)
{
x
[
0
]
=
smem
[
idx
];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
read_from_smem
(
float
(
&
x
)[
4
],
const
float
*
smem
,
int
idx
)
{
float4
tmp
=
*
(
const
float4
*
)
&
smem
[
4
*
idx
];
x
[
0
]
=
tmp
.
x
;
x
[
1
]
=
tmp
.
y
;
x
[
2
]
=
tmp
.
z
;
x
[
3
]
=
tmp
.
w
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
read_from_smem
(
int
(
&
x
)[
2
],
const
int
*
smem
,
int
idx
)
{
int2
tmp
=
*
(
const
int2
*
)
&
smem
[
2
*
idx
];
x
[
0
]
=
tmp
.
x
;
x
[
1
]
=
tmp
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
write_to_gmem
(
float
*
gmem
,
int
idx
,
const
float
(
&
src
)[
2
])
{
reinterpret_cast
<
float2
*>
(
&
gmem
[
2
*
idx
])[
0
]
=
make_float2
(
src
[
0
],
src
[
1
]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
write_to_gmem
(
float
*
gmem
,
int
idx
,
const
float
(
&
src
)[
4
])
{
reinterpret_cast
<
float4
*>
(
&
gmem
[
4
*
idx
])[
0
]
=
make_float4
(
src
[
0
],
src
[
1
],
src
[
2
],
src
[
3
]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
scaled_write_to_gmem
(
float
*
gmem
,
int
idx
,
const
float
(
&
src
)[
4
],
const
float
coeff
)
{
reinterpret_cast
<
float4
*>
(
&
gmem
[
4
*
idx
])[
0
]
=
make_float4
(
src
[
0
]
*
coeff
,
src
[
1
]
*
coeff
,
src
[
2
]
*
coeff
,
src
[
3
]
*
coeff
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
write_to_smem
(
float
*
smem
,
int
idx
,
const
float
(
&
x
)[
2
])
{
reinterpret_cast
<
float2
*>
(
&
smem
[
2
*
idx
])[
0
]
=
make_float2
(
x
[
0
],
x
[
1
]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
write_to_smem
(
int
*
smem
,
int
idx
,
const
int
(
&
x
)[
1
])
{
smem
[
idx
]
=
x
[
0
];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
write_to_smem
(
float
*
smem
,
int
idx
,
const
float
(
&
x
)[
4
])
{
reinterpret_cast
<
float4
*>
(
&
smem
[
4
*
idx
])[
0
]
=
make_float4
(
x
[
0
],
x
[
1
],
x
[
2
],
x
[
3
]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
write_to_smem
(
int
*
smem
,
int
idx
,
const
int
(
&
x
)[
2
])
{
reinterpret_cast
<
int2
*>
(
&
smem
[
2
*
idx
])[
0
]
=
make_int2
(
x
[
0
],
x
[
1
]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
zero_array
(
int
(
&
dst
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
dst
[
i
]
=
0
;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
zero_array
(
float
(
&
dst
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
dst
[
i
]
=
0.
f
;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
add
(
float
(
&
x
)[
N
],
const
float
(
&
y
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
x
[
i
]
+=
y
[
i
];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
multiply
(
float
(
&
x
)[
N
],
const
float
(
&
y
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
x
[
i
]
*=
y
[
i
];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
scale_
(
float
(
&
x
)[
N
],
float
scalar
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
x
[
i
]
*=
scalar
;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
normalize
(
float
(
&
x
)[
N
],
const
float
(
&
bias
)[
N
],
const
float
(
&
scale
)[
N
],
const
float
(
&
m1
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
x
[
i
]
=
bias
[
i
]
+
scale
[
i
]
*
(
x
[
i
]
-
m1
[
i
]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Storage
>
DEVICE_FUNCTION
Storage
relu
(
Storage
in
)
{
Storage
zero
=
(
Storage
)
0.
f
;
return
(
in
<
zero
)
?
zero
:
in
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
relu_activation
(
float
(
&
x
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
x
[
i
]
=
relu
(
x
[
i
]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_CTA
>
DEVICE_FUNCTION
void
parallel_sums_16x2
(
float
*
smem
,
float
(
&
x
)[
4
],
int
nhw
,
void
*
params_my_data
,
void
**
params_pair_datas
,
int
off
,
const
int
magic
,
const
int
sync_iters
)
{
// The size of a warp.
const
int
THREADS_PER_WARP
=
32
;
// The number of warps in a CTA.
const
int
WARPS_PER_CTA
=
THREADS_PER_CTA
/
THREADS_PER_WARP
;
// The number of threads per pixel.
const
int
THREADS_PER_PIXEL
=
16
;
// The number of elements per ldg.
const
int
ELEMENTS_PER_LDG
=
4
;
// The number of reducing ops, each uses its own space : mean, var, dscale, dbias
const
int
REDUCE_OPS
=
4
;
// Maximum block.y supported - limited due to buffer allocation
const
int
MAX_BLOCK_Y
=
256
;
const
int
MAX_OFFSET
=
REDUCE_OPS
*
MAX_BLOCK_Y
;
// The warp decomposition.
const
int
warp_id
=
threadIdx
.
x
/
THREADS_PER_WARP
;
const
int
lane_id
=
threadIdx
.
x
%
THREADS_PER_WARP
;
// total size of data per sync iter
const
int
data_total
=
MAX_OFFSET
*
THREADS_PER_PIXEL
*
ELEMENTS_PER_LDG
*
2
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
x
[
i
]
+=
__shfl_sync
(
0xffffffffU
,
x
[
i
],
THREADS_PER_PIXEL
+
lane_id
);
}
// The warp leaders, write to SMEM.
if
(
lane_id
<
THREADS_PER_PIXEL
)
{
write_to_smem
(
smem
,
warp_id
*
THREADS_PER_PIXEL
+
lane_id
,
x
);
}
// The data is in SMEM. Do the final reduction.
__syncthreads
();
// The 1st warp does all the work.
// We do the final reduction each half-warp sequentially reduces the final values.
if
(
warp_id
==
0
)
{
read_from_smem
(
x
,
smem
,
threadIdx
.
x
);
#pragma unroll
for
(
int
offset
=
1
;
offset
<
WARPS_PER_CTA
/
(
THREADS_PER_WARP
/
THREADS_PER_PIXEL
);
++
offset
)
{
float
y
[
ELEMENTS_PER_LDG
];
// Read the mean and variance from the other pixel.
read_from_smem
(
y
,
smem
,
threadIdx
.
x
+
offset
*
THREADS_PER_WARP
);
// Compute the updated sum.
add
(
x
,
y
);
}
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
x
[
i
]
+=
__shfl_sync
(
0xffffffffU
,
x
[
i
],
THREADS_PER_PIXEL
+
lane_id
);
}
// Make sure the data was read from SMEM.
__syncwarp
();
// Store the final values.
if
(
threadIdx
.
x
<
THREADS_PER_PIXEL
)
{
// probably could do it earlier, before sync
for
(
int
sync_iter
=
0
;
sync_iter
<
sync_iters
;
++
sync_iter
)
{
//float* params_pair_data = (reinterpret_cast<float**>(params_pair_datas))[sync_iter];
void
*
params_pair_data
=
params_pair_datas
[
sync_iter
];
// skip the space consumed by previous sync iterations
const
int
xbuf_offset
=
sync_iter
*
data_total
;
// data starts after flags, but have to skip previous
const
int
data_offset
=
xbuf_offset
+
off
*
ELEMENTS_PER_LDG
*
THREADS_PER_PIXEL
*
2
+
ELEMENTS_PER_LDG
*
threadIdx
.
x
*
2
;
// after sums for this GPU were computed, let CTA0 broadcast the sum to over GPU
if
(
blockIdx
.
x
==
0
)
{
volatile
float
*
write_data
=
&
((
reinterpret_cast
<
float
*>
(
params_pair_data
))[
data_offset
]);
// write the data to memory region to be reflected to other GPU
asm
volatile
(
"st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};"
::
"l"
(
write_data
)
,
"f"
(
x
[
0
]),
"r"
(
magic
),
"f"
(
x
[
2
]),
"r"
(
magic
));
asm
volatile
(
"st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};"
::
"l"
(
write_data
+
4
)
,
"f"
(
x
[
1
]),
"r"
(
magic
),
"f"
(
x
[
3
]),
"r"
(
magic
));
}
// now each CTA (on each GPU) reads the data written by CTA 0 of the other GPU
volatile
float
*
read_data
=
&
((
reinterpret_cast
<
float
*>
(
params_my_data
))[
data_offset
]);
float
other
[
4
];
uint32_t
other_flag_a
,
other_flag_b
;
do
{
asm
volatile
(
"ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];"
:
"=f"
(
other
[
0
]),
"=r"
(
other_flag_a
),
"=f"
(
other
[
2
]),
"=r"
(
other_flag_b
)
:
"l"
(
read_data
));
}
while
((
other_flag_a
!=
magic
)
||
(
other_flag_b
!=
magic
));
do
{
asm
volatile
(
"ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];"
:
"=f"
(
other
[
1
]),
"=r"
(
other_flag_a
),
"=f"
(
other
[
3
]),
"=r"
(
other_flag_b
)
:
"l"
(
read_data
+
4
));
}
while
((
other_flag_a
!=
magic
)
||
(
other_flag_b
!=
magic
));
add
(
x
,
other
);
}
// finally, after syncing up and accounting for partial sums from
// other GPUs as required, write the result
write_to_smem
(
smem
,
threadIdx
.
x
,
x
);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_CTA
>
DEVICE_FUNCTION
void
parallel_sums_8x4
(
float
*
smem
,
float
(
&
x
)[
4
],
int
nhw
)
{
// The size of a warp.
const
int
THREADS_PER_WARP
=
32
;
// The number of warps in a CTA.
const
int
WARPS_PER_CTA
=
THREADS_PER_CTA
/
THREADS_PER_WARP
;
// The number of threads per pixel.
const
int
THREADS_PER_PIXEL
=
8
;
// The number of elements per ldg.
const
int
ELEMENTS_PER_LDG
=
4
;
// The warp decomposition.
const
int
warp_id
=
threadIdx
.
x
/
THREADS_PER_WARP
;
const
int
lane_id
=
threadIdx
.
x
%
THREADS_PER_WARP
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
x
[
i
]
+=
__shfl_sync
(
0xffffffffU
,
x
[
i
],
THREADS_PER_PIXEL
+
lane_id
);
x
[
i
]
+=
__shfl_sync
(
0xffffffffU
,
x
[
i
],
THREADS_PER_PIXEL
*
2
+
lane_id
);
}
// The warp leaders, write to SMEM.
if
(
lane_id
<
THREADS_PER_PIXEL
)
{
write_to_smem
(
smem
,
warp_id
*
THREADS_PER_PIXEL
+
lane_id
,
x
);
}
// The data is in SMEM. Do the final reduction.
__syncthreads
();
// The 1st warp does all the work.
// We do the final reduction each half-warp sequentially reduces the final values.
if
(
warp_id
==
0
)
{
read_from_smem
(
x
,
smem
,
threadIdx
.
x
);
#pragma unroll
for
(
int
offset
=
1
;
offset
<
WARPS_PER_CTA
/
(
THREADS_PER_WARP
/
THREADS_PER_PIXEL
);
++
offset
)
{
float
y
[
ELEMENTS_PER_LDG
];
// Read the mean and variance from the other pixel.
read_from_smem
(
y
,
smem
,
threadIdx
.
x
+
offset
*
THREADS_PER_WARP
);
// Compute the updated sum.
add
(
x
,
y
);
}
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
x
[
i
]
+=
__shfl_sync
(
0xffffffffU
,
x
[
i
],
THREADS_PER_PIXEL
+
lane_id
);
x
[
i
]
+=
__shfl_sync
(
0xffffffffU
,
x
[
i
],
THREADS_PER_PIXEL
*
2
+
lane_id
);
}
// Make sure the data was read from SMEM.
__syncwarp
();
// Store the final values.
if
(
threadIdx
.
x
<
THREADS_PER_PIXEL
)
{
write_to_smem
(
smem
,
threadIdx
.
x
,
x
);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_CTA
,
int
THREADS_PER_PIXEL
,
int
ELEMENTS_PER_LDG
>
DEVICE_FUNCTION
void
parallel_sums
(
float
*
smem
,
float
(
&
x
)[
ELEMENTS_PER_LDG
],
int
nhw
)
{
// The size of a warp.
const
int
THREADS_PER_WARP
=
32
;
// The number of warps in a CTA.
const
int
WARPS_PER_CTA
=
THREADS_PER_CTA
/
THREADS_PER_WARP
;
// The number of pixels computed by a single warp.
const
int
PIXELS_PER_WARP
=
THREADS_PER_WARP
/
THREADS_PER_PIXEL
;
// The position in the warp.
const
int
nhw_in_warp
=
nhw
%
PIXELS_PER_WARP
;
// The C in the warp.
const
int
c_in_warp
=
threadIdx
.
x
%
THREADS_PER_PIXEL
;
// Store the values to shared memory.
write_to_smem
(
smem
,
threadIdx
.
x
,
x
);
// Compute the parallel sums.
for
(
int
offset
=
PIXELS_PER_WARP
/
2
;
offset
>
0
;
offset
/=
2
)
{
// NOP.
__syncwarp
();
// Read the running sum from the other thread.
float
y
[
ELEMENTS_PER_LDG
];
if
(
nhw_in_warp
<
offset
)
{
read_from_smem
(
y
,
smem
,
threadIdx
.
x
+
offset
*
THREADS_PER_PIXEL
);
}
// Compute the updated sum.
add
(
x
,
y
);
// NOP.
__syncwarp
();
// Update the sum in SMEM.
if
(
offset
>
1
&&
nhw_in_warp
<
offset
)
{
write_to_smem
(
smem
,
threadIdx
.
x
,
x
);
}
}
// The warps are done. Do the final reduction at the CTA level.
__syncthreads
();
// The warp leaders, write to SMEM.
const
int
idx
=
(
threadIdx
.
x
/
THREADS_PER_WARP
)
*
THREADS_PER_PIXEL
+
c_in_warp
;
if
(
nhw_in_warp
==
0
)
{
write_to_smem
(
smem
,
idx
,
x
);
}
// The data is in SMEM. Do the final reduction.
__syncthreads
();
// Read the 1st element to prepare the work.
if
(
nhw
<
WARPS_PER_CTA
/
2
)
{
read_from_smem
(
x
,
smem
,
threadIdx
.
x
);
}
// We have the running mean and running m2. Let's build the mean/var of the CTA.
for
(
int
offset
=
WARPS_PER_CTA
/
2
;
offset
>
0
;
offset
/=
2
)
{
// NOP.
__syncwarp
();
// Read the mean and variance from the other pixel.
float
y
[
ELEMENTS_PER_LDG
];
if
(
nhw
<
offset
)
{
read_from_smem
(
y
,
smem
,
threadIdx
.
x
+
offset
*
THREADS_PER_PIXEL
);
}
// Compute the updated sum.
add
(
x
,
y
);
// NOP.
__syncwarp
();
// Store the mean/var for the different pixels.
if
(
nhw
<
offset
)
{
write_to_smem
(
smem
,
threadIdx
.
x
,
x
);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_PIXEL
,
int
ELEMENTS_PER_LDG
>
struct
ParallelSums
{
template
<
int
THREADS_PER_CTA
>
DEVICE_FUNCTION
void
dispatch
(
float
*
smem
,
float
(
&
x
)[
ELEMENTS_PER_LDG
],
int
nhw
)
{
parallel_sums
<
THREADS_PER_CTA
,
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>
(
smem
,
x
,
nhw
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
ParallelSums
<
16
,
4
>
{
template
<
int
THREADS_PER_CTA
>
DEVICE_FUNCTION
void
dispatch
(
float
*
smem
,
float
(
&
x
)[
4
],
int
nhw
)
{
parallel_sums_16x2
<
THREADS_PER_CTA
>
(
smem
,
x
,
nhw
,
0
,
0
,
0
,
0
,
0
);
}
template
<
int
THREADS_PER_CTA
>
DEVICE_FUNCTION
void
dispatchX
(
float
*
smem
,
float
(
&
x
)[
4
],
int
nhw
,
void
*
params_my_data
,
void
**
params_pair_datas
,
int
off
,
const
int
magic
,
const
unsigned
int
&
sync_iters
)
{
parallel_sums_16x2
<
THREADS_PER_CTA
>
(
smem
,
x
,
nhw
,
params_my_data
,
params_pair_datas
,
off
,
magic
,
sync_iters
);
}
};
template
<
>
struct
ParallelSums
<
8
,
4
>
{
template
<
int
THREADS_PER_CTA
>
DEVICE_FUNCTION
void
dispatch
(
float
*
smem
,
float
(
&
x
)[
4
],
int
nhw
)
{
parallel_sums_8x4
<
THREADS_PER_CTA
>
(
smem
,
x
,
nhw
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
int
div_up
(
int
m
,
int
n
)
{
return
(
m
+
n
-
1
)
/
n
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// It is expected that all threads in the CTA enter this function!
DEVICE_FUNCTION
void
inter_block_sync
(
int
*
gmem_retired_ctas
,
int
expected_count
,
bool
master
)
{
// Register the CTA.
if
(
threadIdx
.
x
==
0
)
{
// Issue the membar.
__threadfence
();
// Notify that the CTA is done.
int
val_to_add
=
1
;
if
(
master
)
{
val_to_add
=
-
(
expected_count
-
1
);
}
atomicAdd
(
gmem_retired_ctas
,
val_to_add
);
}
// Are all CTAs done?
if
(
threadIdx
.
x
==
0
)
{
int
retired_ctas
=
-
1
;
do
{
__threadfence
();
asm
volatile
(
"ld.global.cg.b32 %0, [%1];"
:
"=r"
(
retired_ctas
)
:
"l"
(
gmem_retired_ctas
));
}
while
(
retired_ctas
!=
0
);
}
__syncthreads
();
}
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
NhwcBatchNormFwdInferenceParams
{
// The input/output tensors.
uint16_t
*
gmem_src
,
*
gmem_dst
,
*
gmem_src1
;
// the final mean and variance as calculated during the training process
float
*
gmem_mean
,
*
gmem_var
;
// The bias/scale.
float
*
gmem_bias
,
*
gmem_scale
;
// The dimensions.
int
nhw
,
c
;
// epsilon
float
var_eps
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// No DESIRED_OCCUPANCY launch bounds needed, as this is not launched cooperatively
template
<
typename
Storage
,
int
THREADS_PER_CTA
,
int
THREADS_PER_PIXEL
,
int
ELEMENTS_PER_LDG
,
bool
USE_RELU
,
bool
USE_ADD_RELU
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
)
void
nhwc_batch_norm_fwd_inference
(
NhwcBatchNormFwdInferenceParams
params
)
{
// The number of pixels loaded in a single LDG.
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
// The number of C elements per CTA.
const
int
C_ELEMENTS_PER_CTA
=
THREADS_PER_PIXEL
*
ELEMENTS_PER_LDG
;
// The start position in the NHW dimension where the CTA starts.
const
int
cta_nhw_stride
=
gridDim
.
x
*
PIXELS_PER_LDG
;
// Compute the NHW coordinate of the thread in the CTA.
const
int
thread_in_cta_nhw
=
threadIdx
.
x
/
THREADS_PER_PIXEL
;
// thread's starting point in NHW
const
int
thread_nhw
=
thread_in_cta_nhw
+
blockIdx
.
x
*
PIXELS_PER_LDG
;
// The position in the C dimension where the CTA starts.
const
int
cta_c
=
blockIdx
.
y
*
C_ELEMENTS_PER_CTA
;
// Compute the C coordinate of the thread in the CTA.
const
int
thread_in_cta_c
=
threadIdx
.
x
%
THREADS_PER_PIXEL
;
// Compute the C coordinate of the thread.
const
int
thread_c
=
cta_c
+
thread_in_cta_c
*
ELEMENTS_PER_LDG
;
// Is the thread working on a valid C dimension?
const
int
is_valid_c
=
thread_c
<
params
.
c
;
float
mean
[
ELEMENTS_PER_LDG
],
var
[
ELEMENTS_PER_LDG
];
float
scale
[
ELEMENTS_PER_LDG
],
bias
[
ELEMENTS_PER_LDG
];
zero_array
(
mean
);
zero_array
(
var
);
zero_array
(
scale
);
zero_array
(
bias
);
if
(
is_valid_c
)
{
read_from_gmem
(
var
,
&
params
.
gmem_var
[
cta_c
],
thread_in_cta_c
);
read_from_gmem
(
scale
,
&
params
.
gmem_scale
[
cta_c
],
thread_in_cta_c
);
read_from_gmem
(
mean
,
&
params
.
gmem_mean
[
cta_c
],
thread_in_cta_c
);
read_from_gmem
(
bias
,
&
params
.
gmem_bias
[
cta_c
],
thread_in_cta_c
);
}
// Update the scale with the stddev and eps.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
scale
[
i
]
*=
rsqrtf
(
var
[
i
]
+
params
.
var_eps
);
}
// The base pointers for reading/writing
uint16_t
*
const
gmem_src
=
&
params
.
gmem_src
[
thread_c
];
uint16_t
*
const
gmem_dst
=
&
params
.
gmem_dst
[
thread_c
];
const
uint16_t
*
gmem_src1
=
nullptr
;
if
(
USE_ADD_RELU
)
{
gmem_src1
=
&
params
.
gmem_src1
[
thread_c
];
}
// apply BN
for
(
int
nhw
=
thread_nhw
;
nhw
<
params
.
nhw
;
nhw
+=
cta_nhw_stride
)
{
float
x_math
[
ELEMENTS_PER_LDG
];
zero_array
(
x_math
);
if
(
is_valid_c
)
{
ldg
(
x_math
,
&
gmem_src
[
nhw
*
params
.
c
]);
}
// Normalize and apply activation function
normalize
(
x_math
,
bias
,
scale
,
mean
);
if
(
USE_ADD_RELU
)
{
float
x1_math
[
ELEMENTS_PER_LDG
];
ldg
(
x1_math
,
&
gmem_src1
[
nhw
*
params
.
c
]);
add
(
x_math
,
x1_math
);
relu_activation
(
x_math
);
}
else
if
(
USE_RELU
)
{
relu_activation
(
x_math
);
}
if
(
is_valid_c
)
{
stg
(
&
gmem_dst
[
nhw
*
params
.
c
],
x_math
);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
NhwcBatchNormFwdParams
{
// The input/output tensors.
uint16_t
*
gmem_src
,
*
gmem_dst
,
*
gmem_src1
;
// The bias/scale.
float
*
gmem_bias
,
*
gmem_scale
;
// running mean/var (refer BN API from cudnn doc)
float
*
gmem_running_mean
,
*
gmem_running_var
;
// saved mean/var (refer BN API from cudnn doc)
float
*
gmem_saved_mean
,
*
gmem_saved_var
;
// ReLU bitmask
unsigned
int
*
gmem_relu_bitmask
;
// The dimensions.
int
nhw
,
c
;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
float
svar_inv_count
;
// factor to scale sum of squared errors to get running variance. Should be 1/nhw or 1/(nhw-1).
float
rvar_inv_count
;
// The buffer to do the reduction for mean, stddev and count.
float
*
gmem_sums
;
// The buffer to count items in the different CTAs.
int
*
gmem_counts
;
// The counters of retired CTAs.
int
*
gmem_retired_ctas
;
// The epsilon to apply to the computation of the variance.
float
var_eps
;
// outer loop count
int
outer_loops
;
// exponential average factor
float
exp_avg_factor
;
// number of CTAs along .x dimension
int
c_blks
;
void
*
my_data
;
void
*
pair_datas
[
4
];
int
magic
;
int
sync_iters
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Storage
,
int
THREADS_PER_CTA
,
int
THREADS_PER_PIXEL
,
int
PIXELS_PER_THREAD_IN_REGISTERS
,
int
PIXELS_PER_THREAD_IN_SMEM
,
int
ELEMENTS_PER_LDG
,
int
USE_ONLINE_APPROACH
,
int
OUTER_LOOPS_
,
bool
USE_RELU
,
bool
USE_ADD_RELU
,
int
DESIRED_OCCUPANCY
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
,
DESIRED_OCCUPANCY
)
void
nhwc_batch_norm_fwd
(
NhwcBatchNormFwdParams
params
)
{
// The number of pixels loaded in a single LDG.
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
// The number of pixels computed per CTA stored in registers.
const
int
PIXELS_PER_CTA_IN_REGISTERS
=
PIXELS_PER_THREAD_IN_REGISTERS
*
PIXELS_PER_LDG
;
// The number of pixels computed per CTA stored in SMEM.
const
int
PIXELS_PER_CTA_IN_SMEM
=
PIXELS_PER_THREAD_IN_SMEM
*
PIXELS_PER_LDG
;
// The number of C elements per CTA.
const
int
C_ELEMENTS_PER_CTA
=
THREADS_PER_PIXEL
*
ELEMENTS_PER_LDG
;
// Shared memory to do CTA-wide parallel sums.
__shared__
float
smem
[
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
32
)
*
ELEMENTS_PER_LDG
];
// Compute the NHW coordinate of the thread in the CTA.
const
int
thread_in_cta_nhw
=
threadIdx
.
x
/
THREADS_PER_PIXEL
;
// The adapter for the storage.
typedef
PackedStorage
<
Storage
,
ELEMENTS_PER_LDG
>
PackedStorage_
;
// The data type for packed storage in SMEM.
typedef
typename
PackedStorage_
::
Type
PackedStorageType
;
// The number of elements in the packed storage.
const
int
PACKED_ELEMENTS_PER_LDG
=
PackedStorage_
::
PACKED_ELEMENTS_PER_LDG
;
// Registers to keep the data live for the persistent approach.
PackedStorageType
x_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
// Shared memory buffer to store the extra pixels.
extern
__shared__
PackedStorageType
smem_storage_packed
[];
for
(
int
c_blk_index
=
blockIdx
.
y
;
c_blk_index
<
params
.
c_blks
;
c_blk_index
+=
gridDim
.
y
)
{
// The position in the NHW dimension where the CTA starts.
int
cta_nhw_regs
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_REGISTERS
;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int
cta_nhw_smem
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_SMEM
;
// The position in the C dimension where the CTA starts.
const
int
cta_c
=
c_blk_index
*
C_ELEMENTS_PER_CTA
;
// Compute the C coordinate of the thread in the CTA.
const
int
thread_in_cta_c
=
threadIdx
.
x
%
THREADS_PER_PIXEL
;
// Compute the C coordinate of the thread.
int
thread_c
=
cta_c
+
thread_in_cta_c
*
ELEMENTS_PER_LDG
;
// Is the thread working on a valid C dimension?
const
int
is_valid_c
=
thread_c
<
params
.
c
;
// Clamp thread_c so that we load from valid locations even if we don't use the value
if
(
!
is_valid_c
)
thread_c
=
params
.
c
-
4
;
// Single pass numerically stable algorithm, see:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm
//
// n = 0, mean = 0.0, M2 = 0.0
//
// for x in data:
// n += 1
// delta = x - mean
// mean += delta/n
// delta2 = x - mean
// M2 += delta*delta2
//
// if n < 2:
// return float('nan')
// else:
// return M2 / (n - 1)
// Register to store the number of elements read so far.
float
count
=
0.
f
,
mean
[
ELEMENTS_PER_LDG
],
m2
[
ELEMENTS_PER_LDG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
mean
[
i
]
=
0.
f
;
m2
[
i
]
=
0.
f
;
}
// The number of elements loaded by this CTA.
int
cta_count
=
0
;
// The base pointer to load from.
const
uint16_t
*
gmem_src
=
&
params
.
gmem_src
[
thread_c
];
// outer loops
int
OUTER_LOOPS
=
OUTER_LOOPS_
==
1
?
1
:
params
.
outer_loops
;
// Load the batch of elements. Compute the mean/var across those elements.
const
int
pixels_per_iteration
=
PIXELS_PER_CTA_IN_REGISTERS
*
gridDim
.
x
;
if
(
OUTER_LOOPS_
!=
1
)
{
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized, offset is evenly divisible by 32
int
offset
=
(
pixels_per_iteration
*
OUTER_LOOPS
+
PIXELS_PER_CTA_IN_SMEM
*
gridDim
.
x
-
params
.
nhw
)
&
~
31
;
cta_nhw_regs
-=
offset
;
cta_nhw_smem
-=
offset
;
}
#pragma unroll 1
for
(
int
loop_i
=
0
;
loop_i
<
OUTER_LOOPS
;
++
loop_i
)
{
// The nhw position.
int
nhw_regs
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count
+=
max
(
min
(
nhw_regs
+
PIXELS_PER_CTA_IN_REGISTERS
,
params
.
nhw
)
-
max
(
nhw_regs
,
0
),
0
);
// Load the data and compute the local mean/sum and the variance.
if
(
USE_ONLINE_APPROACH
)
{
// Read the elements from memory.
float
is_valid
[
PIXELS_PER_THREAD_IN_REGISTERS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
nhw_regs
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
zero_array
(
x_storage
[
i
]);
is_valid
[
i
]
=
0.
f
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
if
(
loop_i
==
OUTER_LOOPS
-
1
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
}
else
{
ldg
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
}
is_valid
[
i
]
=
1.
f
;
}
}
// Do the math.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
// Update the count.
count
+=
is_valid
[
i
];
// Invert the count.
float
inv_count
=
is_valid
[
i
]
?
1.
f
/
count
:
0.
f
;
// Update the mean and m2 using deltas.
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
float
delta0
=
x_math
[
j
]
-
mean
[
j
];
mean
[
j
]
+=
delta0
*
inv_count
;
float
delta1
=
x_math
[
j
]
-
mean
[
j
];
m2
[
j
]
+=
delta0
*
delta1
*
is_valid
[
i
];
}
}
}
else
{
// Read the elements from memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
nhw_regs
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
zero_array
(
x_storage
[
i
]);
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
if
(
loop_i
==
OUTER_LOOPS
-
1
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
}
else
{
ldg
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
}
count
+=
1.
f
;
}
}
// Sum the elements in registers.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
// Update the mean and m2 using deltas.
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
mean
[
j
]
+=
x_math
[
j
];
}
}
// Compute the mean.
float
inv_count
=
1.
f
/
count
;
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
mean
[
j
]
*=
inv_count
;
}
// Compute the variance.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
// Is it a valid pixel?
float
is_valid
=
i
<
static_cast
<
int
>
(
count
)
?
1.
f
:
0.
f
;
// Update the mean and m2 using deltas.
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
m2
[
j
]
+=
(
x_math
[
j
]
-
mean
[
j
])
*
(
x_math
[
j
]
-
mean
[
j
])
*
is_valid
;
}
}
}
}
// The elements to load and store in SMEM.
int
smem_nhw
=
OUTER_LOOPS
*
pixels_per_iteration
+
cta_nhw_smem
;
// Load elements from SMEM, update the CTA count.
int
pixels_in_smem
=
min
(
smem_nhw
+
PIXELS_PER_CTA_IN_SMEM
,
params
.
nhw
)
-
max
(
smem_nhw
,
0
);
if
(
pixels_in_smem
>
0
)
{
cta_count
+=
pixels_in_smem
;
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
float
is_pixel_valid
=
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
?
1.
f
:
0.
f
;
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
ldg_stream
(
x_storage_local
,
&
gmem_src
[(
is_pixel_valid
?
idx
:
0
)
*
params
.
c
]);
// The offset to store in SMEM.
const
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
// Store in SMEM.
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
x_storage_local
);
// Update the count.
count
+=
is_pixel_valid
;
// Invert the count.
float
inv_count
=
is_pixel_valid
?
1.
f
/
count
:
0.
f
;
float
x_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
// Update the mean and m2 using deltas.
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
float
delta0
=
x_math
[
j
]
-
mean
[
j
];
mean
[
j
]
+=
delta0
*
inv_count
;
float
delta1
=
x_math
[
j
]
-
mean
[
j
];
m2
[
j
]
+=
delta0
*
delta1
*
is_pixel_valid
;
}
}
}
// We scale the mean by the number of elements. It brings more stability.
float
m1
[
ELEMENTS_PER_LDG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
m1
[
i
]
=
mean
[
i
]
*
count
;
}
// Run the parallel sum accross the CTA to get the local sum.
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
m1
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
m1
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// Adjust the variance.
float
inv_cta_count
=
1.
f
/
static_cast
<
float
>
(
cta_count
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
float
mean_diff
=
m1
[
i
]
*
inv_cta_count
-
mean
[
i
];
m2
[
i
]
=
m2
[
i
]
+
mean_diff
*
mean_diff
*
count
;
}
// Run the parallel sum accross the CTA to get the local adjusted variance.
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
m2
,
thread_in_cta_nhw
);
// The workspace in global memory is distributed across the different CTA.
int
gmem_sums_offset
=
c_blk_index
*
gridDim
.
x
*
C_ELEMENTS_PER_CTA
*
2
;
// Write the data for the CTA to global memory.
float
*
gmem_sums
=
&
params
.
gmem_sums
[
gmem_sums_offset
];
if
(
threadIdx
.
x
<
THREADS_PER_PIXEL
)
{
const
int
idx
=
blockIdx
.
x
*
THREADS_PER_PIXEL
+
threadIdx
.
x
;
write_to_gmem
(
&
gmem_sums
[
0
],
idx
,
m1
);
write_to_gmem
(
&
gmem_sums
[
C_ELEMENTS_PER_CTA
*
gridDim
.
x
],
idx
,
m2
);
}
// The memory location to store the number of pixels per CTA.
int
*
gmem_counts
=
&
params
.
gmem_counts
[
c_blk_index
*
gridDim
.
x
];
if
(
threadIdx
.
x
==
0
)
{
gmem_counts
[
blockIdx
.
x
]
=
cta_count
;
}
// Read the bias and scale.
float
bias
[
ELEMENTS_PER_LDG
],
scale
[
ELEMENTS_PER_LDG
];
if
(
is_valid_c
)
{
read_from_gmem
(
bias
,
&
params
.
gmem_bias
[
cta_c
],
thread_in_cta_c
);
read_from_gmem
(
scale
,
&
params
.
gmem_scale
[
cta_c
],
thread_in_cta_c
);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int
*
gmem_retired_ctas
=
&
params
.
gmem_retired_ctas
[
c_blk_index
%
(
2
*
gridDim
.
y
)];
inter_block_sync
(
gmem_retired_ctas
,
gridDim
.
x
,
blockIdx
.
x
==
0
);
// Reset the mean to compute the global mean.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
m1
[
i
]
=
0.
f
;
}
// Build the global mean.
#pragma unroll 1
for
(
int
idx
=
threadIdx
.
x
;
idx
<
THREADS_PER_PIXEL
*
gridDim
.
x
;
idx
+=
THREADS_PER_CTA
)
{
float
tmp
[
ELEMENTS_PER_LDG
];
read_from_gmem
(
tmp
,
gmem_sums
,
idx
);
add
(
m1
,
tmp
);
}
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
m1
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
3
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
m1
,
thread_in_cta_nhw
);
}
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
m1
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// Normalize the mean.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
m1
[
i
]
=
m1
[
i
]
*
params
.
svar_inv_count
;
}
// Reset the variance.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
m2
[
i
]
=
0.
f
;
}
// for add+relu fusion
const
uint16_t
*
gmem_src1
=
nullptr
;
if
(
USE_ADD_RELU
)
{
gmem_src1
=
&
params
.
gmem_src1
[
thread_c
];
}
// Build the global variance.
#pragma unroll 1
for
(
int
idx
=
threadIdx
.
x
;
idx
<
THREADS_PER_PIXEL
*
gridDim
.
x
;
idx
+=
THREADS_PER_CTA
)
{
// Read the means computed by different CTAs (again). Reuse tmp if we have 1 iteration.
float
tmp_mean
[
ELEMENTS_PER_LDG
],
tmp_var
[
ELEMENTS_PER_LDG
];
read_from_gmem
(
tmp_mean
,
&
gmem_sums
[
0
],
idx
);
read_from_gmem
(
tmp_var
,
&
gmem_sums
[
C_ELEMENTS_PER_CTA
*
gridDim
.
x
],
idx
);
// Read the number of pixels visited by a given CTA.
cta_count
=
__ldg
(
&
gmem_counts
[
idx
/
THREADS_PER_PIXEL
]);
// Compute the diff to update the variance.
float
mean_diff
[
ELEMENTS_PER_LDG
],
inv_cta_count
=
1.
f
/
static_cast
<
float
>
(
cta_count
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
mean_diff
[
i
]
=
m1
[
i
]
-
tmp_mean
[
i
]
*
inv_cta_count
;
}
// Update the variance.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
m2
[
i
]
+=
tmp_var
[
i
]
+
mean_diff
[
i
]
*
mean_diff
[
i
]
*
static_cast
<
float
>
(
cta_count
);
}
}
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
m2
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
2
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
m2
,
thread_in_cta_nhw
);
}
__syncthreads
();
read_from_smem
(
m2
,
smem
,
thread_in_cta_c
);
// Finalize the stddev.
// becasue saved var and running var may have different denominator, we don't do it here
// scale_(m2, inv_count);
// store the saved mean/var
float
svarinv
[
ELEMENTS_PER_LDG
];
bool
is_valid_for_saving
=
is_valid_c
&&
blockIdx
.
x
==
0
&&
thread_in_cta_nhw
==
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
svarinv
[
i
]
=
rsqrtf
(
m2
[
i
]
*
params
.
svar_inv_count
+
params
.
var_eps
);
}
if
(
is_valid_for_saving
)
{
write_to_gmem
(
params
.
gmem_saved_mean
,
thread_c
/
ELEMENTS_PER_LDG
,
m1
);
write_to_gmem
(
params
.
gmem_saved_var
,
thread_c
/
ELEMENTS_PER_LDG
,
svarinv
);
}
// store the running mean/var
float
rmean
[
ELEMENTS_PER_LDG
],
rvar
[
ELEMENTS_PER_LDG
];
zero_array
(
rmean
);
zero_array
(
rvar
);
if
(
params
.
exp_avg_factor
!=
1.
f
&&
is_valid_for_saving
)
{
read_from_gmem
(
rmean
,
params
.
gmem_running_mean
,
thread_c
/
ELEMENTS_PER_LDG
);
read_from_gmem
(
rvar
,
params
.
gmem_running_var
,
thread_c
/
ELEMENTS_PER_LDG
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
rmean
[
i
]
=
(
1.
f
-
params
.
exp_avg_factor
)
*
rmean
[
i
]
+
\
params
.
exp_avg_factor
*
m1
[
i
];
rvar
[
i
]
=
(
1.
f
-
params
.
exp_avg_factor
)
*
rvar
[
i
]
+
\
params
.
exp_avg_factor
*
(
m2
[
i
]
*
params
.
rvar_inv_count
);
}
if
(
is_valid_for_saving
)
{
write_to_gmem
(
params
.
gmem_running_mean
,
thread_c
/
ELEMENTS_PER_LDG
,
rmean
);
write_to_gmem
(
params
.
gmem_running_var
,
thread_c
/
ELEMENTS_PER_LDG
,
rvar
);
}
// Update the scale with the stddev and eps.
multiply
(
scale
,
svarinv
);
// The base pointer to write to.
uint16_t
*
const
gmem_dst
=
&
params
.
gmem_dst
[
thread_c
];
unsigned
int
*
const
gmem_relu_bitmask
=
params
.
gmem_relu_bitmask
+
((
params
.
nhw
+
31
)
&
~
31
)
*
2
*
c_blk_index
;
// Store the elements in registers.
#pragma unroll 1
for
(
int
loop_i
=
OUTER_LOOPS
-
1
;
loop_i
>=
0
;
--
loop_i
)
{
// The value for nhw.
int
out_nhw
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Normalize the elements and write to memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_valid_nhw
=
static_cast
<
unsigned
int
>
(
idx
)
<
static_cast
<
unsigned
int
>
(
params
.
nhw
);
const
bool
is_valid
=
is_valid_nhw
&&
is_valid_c
;
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
// Normalize and apply activation function
normalize
(
x_math
,
bias
,
scale
,
m1
);
if
(
USE_ADD_RELU
)
{
float
x1_math
[
ELEMENTS_PER_LDG
];
ldg_stream
(
x1_math
,
&
gmem_src1
[(
is_valid
?
idx
:
0
)
*
params
.
c
]);
add
(
x_math
,
x1_math
);
unsigned
int
relu_mask
;
int
lane_id
=
threadIdx
.
x
&
31
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
bool
rectified
=
x_math
[
i
]
<
0.0
F
;
unsigned
int
local_relu_mask
=
__ballot_sync
(
0xFFFFFFFFU
,
rectified
);
if
(
lane_id
==
i
)
{
// Thread 0 remembers the relu_mask from the first time through this
// loop, Thread 1 the next, Thread 2 the next, and Thread 3 the last.
relu_mask
=
local_relu_mask
;
}
if
(
rectified
)
{
x_math
[
i
]
=
0.0
F
;
}
}
if
(
is_valid_nhw
&&
(
lane_id
<
ELEMENTS_PER_LDG
))
{
gmem_relu_bitmask
[
idx
*
2
+
lane_id
]
=
relu_mask
;
}
}
else
if
(
USE_RELU
)
{
relu_activation
(
x_math
);
}
// Write back.
if
(
is_valid
)
{
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
x_math
);
}
}
// The next value of nhw.
out_nhw
-=
pixels_per_iteration
;
// Read the next elements from memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
}
}
}
// Normalize the elements from SMEM and write them out.
if
(
pixels_in_smem
>
0
)
{
#pragma unroll 2
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_valid_nhw
=
static_cast
<
unsigned
int
>
(
idx
)
<
static_cast
<
unsigned
int
>
(
params
.
nhw
);
const
bool
is_valid
=
is_valid_nhw
&&
is_valid_c
;
// Read from SMEM.
const
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
read_from_smem
(
x_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
float
x_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
// Normalize and apply activation function
normalize
(
x_math
,
bias
,
scale
,
m1
);
if
(
USE_ADD_RELU
)
{
float
x1_math
[
ELEMENTS_PER_LDG
];
ldg_stream
(
x1_math
,
&
gmem_src1
[(
is_valid
?
idx
:
0
)
*
params
.
c
]);
add
(
x_math
,
x1_math
);
unsigned
int
relu_mask
;
int
lane_id
=
threadIdx
.
x
&
31
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
bool
rectified
=
x_math
[
i
]
<
0.0
F
;
unsigned
int
local_relu_mask
=
__ballot_sync
(
0xFFFFFFFFU
,
rectified
);
if
(
lane_id
==
i
)
{
relu_mask
=
local_relu_mask
;
}
if
(
rectified
)
{
x_math
[
i
]
=
0.0
F
;
}
}
if
(
is_valid_nhw
&&
(
lane_id
<
ELEMENTS_PER_LDG
))
{
gmem_relu_bitmask
[
idx
*
2
+
lane_id
]
=
relu_mask
;
}
}
else
if
(
USE_RELU
)
{
relu_activation
(
x_math
);
}
// Write back.
if
(
is_valid
)
{
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
x_math
);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads
();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
NhwcBatchNormBwdParams
{
// The input/output tensors.
uint16_t
*
gmem_src
,
*
gmem_dy
,
*
gmem_dst
,
*
gmem_dst1
;
// dscale/dbias
float
*
gmem_dscale
,
*
gmem_dbias
;
// The scale and bias.
float
*
gmem_scale
,
*
gmem_bias
;
// The mean/inv-var saved from fwd pass
float
*
gmem_saved_mean
,
*
gmem_saved_var
;
// ReLU bitmask
unsigned
int
*
gmem_relu_bitmask
;
// The dimensions.
int
nhw
,
c
;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
float
svar_inv_count
;
// The buffer to do the reduction for dscale and dbias
float
*
gmem_sums
;
// The counters of retired CTAs.
int
*
gmem_retired_ctas
;
// outer loop count
int
outer_loops
;
// number of CTAs along .x dimension
int
c_blks
;
void
*
my_data
;
void
*
pair_datas
[
4
];
int
magic
;
int
sync_iters
;
float
wgrad_coeff
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
relu_bwd
(
float
(
&
dy
)[
N
],
const
float
(
&
x
)[
N
],
const
float
(
&
mean_var_scale_bias
)[
N
],
const
float
(
&
var_scale
)[
N
],
bool
valid_data
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
float
y
=
(
x
[
j
]
*
var_scale
[
j
])
+
mean_var_scale_bias
[
j
];
if
((
y
<=
0.
f
)
&&
valid_data
)
{
dy
[
j
]
=
0.
f
;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
relu_bwd
(
float
(
&
dy
)[
N
],
const
float
(
&
y
)[
N
],
bool
valid_data
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
if
((
y
[
j
]
<=
0.
f
)
&&
valid_data
)
{
dy
[
j
]
=
0.
f
;
}
}
}
template
<
int
N
>
DEVICE_FUNCTION
void
relu_bwd
(
float
(
&
dy
)[
N
],
const
bool
(
&
rectified
)[
N
],
bool
valid_data
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
if
(
rectified
[
j
]
&&
valid_data
)
{
dy
[
j
]
=
0.
f
;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
relu_bwd_for_dx
(
float
(
&
dy
)[
N
],
const
float
(
&
x
)[
N
],
const
float
(
&
mean_var_scale_bias
)[
N
],
const
float
(
&
var_scale
)[
N
])
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
float
y
=
(
x
[
j
]
*
var_scale
[
j
])
+
mean_var_scale_bias
[
j
];
if
(
y
<=
0.
f
)
{
dy
[
j
]
=
0.
f
;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
relu_bwd_for_dx
(
float
(
&
dy
)[
N
],
const
float
(
&
y
)[
N
])
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
if
(
y
[
j
]
<=
0.
f
)
{
dy
[
j
]
=
0.
f
;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
bwd_update
(
float
(
&
dscale
)[
N
],
float
(
&
dbias
)[
N
],
const
float
(
&
dy
)[
N
],
const
float
(
&
x
)[
N
],
const
float
(
&
mean
)[
N
],
float
inv_count
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
float
delta0
=
dy
[
j
]
-
dbias
[
j
];
dbias
[
j
]
+=
delta0
*
inv_count
;
delta0
=
(
dy
[
j
]
*
(
x
[
j
]
-
mean
[
j
]))
-
dscale
[
j
];
dscale
[
j
]
+=
delta0
*
inv_count
;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
bwd_dx
(
float
(
&
dx
)[
N
],
const
float
(
&
dy
)[
N
],
const
float
(
&
var
)[
N
],
const
float
(
&
x
)[
N
],
const
float
(
&
mean
)[
N
],
const
float
(
&
dscale
)[
N
],
const
float
(
&
dbias
)[
N
],
float
inv_count
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
float
tmp1
=
dy
[
j
]
-
(
dbias
[
j
]
*
inv_count
);
float
tmp2
=
dscale
[
j
]
*
inv_count
;
float
tmp3
=
x
[
j
]
-
mean
[
j
];
dx
[
j
]
=
var
[
j
]
*
(
tmp1
-
(
tmp2
*
tmp3
));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Storage
,
int
THREADS_PER_CTA
,
int
THREADS_PER_PIXEL
,
int
PIXELS_PER_THREAD_IN_REGISTERS
,
int
PIXELS_PER_THREAD_IN_SMEM
,
int
ELEMENTS_PER_LDG
,
int
USE_ONLINE_APPROACH
,
int
OUTER_LOOPS_
,
int
DESIRED_OCCUPANCY
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
,
DESIRED_OCCUPANCY
)
void
nhwc_batch_norm_bwd
(
NhwcBatchNormBwdParams
params
)
{
// The number of pixels loaded in a single LDG.
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
// The number of pixels computed per CTA stored in registers.
const
int
PIXELS_PER_CTA_IN_REGISTERS
=
PIXELS_PER_THREAD_IN_REGISTERS
*
PIXELS_PER_LDG
;
// The number of pixels computed per CTA stored in SMEM.
const
int
PIXELS_PER_CTA_IN_SMEM
=
PIXELS_PER_THREAD_IN_SMEM
*
PIXELS_PER_LDG
;
// The number of C elements per CTA.
const
int
C_ELEMENTS_PER_CTA
=
THREADS_PER_PIXEL
*
ELEMENTS_PER_LDG
;
// Shared memory to do CTA-wide parallel sums.
__shared__
float
smem
[
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
32
)
*
ELEMENTS_PER_LDG
];
// The adapter for the storage.
typedef
PackedStorage
<
Storage
,
ELEMENTS_PER_LDG
>
PackedStorage_
;
// The data type for packed storage in SMEM.
typedef
typename
PackedStorage_
::
Type
PackedStorageType
;
// The number of elements in the packed storage.
const
int
PACKED_ELEMENTS_PER_LDG
=
PackedStorage_
::
PACKED_ELEMENTS_PER_LDG
;
// Registers to keep the data live for the persistent approach.
PackedStorageType
x_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
PackedStorageType
dy_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
// Shared memory buffer to store the extra pixels.
extern
__shared__
PackedStorageType
smem_storage_packed
[];
for
(
int
c_blk_index
=
blockIdx
.
y
;
c_blk_index
<
params
.
c_blks
;
c_blk_index
+=
gridDim
.
y
)
{
// The position in the NHW dimension where the CTA starts.
int
cta_nhw_regs
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_REGISTERS
;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int
cta_nhw_smem
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_SMEM
;
// Compute the NHW coordinate of the thread in the CTA.
const
int
thread_in_cta_nhw
=
threadIdx
.
x
/
THREADS_PER_PIXEL
;
// The position in the C dimension where the CTA starts.
const
int
cta_c
=
c_blk_index
*
C_ELEMENTS_PER_CTA
;
// Compute the C coordinate of the thread in the CTA.
const
int
thread_in_cta_c
=
threadIdx
.
x
%
THREADS_PER_PIXEL
;
// Compute the C coordinate of the thread.
const
int
thread_c
=
cta_c
+
thread_in_cta_c
*
ELEMENTS_PER_LDG
;
// Is the thread working on a valid C dimension?
const
int
is_valid_c
=
thread_c
<
params
.
c
;
// Registers to store the mean used for entire duration
float
mean
[
ELEMENTS_PER_LDG
];
zero_array
(
mean
);
if
(
is_valid_c
)
{
read_from_gmem
(
mean
,
params
.
gmem_saved_mean
,
thread_c
/
ELEMENTS_PER_LDG
);
}
// accumulation related registers
float
count
=
0.
f
,
dscale
[
ELEMENTS_PER_LDG
],
dbias
[
ELEMENTS_PER_LDG
];
zero_array
(
dscale
);
zero_array
(
dbias
);
// The number of elements loaded by this CTA.
int
cta_count
=
0
;
// The base pointers to load from.
const
uint16_t
*
gmem_src
=
&
params
.
gmem_src
[
thread_c
];
const
uint16_t
*
gmem_dy
=
&
params
.
gmem_dy
[
thread_c
];
// outer loops
int
OUTER_LOOPS
=
OUTER_LOOPS_
==
1
?
1
:
params
.
outer_loops
;
// Load the batch of elements. Compute sum across them
const
int
pixels_per_iteration
=
PIXELS_PER_CTA_IN_REGISTERS
*
gridDim
.
x
;
if
(
OUTER_LOOPS_
!=
1
)
{
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized
int
offset
=
params
.
nhw
-
pixels_per_iteration
*
OUTER_LOOPS
-
PIXELS_PER_CTA_IN_SMEM
*
gridDim
.
x
;
cta_nhw_regs
+=
offset
;
cta_nhw_smem
+=
offset
;
}
#pragma unroll 1
for
(
int
loop_i
=
0
;
loop_i
<
OUTER_LOOPS
;
++
loop_i
)
{
// The nhw position.
int
nhw_regs
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count
+=
max
(
0
,
min
(
PIXELS_PER_CTA_IN_REGISTERS
,
params
.
nhw
-
nhw_regs
));
// Read the elements from memory.
float
is_valid
[
PIXELS_PER_THREAD_IN_REGISTERS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
nhw_regs
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
zero_array
(
x_storage
[
i
]);
zero_array
(
dy_storage
[
i
]);
is_valid
[
i
]
=
0.
f
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
if
(
loop_i
==
OUTER_LOOPS
-
1
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
else
{
ldg
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
is_valid
[
i
]
=
1.
f
;
}
}
// Do the math.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float and update
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
to_float
(
dy_math
,
dy_storage
[
i
]);
// Update the count.
count
+=
is_valid
[
i
];
// Invert the count.
float
inv_count
=
is_valid
[
i
]
?
1.
f
/
count
:
0.
f
;
bwd_update
(
dscale
,
dbias
,
dy_math
,
x_math
,
mean
,
inv_count
);
}
}
// The elements to load and store in SMEM.
int
smem_nhw
=
OUTER_LOOPS
*
pixels_per_iteration
+
cta_nhw_smem
;
// Load elements from SMEM, update the CTA count.
int
pixels_in_smem
=
min
(
PIXELS_PER_CTA_IN_SMEM
,
params
.
nhw
-
smem_nhw
);
if
(
pixels_in_smem
>
0
)
{
cta_count
+=
pixels_in_smem
;
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
bool
is_pixel_valid
=
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
);
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
],
dy_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
zero_array
(
x_storage_local
);
zero_array
(
dy_storage_local
);
if
(
is_pixel_valid
)
{
ldg_stream
(
x_storage_local
,
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage_local
,
&
gmem_dy
[
idx
*
params
.
c
]);
}
// The offset to store in SMEM.
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
// Store in SMEM.
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
x_storage_local
);
offset
+=
PIXELS_PER_THREAD_IN_SMEM
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
dy_storage_local
);
// Update the count.
count
+=
is_pixel_valid
;
// Invert the count.
float
inv_count
=
is_pixel_valid
?
1.
f
/
count
:
0.
f
;
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
to_float
(
dy_math
,
dy_storage_local
);
bwd_update
(
dscale
,
dbias
,
dy_math
,
x_math
,
mean
,
inv_count
);
}
}
// We scale the mean by the number of elements. It brings more stability.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
dbias
[
i
]
*=
count
;
dscale
[
i
]
*=
count
;
}
// dscale parallel sum
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dscale
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// dbias parallel sum
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dbias
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// The workspace in global memory is distributed across the different CTA.
int
gmem_sums_offset
=
c_blk_index
*
gridDim
.
x
*
C_ELEMENTS_PER_CTA
*
2
;
// Write the data for the CTA to global memory.
float
*
gmem_sums
=
&
params
.
gmem_sums
[
gmem_sums_offset
];
if
(
threadIdx
.
x
<
THREADS_PER_PIXEL
)
{
const
int
idx
=
blockIdx
.
x
*
THREADS_PER_PIXEL
+
threadIdx
.
x
;
write_to_gmem
(
&
gmem_sums
[
0
],
idx
,
dscale
);
write_to_gmem
(
&
gmem_sums
[
C_ELEMENTS_PER_CTA
*
gridDim
.
x
],
idx
,
dbias
);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int
*
gmem_retired_ctas
=
&
params
.
gmem_retired_ctas
[
c_blk_index
%
(
2
*
gridDim
.
y
)];
inter_block_sync
(
gmem_retired_ctas
,
gridDim
.
x
,
blockIdx
.
x
==
0
);
// Reset the accumulators for global summation
zero_array
(
dscale
);
zero_array
(
dbias
);
// Build the global accumulation
#pragma unroll 1
for
(
int
idx
=
threadIdx
.
x
;
idx
<
THREADS_PER_PIXEL
*
gridDim
.
x
;
idx
+=
THREADS_PER_CTA
)
{
float
tmp1
[
ELEMENTS_PER_LDG
],
tmp2
[
ELEMENTS_PER_LDG
];
read_from_gmem
(
tmp1
,
gmem_sums
,
idx
);
read_from_gmem
(
tmp2
,
gmem_sums
+
C_ELEMENTS_PER_CTA
*
gridDim
.
x
,
idx
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
dscale
[
i
]
+=
tmp1
[
i
];
dbias
[
i
]
+=
tmp2
[
i
];
}
}
// dscale parallel sum
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
1
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
);
}
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dscale
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// dbias parallel sum
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
0
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
);
}
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dbias
,
smem
,
thread_in_cta_c
);
// inv-var
float
var
[
ELEMENTS_PER_LDG
];
zero_array
(
var
);
if
(
is_valid_c
)
{
read_from_gmem
(
var
,
params
.
gmem_saved_var
,
thread_c
/
ELEMENTS_PER_LDG
);
}
// Normalize the dscale.
multiply
(
dscale
,
var
);
// store dscale/dbias
bool
is_valid_for_saving
=
is_valid_c
&&
blockIdx
.
x
==
0
&&
thread_in_cta_nhw
==
0
;
if
(
is_valid_for_saving
)
{
if
(
params
.
sync_iters
>
0
)
{
scaled_write_to_gmem
(
params
.
gmem_dscale
,
thread_c
/
ELEMENTS_PER_LDG
,
dscale
,
params
.
wgrad_coeff
);
scaled_write_to_gmem
(
params
.
gmem_dbias
,
thread_c
/
ELEMENTS_PER_LDG
,
dbias
,
params
.
wgrad_coeff
);
}
else
{
write_to_gmem
(
params
.
gmem_dscale
,
thread_c
/
ELEMENTS_PER_LDG
,
dscale
);
write_to_gmem
(
params
.
gmem_dbias
,
thread_c
/
ELEMENTS_PER_LDG
,
dbias
);
}
}
// scale
float
scale
[
ELEMENTS_PER_LDG
];
zero_array
(
scale
);
if
(
is_valid_c
)
{
read_from_gmem
(
scale
,
params
.
gmem_scale
,
thread_c
/
ELEMENTS_PER_LDG
);
}
// Further normalize the dscale to be used in dx calculation
multiply
(
dscale
,
var
);
// scale the inv-var as well, afterwards
multiply
(
var
,
scale
);
// inverse count
float
inv_count
=
params
.
svar_inv_count
;
// The base pointer to write to.
uint16_t
*
const
gmem_dst
=
&
params
.
gmem_dst
[
thread_c
];
// Store the elements in registers.
#pragma unroll 1
for
(
int
loop_i
=
OUTER_LOOPS
-
1
;
loop_i
>=
0
;
--
loop_i
)
{
// The value for nhw.
int
out_nhw
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Normalize the elements and write to memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
to_float
(
dy_math
,
dy_storage
[
i
]);
float
dx
[
ELEMENTS_PER_LDG
];
bwd_dx
(
dx
,
dy_math
,
var
,
x_math
,
mean
,
dscale
,
dbias
,
inv_count
);
// Write back.
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
dx
);
}
}
// The next value of nhw.
out_nhw
-=
pixels_per_iteration
;
// Read the next elements from memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
}
}
// Normalize the elements from SMEM and write them out.
if
(
pixels_in_smem
>
0
)
{
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_valid
=
((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
;
if
(
is_valid
)
{
// Read from SMEM.
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
],
dy_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
read_from_smem
(
x_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
offset
+=
PIXELS_PER_THREAD_IN_SMEM
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
read_from_smem
(
dy_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
to_float
(
dy_math
,
dy_storage_local
);
float
dx
[
ELEMENTS_PER_LDG
];
bwd_dx
(
dx
,
dy_math
,
var
,
x_math
,
mean
,
dscale
,
dbias
,
inv_count
);
// Write back.
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
dx
);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads
();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Storage
,
int
THREADS_PER_CTA
,
int
THREADS_PER_PIXEL
,
int
PIXELS_PER_THREAD_IN_REGISTERS
,
int
PIXELS_PER_THREAD_IN_SMEM
,
int
ELEMENTS_PER_LDG
,
int
USE_ONLINE_APPROACH
,
int
OUTER_LOOPS_
,
int
DESIRED_OCCUPANCY
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
,
DESIRED_OCCUPANCY
)
void
nhwc_batch_norm_bwd_relu
(
NhwcBatchNormBwdParams
params
)
{
// The number of pixels loaded in a single LDG.
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
// The number of pixels computed per CTA stored in registers.
const
int
PIXELS_PER_CTA_IN_REGISTERS
=
PIXELS_PER_THREAD_IN_REGISTERS
*
PIXELS_PER_LDG
;
// The number of pixels computed per CTA stored in SMEM.
const
int
PIXELS_PER_CTA_IN_SMEM
=
PIXELS_PER_THREAD_IN_SMEM
*
PIXELS_PER_LDG
;
// The number of C elements per CTA.
const
int
C_ELEMENTS_PER_CTA
=
THREADS_PER_PIXEL
*
ELEMENTS_PER_LDG
;
// Shared memory to do CTA-wide parallel sums.
__shared__
float
smem
[
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
32
)
*
ELEMENTS_PER_LDG
];
// The adapter for the storage.
typedef
PackedStorage
<
Storage
,
ELEMENTS_PER_LDG
>
PackedStorage_
;
// The data type for packed storage in SMEM.
typedef
typename
PackedStorage_
::
Type
PackedStorageType
;
// The number of elements in the packed storage.
const
int
PACKED_ELEMENTS_PER_LDG
=
PackedStorage_
::
PACKED_ELEMENTS_PER_LDG
;
// Registers to keep the data live for the persistent approach.
PackedStorageType
x_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
PackedStorageType
dy_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
// Shared memory buffer to store the extra pixels.
extern
__shared__
PackedStorageType
smem_storage_packed
[];
for
(
int
c_blk_index
=
blockIdx
.
y
;
c_blk_index
<
params
.
c_blks
;
c_blk_index
+=
gridDim
.
y
)
{
// The position in the NHW dimension where the CTA starts.
int
cta_nhw_regs
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_REGISTERS
;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int
cta_nhw_smem
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_SMEM
;
// Compute the NHW coordinate of the thread in the CTA.
const
int
thread_in_cta_nhw
=
threadIdx
.
x
/
THREADS_PER_PIXEL
;
// The position in the C dimension where the CTA starts.
const
int
cta_c
=
c_blk_index
*
C_ELEMENTS_PER_CTA
;
// Compute the C coordinate of the thread in the CTA.
const
int
thread_in_cta_c
=
threadIdx
.
x
%
THREADS_PER_PIXEL
;
// Compute the C coordinate of the thread.
const
int
thread_c
=
cta_c
+
thread_in_cta_c
*
ELEMENTS_PER_LDG
;
// Is the thread working on a valid C dimension?
const
int
is_valid_c
=
thread_c
<
params
.
c
;
// Registers to store the mean/var/scale/bias used for the entire duration
// Register usage optimizations:
// 1. Can combine bias - (mean * var * scale) into a single register
// 2. Can combine var * scale into a single register
float
varscale
[
ELEMENTS_PER_LDG
];
zero_array
(
varscale
);
if
(
is_valid_c
)
{
read_from_gmem
(
varscale
,
params
.
gmem_saved_var
,
thread_c
/
ELEMENTS_PER_LDG
);
}
float
tmp
[
ELEMENTS_PER_LDG
];
zero_array
(
tmp
);
if
(
is_valid_c
)
{
read_from_gmem
(
tmp
,
params
.
gmem_scale
,
thread_c
/
ELEMENTS_PER_LDG
);
}
multiply
(
varscale
,
tmp
);
float
mean
[
ELEMENTS_PER_LDG
];
zero_array
(
mean
);
if
(
is_valid_c
)
{
read_from_gmem
(
mean
,
params
.
gmem_saved_mean
,
thread_c
/
ELEMENTS_PER_LDG
);
}
zero_array
(
tmp
);
if
(
is_valid_c
)
{
read_from_gmem
(
tmp
,
params
.
gmem_bias
,
thread_c
/
ELEMENTS_PER_LDG
);
}
float
mean_var_scale_bias
[
ELEMENTS_PER_LDG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
mean_var_scale_bias
[
i
]
=
tmp
[
i
]
-
(
mean
[
i
]
*
varscale
[
i
]);
}
// accumulation related registers
float
count
=
0.
f
,
dscale
[
ELEMENTS_PER_LDG
],
dbias
[
ELEMENTS_PER_LDG
];
zero_array
(
dscale
);
zero_array
(
dbias
);
// The number of elements loaded by this CTA.
int
cta_count
=
0
;
// The base pointers to load from.
const
uint16_t
*
gmem_src
=
&
params
.
gmem_src
[
thread_c
];
const
uint16_t
*
gmem_dy
=
&
params
.
gmem_dy
[
thread_c
];
// outer loops
int
OUTER_LOOPS
=
OUTER_LOOPS_
==
1
?
1
:
params
.
outer_loops
;
// Load the batch of elements. Compute sum across them
const
int
pixels_per_iteration
=
PIXELS_PER_CTA_IN_REGISTERS
*
gridDim
.
x
;
if
(
OUTER_LOOPS_
!=
1
)
{
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized
int
offset
=
params
.
nhw
-
pixels_per_iteration
*
OUTER_LOOPS
-
PIXELS_PER_CTA_IN_SMEM
*
gridDim
.
x
;
cta_nhw_regs
+=
offset
;
cta_nhw_smem
+=
offset
;
}
#pragma unroll 1
for
(
int
loop_i
=
0
;
loop_i
<
OUTER_LOOPS
;
++
loop_i
)
{
// The nhw position.
int
nhw_regs
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count
+=
max
(
0
,
min
(
PIXELS_PER_CTA_IN_REGISTERS
,
params
.
nhw
-
nhw_regs
));
// Read the elements from memory.
float
is_valid
[
PIXELS_PER_THREAD_IN_REGISTERS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
nhw_regs
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
zero_array
(
x_storage
[
i
]);
zero_array
(
dy_storage
[
i
]);
is_valid
[
i
]
=
0.
f
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
if
(
loop_i
==
OUTER_LOOPS
-
1
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
else
{
ldg
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
is_valid
[
i
]
=
1.
f
;
}
}
// Do the math.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float and update
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
to_float
(
dy_math
,
dy_storage
[
i
]);
// Update the count.
count
+=
is_valid
[
i
];
// Invert the count.
float
inv_count
=
is_valid
[
i
]
?
1.
f
/
count
:
0.
f
;
relu_bwd
(
dy_math
,
x_math
,
mean_var_scale_bias
,
varscale
,
is_valid
[
i
]);
bwd_update
(
dscale
,
dbias
,
dy_math
,
x_math
,
mean
,
inv_count
);
}
}
// The elements to load and store in SMEM.
int
smem_nhw
=
OUTER_LOOPS
*
pixels_per_iteration
+
cta_nhw_smem
;
// Load elements from SMEM, update the CTA count.
int
pixels_in_smem
=
min
(
PIXELS_PER_CTA_IN_SMEM
,
params
.
nhw
-
smem_nhw
);
if
(
pixels_in_smem
>
0
)
{
cta_count
+=
pixels_in_smem
;
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
bool
is_pixel_valid
=
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
);
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
],
dy_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
zero_array
(
x_storage_local
);
zero_array
(
dy_storage_local
);
if
(
is_pixel_valid
)
{
ldg_stream
(
x_storage_local
,
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage_local
,
&
gmem_dy
[
idx
*
params
.
c
]);
}
// The offset to store in SMEM.
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
// Store in SMEM.
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
x_storage_local
);
offset
+=
PIXELS_PER_THREAD_IN_SMEM
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
dy_storage_local
);
// Update the count.
count
+=
is_pixel_valid
;
// Invert the count.
float
inv_count
=
is_pixel_valid
?
1.
f
/
count
:
0.
f
;
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
to_float
(
dy_math
,
dy_storage_local
);
relu_bwd
(
dy_math
,
x_math
,
mean_var_scale_bias
,
varscale
,
is_pixel_valid
);
bwd_update
(
dscale
,
dbias
,
dy_math
,
x_math
,
mean
,
inv_count
);
}
}
// We scale the mean by the number of elements. It brings more stability.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
dbias
[
i
]
*=
count
;
dscale
[
i
]
*=
count
;
}
// dscale parallel sum
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dscale
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// dbias parallel sum
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dbias
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// The workspace in global memory is distributed across the different CTA.
int
gmem_sums_offset
=
c_blk_index
*
gridDim
.
x
*
C_ELEMENTS_PER_CTA
*
2
;
// Write the data for the CTA to global memory.
float
*
gmem_sums
=
&
params
.
gmem_sums
[
gmem_sums_offset
];
if
(
threadIdx
.
x
<
THREADS_PER_PIXEL
)
{
const
int
idx
=
blockIdx
.
x
*
THREADS_PER_PIXEL
+
threadIdx
.
x
;
write_to_gmem
(
&
gmem_sums
[
0
],
idx
,
dscale
);
write_to_gmem
(
&
gmem_sums
[
C_ELEMENTS_PER_CTA
*
gridDim
.
x
],
idx
,
dbias
);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int
*
gmem_retired_ctas
=
&
params
.
gmem_retired_ctas
[
c_blk_index
%
(
2
*
gridDim
.
y
)];
inter_block_sync
(
gmem_retired_ctas
,
gridDim
.
x
,
blockIdx
.
x
==
0
);
// Reset the accumulators for global summation
zero_array
(
dscale
);
zero_array
(
dbias
);
// Build the global accumulation
#pragma unroll 1
for
(
int
idx
=
threadIdx
.
x
;
idx
<
THREADS_PER_PIXEL
*
gridDim
.
x
;
idx
+=
THREADS_PER_CTA
)
{
float
tmp1
[
ELEMENTS_PER_LDG
],
tmp2
[
ELEMENTS_PER_LDG
];
read_from_gmem
(
tmp1
,
gmem_sums
,
idx
);
read_from_gmem
(
tmp2
,
gmem_sums
+
C_ELEMENTS_PER_CTA
*
gridDim
.
x
,
idx
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
dscale
[
i
]
+=
tmp1
[
i
];
dbias
[
i
]
+=
tmp2
[
i
];
}
}
// dscale parallel sum
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
1
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
);
}
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dscale
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// dbias parallel sum
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
0
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
);
}
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dbias
,
smem
,
thread_in_cta_c
);
// Normalize the dscale.
float
var
[
ELEMENTS_PER_LDG
];
zero_array
(
var
);
if
(
is_valid_c
)
{
read_from_gmem
(
var
,
params
.
gmem_saved_var
,
thread_c
/
ELEMENTS_PER_LDG
);
}
multiply
(
dscale
,
var
);
// store dscale/dbias
bool
is_valid_for_saving
=
is_valid_c
&&
blockIdx
.
x
==
0
&&
thread_in_cta_nhw
==
0
;
if
(
is_valid_for_saving
)
{
if
(
params
.
sync_iters
>
0
)
{
scaled_write_to_gmem
(
params
.
gmem_dscale
,
thread_c
/
ELEMENTS_PER_LDG
,
dscale
,
params
.
wgrad_coeff
);
scaled_write_to_gmem
(
params
.
gmem_dbias
,
thread_c
/
ELEMENTS_PER_LDG
,
dbias
,
params
.
wgrad_coeff
);
}
else
{
write_to_gmem
(
params
.
gmem_dscale
,
thread_c
/
ELEMENTS_PER_LDG
,
dscale
);
write_to_gmem
(
params
.
gmem_dbias
,
thread_c
/
ELEMENTS_PER_LDG
,
dbias
);
}
}
// Further normalize the dscale to be used in dx calculation
float
scale
[
ELEMENTS_PER_LDG
];
zero_array
(
scale
);
if
(
is_valid_c
)
{
read_from_gmem
(
scale
,
params
.
gmem_scale
,
thread_c
/
ELEMENTS_PER_LDG
);
}
multiply
(
dscale
,
var
);
// scale the inv-var as well, afterwards
multiply
(
var
,
scale
);
// inverse count
float
inv_count
=
params
.
svar_inv_count
;
// The base pointer to write to.
uint16_t
*
const
gmem_dst
=
&
params
.
gmem_dst
[
thread_c
];
// Store the elements in registers.
#pragma unroll 1
for
(
int
loop_i
=
OUTER_LOOPS
-
1
;
loop_i
>=
0
;
--
loop_i
)
{
// The value for nhw.
int
out_nhw
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Normalize the elements and write to memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
to_float
(
dy_math
,
dy_storage
[
i
]);
relu_bwd_for_dx
(
dy_math
,
x_math
,
mean_var_scale_bias
,
var
);
float
dx
[
ELEMENTS_PER_LDG
];
bwd_dx
(
dx
,
dy_math
,
var
,
x_math
,
mean
,
dscale
,
dbias
,
inv_count
);
// Write back.
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
dx
);
}
}
// The next value of nhw.
out_nhw
-=
pixels_per_iteration
;
// Read the next elements from memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
}
}
// Normalize the elements from SMEM and write them out.
if
(
pixels_in_smem
>
0
)
{
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_valid
=
((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
;
if
(
is_valid
)
{
// Read from SMEM.
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
],
dy_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
read_from_smem
(
x_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
offset
+=
PIXELS_PER_THREAD_IN_SMEM
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
read_from_smem
(
dy_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
to_float
(
dy_math
,
dy_storage_local
);
relu_bwd_for_dx
(
dy_math
,
x_math
,
mean_var_scale_bias
,
var
);
float
dx
[
ELEMENTS_PER_LDG
];
bwd_dx
(
dx
,
dy_math
,
var
,
x_math
,
mean
,
dscale
,
dbias
,
inv_count
);
// Write back.
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
dx
);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads
();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Storage
,
int
THREADS_PER_CTA
,
int
THREADS_PER_PIXEL
,
int
PIXELS_PER_THREAD_IN_REGISTERS
,
int
PIXELS_PER_THREAD_IN_SMEM
,
int
ELEMENTS_PER_LDG
,
int
USE_ONLINE_APPROACH
,
int
OUTER_LOOPS_
,
int
DESIRED_OCCUPANCY
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
,
DESIRED_OCCUPANCY
)
void
nhwc_batch_norm_bwd_add_relu
(
NhwcBatchNormBwdParams
params
)
{
// The number of pixels loaded in a single LDG.
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
// The number of pixels computed per CTA stored in registers.
const
int
PIXELS_PER_CTA_IN_REGISTERS
=
PIXELS_PER_THREAD_IN_REGISTERS
*
PIXELS_PER_LDG
;
// The number of pixels computed per CTA stored in SMEM.
const
int
PIXELS_PER_CTA_IN_SMEM
=
PIXELS_PER_THREAD_IN_SMEM
*
PIXELS_PER_LDG
;
// The number of C elements per CTA.
const
int
C_ELEMENTS_PER_CTA
=
THREADS_PER_PIXEL
*
ELEMENTS_PER_LDG
;
// Shared memory to do CTA-wide parallel sums.
__shared__
float
smem
[
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
32
)
*
ELEMENTS_PER_LDG
];
// The adapter for the storage.
typedef
PackedStorage
<
Storage
,
ELEMENTS_PER_LDG
>
PackedStorage_
;
// The data type for packed storage in SMEM.
typedef
typename
PackedStorage_
::
Type
PackedStorageType
;
// The number of elements in the packed storage.
const
int
PACKED_ELEMENTS_PER_LDG
=
PackedStorage_
::
PACKED_ELEMENTS_PER_LDG
;
// Registers to keep the data live for the persistent approach.
PackedStorageType
x_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
PackedStorageType
dy_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
// Shared memory buffer to store the extra pixels.
extern
__shared__
PackedStorageType
smem_storage_packed
[];
for
(
int
c_blk_index
=
blockIdx
.
y
;
c_blk_index
<
params
.
c_blks
;
c_blk_index
+=
gridDim
.
y
)
{
// The position in the NHW dimension where the CTA starts.
int
cta_nhw_regs
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_REGISTERS
;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int
cta_nhw_smem
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_SMEM
;
// Compute the NHW coordinate of the thread in the CTA.
const
int
thread_in_cta_nhw
=
threadIdx
.
x
/
THREADS_PER_PIXEL
;
// The position in the C dimension where the CTA starts.
const
int
cta_c
=
c_blk_index
*
C_ELEMENTS_PER_CTA
;
// Compute the C coordinate of the thread in the CTA.
const
int
thread_in_cta_c
=
threadIdx
.
x
%
THREADS_PER_PIXEL
;
// Compute the C coordinate of the thread.
const
int
thread_c
=
cta_c
+
thread_in_cta_c
*
ELEMENTS_PER_LDG
;
// Is the thread working on a valid C dimension?
const
int
is_valid_c
=
thread_c
<
params
.
c
;
float
mean
[
ELEMENTS_PER_LDG
];
zero_array
(
mean
);
if
(
is_valid_c
)
{
read_from_gmem
(
mean
,
params
.
gmem_saved_mean
,
thread_c
/
ELEMENTS_PER_LDG
);
}
// accumulation related registers
float
count
=
0.
f
,
dscale
[
ELEMENTS_PER_LDG
],
dbias
[
ELEMENTS_PER_LDG
];
zero_array
(
dscale
);
zero_array
(
dbias
);
// The number of elements loaded by this CTA.
int
cta_count
=
0
;
// The base pointers to load from.
const
uint16_t
*
gmem_src
=
&
params
.
gmem_src
[
thread_c
];
const
uint16_t
*
gmem_dy
=
&
params
.
gmem_dy
[
thread_c
];
uint16_t
*
gmem_dst1
=
&
params
.
gmem_dst1
[
thread_c
];
// outer loops
int
OUTER_LOOPS
=
OUTER_LOOPS_
==
1
?
1
:
params
.
outer_loops
;
// Load the batch of elements. Compute sum across them
const
int
pixels_per_iteration
=
PIXELS_PER_CTA_IN_REGISTERS
*
gridDim
.
x
;
if
(
OUTER_LOOPS_
!=
1
)
{
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized, offset is evenly divisible by 32
int
offset
=
(
pixels_per_iteration
*
OUTER_LOOPS
+
PIXELS_PER_CTA_IN_SMEM
*
gridDim
.
x
-
params
.
nhw
)
&
~
31
;
cta_nhw_regs
-=
offset
;
cta_nhw_smem
-=
offset
;
}
const
unsigned
int
*
const
gmem_relu_bitmask
=
params
.
gmem_relu_bitmask
+
((
params
.
nhw
+
31
)
&
~
31
)
*
2
*
c_blk_index
;
#pragma unroll 1
for
(
int
loop_i
=
0
;
loop_i
<
OUTER_LOOPS
;
++
loop_i
)
{
// The nhw position.
int
nhw_regs
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count
+=
max
(
0
,
min
(
PIXELS_PER_CTA_IN_REGISTERS
,
params
.
nhw
-
nhw_regs
));
int
lane_id
=
threadIdx
.
x
&
31
;
// Read the elements from memory.
float
is_valid
[
PIXELS_PER_THREAD_IN_REGISTERS
];
unsigned
int
relu_mask
[
PIXELS_PER_THREAD_IN_REGISTERS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
nhw_regs
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
zero_array
(
x_storage
[
i
]);
zero_array
(
dy_storage
[
i
]);
is_valid
[
i
]
=
0.
f
;
const
bool
is_valid_nhw
=
static_cast
<
unsigned
int
>
(
idx
)
<
static_cast
<
unsigned
int
>
(
params
.
nhw
);
if
(
is_valid_nhw
)
{
if
(
is_valid_c
)
{
if
(
loop_i
==
OUTER_LOOPS
-
1
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
else
{
ldg
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
is_valid
[
i
]
=
1.
f
;
}
if
(
lane_id
<
ELEMENTS_PER_LDG
)
{
relu_mask
[
i
]
=
gmem_relu_bitmask
[
idx
*
2
+
lane_id
];
}
}
}
// Do the math.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
nhw_regs
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
// Convert to float and update
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
bool
rectified
[
ELEMENTS_PER_LDG
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
rectified
[
j
]
=
((
__shfl_sync
(
0xFFFFFFFFU
,
relu_mask
[
i
],
j
)
&
(
1U
<<
lane_id
))
!=
0
);
}
to_float
(
x_math
,
x_storage
[
i
]);
to_float
(
dy_math
,
dy_storage
[
i
]);
// Update the count.
count
+=
is_valid
[
i
];
// Invert the count.
float
inv_count
=
is_valid
[
i
]
?
1.
f
/
count
:
0.
f
;
relu_bwd
(
dy_math
,
rectified
,
is_valid
[
i
]);
bwd_update
(
dscale
,
dbias
,
dy_math
,
x_math
,
mean
,
inv_count
);
// Lastly we need 'dy' only for BN, so store the 'relu-dgrad'ed version
from_float
(
dy_storage
[
i
],
dy_math
);
// dZ for elementwise add
if
(
is_valid
[
i
])
{
if
(
loop_i
==
OUTER_LOOPS
-
1
)
{
stg_stream
(
&
gmem_dst1
[
idx
*
params
.
c
],
dy_storage
[
i
]);
}
else
{
stg
(
&
gmem_dst1
[
idx
*
params
.
c
],
dy_storage
[
i
]);
}
}
}
}
// The elements to load and store in SMEM.
int
smem_nhw
=
OUTER_LOOPS
*
pixels_per_iteration
+
cta_nhw_smem
;
// Load elements from SMEM, update the CTA count.
int
pixels_in_smem
=
min
(
PIXELS_PER_CTA_IN_SMEM
,
params
.
nhw
-
smem_nhw
);
if
(
pixels_in_smem
>
0
)
{
cta_count
+=
pixels_in_smem
;
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_pixel_valid_nhw
=
static_cast
<
unsigned
int
>
(
idx
)
<
static_cast
<
unsigned
int
>
(
params
.
nhw
);
const
bool
is_pixel_valid
=
is_pixel_valid_nhw
&&
is_valid_c
;
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
],
dy_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
unsigned
int
relu_mask
;
int
lane_id
=
threadIdx
.
x
&
31
;
zero_array
(
x_storage_local
);
zero_array
(
dy_storage_local
);
if
(
is_pixel_valid_nhw
)
{
if
(
is_valid_c
)
{
ldg_stream
(
x_storage_local
,
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage_local
,
&
gmem_dy
[
idx
*
params
.
c
]);
}
if
(
lane_id
<
ELEMENTS_PER_LDG
)
{
relu_mask
=
gmem_relu_bitmask
[
idx
*
2
+
lane_id
];
}
}
bool
rectified
[
ELEMENTS_PER_LDG
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
rectified
[
j
]
=
((
__shfl_sync
(
0xFFFFFFFFU
,
relu_mask
,
j
)
&
(
1U
<<
lane_id
))
!=
0
);
}
// The offset to store in SMEM.
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
// Store in SMEM.
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
x_storage_local
);
offset
+=
PIXELS_PER_THREAD_IN_SMEM
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
// Update the count.
count
+=
is_pixel_valid
;
// Invert the count.
float
inv_count
=
is_pixel_valid
?
1.
f
/
count
:
0.
f
;
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
to_float
(
dy_math
,
dy_storage_local
);
relu_bwd
(
dy_math
,
rectified
,
is_pixel_valid
);
bwd_update
(
dscale
,
dbias
,
dy_math
,
x_math
,
mean
,
inv_count
);
from_float
(
dy_storage_local
,
dy_math
);
// dZ for elementwise add
if
(
is_pixel_valid
)
{
stg_stream
(
&
gmem_dst1
[
idx
*
params
.
c
],
dy_storage_local
);
}
// only store the 'relu-dgrad'ed version!
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
dy_storage_local
);
}
}
// We scale the mean by the number of elements. It brings more stability.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
dbias
[
i
]
*=
count
;
dscale
[
i
]
*=
count
;
}
// dscale parallel sum
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dscale
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// dbias parallel sum
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dbias
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// The workspace in global memory is distributed across the different CTA.
int
gmem_sums_offset
=
c_blk_index
*
gridDim
.
x
*
C_ELEMENTS_PER_CTA
*
2
;
// Write the data for the CTA to global memory.
float
*
gmem_sums
=
&
params
.
gmem_sums
[
gmem_sums_offset
];
if
(
threadIdx
.
x
<
THREADS_PER_PIXEL
)
{
const
int
idx
=
blockIdx
.
x
*
THREADS_PER_PIXEL
+
threadIdx
.
x
;
write_to_gmem
(
&
gmem_sums
[
0
],
idx
,
dscale
);
write_to_gmem
(
&
gmem_sums
[
C_ELEMENTS_PER_CTA
*
gridDim
.
x
],
idx
,
dbias
);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int
*
gmem_retired_ctas
=
&
params
.
gmem_retired_ctas
[
c_blk_index
%
(
2
*
gridDim
.
y
)];
inter_block_sync
(
gmem_retired_ctas
,
gridDim
.
x
,
blockIdx
.
x
==
0
);
// Reset the accumulators for global summation
zero_array
(
dscale
);
zero_array
(
dbias
);
// Build the global accumulation
#pragma unroll 1
for
(
int
idx
=
threadIdx
.
x
;
idx
<
THREADS_PER_PIXEL
*
gridDim
.
x
;
idx
+=
THREADS_PER_CTA
)
{
float
tmp1
[
ELEMENTS_PER_LDG
],
tmp2
[
ELEMENTS_PER_LDG
];
read_from_gmem
(
tmp1
,
gmem_sums
,
idx
);
read_from_gmem
(
tmp2
,
gmem_sums
+
C_ELEMENTS_PER_CTA
*
gridDim
.
x
,
idx
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
dscale
[
i
]
+=
tmp1
[
i
];
dbias
[
i
]
+=
tmp2
[
i
];
}
}
// dscale parallel sum
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
1
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
);
}
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dscale
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// dbias parallel sum
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
0
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
);
}
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dbias
,
smem
,
thread_in_cta_c
);
// Normalize the dscale.
float
var
[
ELEMENTS_PER_LDG
];
zero_array
(
var
);
if
(
is_valid_c
)
{
read_from_gmem
(
var
,
params
.
gmem_saved_var
,
thread_c
/
ELEMENTS_PER_LDG
);
}
multiply
(
dscale
,
var
);
// store dscale/dbias
bool
is_valid_for_saving
=
is_valid_c
&&
blockIdx
.
x
==
0
&&
thread_in_cta_nhw
==
0
;
if
(
is_valid_for_saving
)
{
if
(
params
.
sync_iters
>
0
)
{
scaled_write_to_gmem
(
params
.
gmem_dscale
,
thread_c
/
ELEMENTS_PER_LDG
,
dscale
,
params
.
wgrad_coeff
);
scaled_write_to_gmem
(
params
.
gmem_dbias
,
thread_c
/
ELEMENTS_PER_LDG
,
dbias
,
params
.
wgrad_coeff
);
}
else
{
write_to_gmem
(
params
.
gmem_dscale
,
thread_c
/
ELEMENTS_PER_LDG
,
dscale
);
write_to_gmem
(
params
.
gmem_dbias
,
thread_c
/
ELEMENTS_PER_LDG
,
dbias
);
}
}
// Further normalize the dscale to be used in dx calculation
float
scale
[
ELEMENTS_PER_LDG
];
zero_array
(
scale
);
if
(
is_valid_c
)
{
read_from_gmem
(
scale
,
params
.
gmem_scale
,
thread_c
/
ELEMENTS_PER_LDG
);
}
multiply
(
dscale
,
var
);
// scale the inv-var as well, afterwards
multiply
(
var
,
scale
);
// inverse count
float
inv_count
=
params
.
svar_inv_count
;
// The base pointer to write to.
uint16_t
*
const
gmem_dst
=
&
params
.
gmem_dst
[
thread_c
];
// Store the elements in registers.
#pragma unroll 1
for
(
int
loop_i
=
OUTER_LOOPS
-
1
;
loop_i
>=
0
;
--
loop_i
)
{
// The value for nhw.
int
out_nhw
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Normalize the elements and write to memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_valid
=
((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
;
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
to_float
(
dy_math
,
dy_storage
[
i
]);
float
dx
[
ELEMENTS_PER_LDG
];
bwd_dx
(
dx
,
dy_math
,
var
,
x_math
,
mean
,
dscale
,
dbias
,
inv_count
);
// Write back.
if
(
is_valid
)
{
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
dx
);
}
}
// The next value of nhw.
out_nhw
-=
pixels_per_iteration
;
// Read the next elements from memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
float
y
[
ELEMENTS_PER_LDG
];
zero_array
(
y
);
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage
[
i
],
&
gmem_dst1
[
idx
*
params
.
c
]);
}
}
}
// Normalize the elements from SMEM and write them out.
if
(
pixels_in_smem
>
0
)
{
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_valid
=
((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
;
if
(
is_valid
)
{
// Read from SMEM.
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
],
dy_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
read_from_smem
(
x_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
offset
+=
PIXELS_PER_THREAD_IN_SMEM
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
read_from_smem
(
dy_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
to_float
(
dy_math
,
dy_storage_local
);
float
dx
[
ELEMENTS_PER_LDG
];
bwd_dx
(
dx
,
dy_math
,
var
,
x_math
,
mean
,
dscale
,
dbias
,
inv_count
);
// Write back.
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
dx
);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads
();
}
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_
apex/contrib/csrc/xentropy/interface.cpp
0 → 100644
View file @
4d6ed501
#include <torch/extension.h>
// CUDA forward declarations
std
::
vector
<
at
::
Tensor
>
softmax_xentropy_cuda
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
,
const
bool
half_to_float
);
at
::
Tensor
softmax_xentropy_backward_cuda
(
const
at
::
Tensor
&
grad_loss
,
const
at
::
Tensor
&
logits
,
const
at
::
Tensor
&
max_log_sum_exp
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
at
::
Tensor
>
softmax_xentropy_forward
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
,
const
bool
half_to_float
)
{
CHECK_CUDA
(
input
);
CHECK_INPUT
(
labels
);
return
softmax_xentropy_cuda
(
input
,
labels
,
smoothing
,
half_to_float
);
}
at
::
Tensor
softmax_xentropy_backward
(
const
at
::
Tensor
&
grad_loss
,
const
at
::
Tensor
&
logits
,
const
at
::
Tensor
&
max_log_sum_exp
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
)
{
CHECK_CUDA
(
grad_loss
);
CHECK_CUDA
(
logits
);
CHECK_INPUT
(
max_log_sum_exp
);
CHECK_INPUT
(
labels
);
return
softmax_xentropy_backward_cuda
(
grad_loss
,
logits
,
max_log_sum_exp
,
labels
,
smoothing
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
softmax_xentropy_forward
,
"Softmax cross entropy loss with label smoothing forward (CUDA)"
);
m
.
def
(
"backward"
,
&
softmax_xentropy_backward
,
"Softmax cross entropy loss with label smoothing backward (CUDA)"
);
}
apex/contrib/csrc/xentropy/xentropy_kernel.cu
0 → 100644
View file @
4d6ed501
/**
* From PyTorch:
*
* Copyright (c) 2016- Facebook, Inc (Adam Paszke)
* Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
* Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
* Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
* Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
* Copyright (c) 2011-2013 NYU (Clement Farabet)
* Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
* Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
* Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
*
* From Caffe2:
*
* Copyright (c) 2016-present, Facebook Inc. All rights reserved.
*
* All contributions by Facebook:
* Copyright (c) 2016 Facebook Inc.
*
* All contributions by Google:
* Copyright (c) 2015 Google Inc.
* All rights reserved.
*
* All contributions by Yangqing Jia:
* Copyright (c) 2015 Yangqing Jia
* All rights reserved.
*
* All contributions from Caffe:
* Copyright(c) 2013, 2014, 2015, the respective contributors
* All rights reserved.
*
* All other contributions:
* Copyright(c) 2015, 2016 the respective contributors
* All rights reserved.
*
* Caffe2 uses a copyright model similar to Caffe: each contributor holds
* copyright over their contributions to Caffe2. The project versioning records
* all such contribution and copyright details. If a contributor wants to further
* mark their specific copyright on a particular contribution, they should
* indicate their copyright solely in the commit message of the change when it is
* committed.
*
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
* and IDIAP Research Institute nor the names of its contributors may be
* used to endorse or promote products derived from this software without
* specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/NumericLimits.cuh>
#include <THC/THC.h>
#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>
#include "type_shim.h"
using
Tensor
=
at
::
Tensor
;
using
TensorList
=
at
::
TensorList
;
using
ScalarType
=
at
::
ScalarType
;
using
at
::
acc_type
;
template
<
typename
T
,
typename
AccumT
,
typename
OutT
>
struct
LogSoftMaxForwardEpilogue
{
__device__
__forceinline__
LogSoftMaxForwardEpilogue
(
AccumT
max_input
,
AccumT
sum
)
:
logsum
(
max_input
+
std
::
log
(
sum
))
{}
__device__
__forceinline__
LogSoftMaxForwardEpilogue
(
AccumT
max_log_sum_exp
)
:
logsum
(
max_log_sum_exp
)
{}
__device__
__forceinline__
OutT
operator
()(
T
input
)
const
{
return
static_cast
<
OutT
>
(
input
-
logsum
);
}
const
AccumT
logsum
;
};
template
<
typename
T
,
typename
AccumT
,
typename
OutT
>
struct
LogSoftMaxBackwardEpilogue
{
__device__
__forceinline__
LogSoftMaxBackwardEpilogue
(
AccumT
sum
)
:
sum
(
sum
)
{}
__device__
__forceinline__
T
operator
()(
OutT
gradOutput
,
OutT
output
)
const
{
return
static_cast
<
T
>
(
gradOutput
-
std
::
exp
(
static_cast
<
AccumT
>
(
output
))
*
sum
);
}
const
AccumT
sum
;
};
const
int
max_threads
=
1024
;
inline
dim3
SoftMax_getBlockSize
(
int
ILP
,
uint64_t
dim_size
)
{
uint64_t
block_size
=
1
;
uint64_t
max_block_size
=
std
::
min
(
dim_size
/
ILP
,
static_cast
<
uint64_t
>
(
max_threads
));
while
(
block_size
<
max_block_size
)
block_size
*=
2
;
// Launch at least a single warp - the kernel assumes that.
block_size
=
std
::
max
(
block_size
,
static_cast
<
uint64_t
>
(
32
));
return
dim3
(
block_size
);
}
template
<
typename
T
>
struct
Add
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
Max
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
b
:
a
;
}
};
////////////////////////////////////////////////////////////////////////////////
// Regular kernel (fast when dim_size is large; requires inner_size == 1)
////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
typename
AccumT
>
struct
MaxFloat
{
__device__
__forceinline__
AccumT
operator
()(
AccumT
max
,
T
v
)
const
{
return
::
max
(
max
,
(
AccumT
)
v
);
}
};
template
<
typename
T
,
typename
AccumT
>
struct
AddFloat
{
__device__
__forceinline__
AccumT
operator
()(
AccumT
sum
,
T
v
)
const
{
return
sum
+
v
;
}
};
template
<
typename
T
,
typename
AccumT
>
struct
SumExpFloat
{
__device__
__forceinline__
SumExpFloat
(
AccumT
v
)
:
max_k
(
v
)
{}
__device__
__forceinline__
AccumT
operator
()(
AccumT
sum
,
T
v
)
const
{
return
sum
+
std
::
exp
(
v
-
max_k
);
}
const
AccumT
max_k
;
};
template
<
template
<
typename
>
class
Reduction
,
typename
AccumT
>
__device__
__forceinline__
AccumT
blockReduce
(
AccumT
*
smem
,
AccumT
val
,
const
Reduction
<
AccumT
>&
r
,
AccumT
defaultVal
)
{
// To avoid RaW races from chaining blockReduce calls together, we need a sync here
__syncthreads
();
smem
[
threadIdx
.
x
]
=
val
;
__syncthreads
();
AccumT
warpVal
=
defaultVal
;
// First warp will perform per-warp reductions for the remaining warps
uint32_t
mask
=
(((
uint64_t
)
1
)
<<
(
blockDim
.
x
/
32
))
-
1
;
if
(
threadIdx
.
x
<
32
)
{
int
lane
=
threadIdx
.
x
%
32
;
if
(
lane
<
blockDim
.
x
/
32
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
32
;
++
i
)
{
warpVal
=
r
(
warpVal
,
smem
[
lane
*
32
+
i
]);
}
__syncwarp
(
mask
);
smem
[
lane
]
=
warpVal
;
}
}
__syncthreads
();
// First thread will perform a reduction of the above per-warp reductions
AccumT
blockVal
=
defaultVal
;
if
(
threadIdx
.
x
==
0
)
{
for
(
int
i
=
0
;
i
<
blockDim
.
x
/
32
;
++
i
)
{
blockVal
=
r
(
blockVal
,
smem
[
i
]);
}
smem
[
0
]
=
blockVal
;
}
// Sync and broadcast
__syncthreads
();
return
smem
[
0
];
}
template
<
template
<
typename
>
class
Reduction1
,
template
<
typename
>
class
Reduction2
,
typename
AccumT
>
__device__
__forceinline__
void
blockReduce
(
AccumT
*
smem
,
AccumT
*
reducVal1
,
AccumT
val1
,
const
Reduction1
<
AccumT
>&
r1
,
AccumT
defaultVal1
,
AccumT
*
reducVal2
,
AccumT
val2
,
const
Reduction2
<
AccumT
>&
r2
,
AccumT
defaultVal2
)
{
// To avoid RaW races from chaining blockReduce calls together, we need a sync here
__syncthreads
();
smem
[
threadIdx
.
x
]
=
val1
;
smem
[
blockDim
.
x
+
threadIdx
.
x
]
=
val2
;
__syncthreads
();
AccumT
warpVal1
=
defaultVal1
;
AccumT
warpVal2
=
defaultVal2
;
// First warp will perform per-warp reductions for the remaining warps
uint32_t
mask
=
(((
uint64_t
)
1
)
<<
(
blockDim
.
x
/
32
))
-
1
;
if
(
threadIdx
.
x
<
32
)
{
int
lane
=
threadIdx
.
x
%
32
;
if
(
lane
<
blockDim
.
x
/
32
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
32
;
++
i
)
{
warpVal1
=
r1
(
warpVal1
,
smem
[
lane
*
32
+
i
]);
warpVal2
=
r2
(
warpVal2
,
smem
[
lane
*
32
+
i
+
blockDim
.
x
]);
}
__syncwarp
(
mask
);
smem
[
lane
]
=
warpVal1
;
smem
[
lane
+
blockDim
.
x
]
=
warpVal2
;
}
}
__syncthreads
();
// First thread will perform a reduction of the above per-warp reductions
AccumT
blockVal1
=
defaultVal1
;
AccumT
blockVal2
=
defaultVal2
;
if
(
threadIdx
.
x
==
0
)
{
for
(
int
i
=
0
;
i
<
blockDim
.
x
/
32
;
++
i
)
{
blockVal1
=
r1
(
blockVal1
,
smem
[
i
]);
blockVal2
=
r2
(
blockVal2
,
smem
[
i
+
blockDim
.
x
]);
}
smem
[
0
]
=
blockVal1
;
smem
[
blockDim
.
x
]
=
blockVal2
;
}
// Sync and broadcast
__syncthreads
();
*
reducVal1
=
smem
[
0
];
*
reducVal2
=
smem
[
blockDim
.
x
];
__syncthreads
();
}
template
<
template
<
typename
,
typename
>
class
Reduction
,
int
ILP
,
typename
T
,
typename
AccumT
>
__device__
__forceinline__
AccumT
ilpReduce
(
T
*
data
,
int
size
,
const
Reduction
<
T
,
AccumT
>&
r
,
AccumT
defaultVal
)
{
AccumT
threadVal
=
defaultVal
;
int
offset
=
threadIdx
.
x
;
int
last
=
size
%
(
ILP
*
blockDim
.
x
);
// Body (unroll by ILP times)
for
(;
offset
<
size
-
last
;
offset
+=
blockDim
.
x
*
ILP
)
{
T
tmp
[
ILP
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ILP
;
++
j
)
tmp
[
j
]
=
data
[
offset
+
j
*
blockDim
.
x
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ILP
;
++
j
)
threadVal
=
r
(
threadVal
,
tmp
[
j
]);
}
// Epilogue
for
(;
offset
<
size
;
offset
+=
blockDim
.
x
)
threadVal
=
r
(
threadVal
,
data
[
offset
]);
return
threadVal
;
}
template
<
template
<
typename
,
typename
>
class
Reduction1
,
template
<
typename
,
typename
>
class
Reduction2
,
int
ILP
,
typename
T
,
typename
AccumT
>
__device__
__forceinline__
void
ilpReduce
(
T
*
data
,
int
size
,
AccumT
*
reducVal1
,
const
Reduction1
<
T
,
AccumT
>&
r1
,
AccumT
defaultVal1
,
AccumT
*
reducVal2
,
const
Reduction2
<
T
,
AccumT
>&
r2
,
AccumT
defaultVal2
)
{
AccumT
threadVal1
=
defaultVal1
;
AccumT
threadVal2
=
defaultVal2
;
int
offset
=
threadIdx
.
x
;
int
last
=
size
%
(
ILP
*
blockDim
.
x
);
// Body (unroll by ILP times)
for
(;
offset
<
size
-
last
;
offset
+=
blockDim
.
x
*
ILP
)
{
T
tmp
[
ILP
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ILP
;
++
j
)
tmp
[
j
]
=
data
[
offset
+
j
*
blockDim
.
x
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ILP
;
++
j
)
{
threadVal1
=
r1
(
threadVal1
,
tmp
[
j
]);
threadVal2
=
r2
(
threadVal2
,
tmp
[
j
]);
}
}
// Epilogue
for
(;
offset
<
size
;
offset
+=
blockDim
.
x
)
{
threadVal1
=
r1
(
threadVal1
,
data
[
offset
]);
threadVal2
=
r2
(
threadVal2
,
data
[
offset
]);
}
*
reducVal1
=
threadVal1
;
*
reducVal2
=
threadVal2
;
}
template
<
int
ILP
,
typename
scalar_t
,
typename
accscalar_t
,
typename
outscalar_t
,
template
<
typename
,
typename
,
typename
>
class
Epilogue
>
__global__
void
cunn_SoftMaxXEntropyForward
(
accscalar_t
*
losses
,
outscalar_t
*
max_log_sum_exp
,
scalar_t
*
input
,
int64_t
*
labels
,
int64_t
classes
,
const
float
smoothing
)
{
extern
__shared__
unsigned
char
smem
[];
auto
sdata
=
reinterpret_cast
<
accscalar_t
*>
(
smem
);
// forward pointers to batch[blockIdx.x]
// each block handles a sample in the mini-batch
input
+=
blockIdx
.
x
*
classes
;
//output += blockIdx.x * classes;
int64_t
label
=
labels
[
blockIdx
.
x
];
// find the max and sum
accscalar_t
threadMax
,
threadSum
,
max_k
,
sum_k
;
ilpReduce
<
MaxFloat
,
AddFloat
,
ILP
,
scalar_t
,
accscalar_t
>
(
input
,
classes
,
&
threadMax
,
MaxFloat
<
scalar_t
,
accscalar_t
>
(),
-
at
::
numeric_limits
<
accscalar_t
>::
max
(),
&
threadSum
,
AddFloat
<
scalar_t
,
accscalar_t
>
(),
static_cast
<
accscalar_t
>
(
0
));
blockReduce
<
Max
,
Add
,
accscalar_t
>
(
sdata
,
&
max_k
,
threadMax
,
Max
<
accscalar_t
>
(),
-
at
::
numeric_limits
<
accscalar_t
>::
max
(),
&
sum_k
,
threadSum
,
Add
<
accscalar_t
>
(),
static_cast
<
accscalar_t
>
(
0
));
// reduce all values
accscalar_t
threadExp
=
ilpReduce
<
SumExpFloat
,
ILP
,
scalar_t
,
accscalar_t
>
(
input
,
classes
,
SumExpFloat
<
scalar_t
,
accscalar_t
>
(
max_k
),
static_cast
<
accscalar_t
>
(
0
));
accscalar_t
sumAll
=
blockReduce
<
Add
,
accscalar_t
>
(
sdata
,
threadExp
,
Add
<
accscalar_t
>
(),
static_cast
<
accscalar_t
>
(
0
));
Epilogue
<
scalar_t
,
accscalar_t
,
outscalar_t
>
epilogue
(
max_k
,
sumAll
);
// calculate per element loss with label smoothing
// reserve max + log_sum_exp for bprop
if
(
threadIdx
.
x
==
0
)
{
accscalar_t
log_prob
=
epilogue
(
static_cast
<
accscalar_t
>
(
input
[
label
]));
losses
[
blockIdx
.
x
]
=
(
max_k
+
std
::
log
(
sumAll
)
-
sum_k
/
classes
)
\
*
smoothing
-
log_prob
*
(
1
-
smoothing
);
max_log_sum_exp
[
blockIdx
.
x
]
=
max_k
+
std
::
log
(
sumAll
);
}
}
template
<
int
ILP
,
typename
scalar_t
,
typename
accscalar_t
,
typename
outscalar_t
,
template
<
typename
,
typename
,
typename
>
class
Epilogue
>
__global__
void
cunn_SoftMaxXEntropyBackward
(
scalar_t
*
gradInput
,
scalar_t
*
logits
,
outscalar_t
*
max_log_sum_exp
,
outscalar_t
*
gradOutput
,
int64_t
*
labels
,
const
float
smoothing
,
int
classes
)
{
gradInput
+=
blockIdx
.
x
*
classes
;
logits
+=
blockIdx
.
x
*
classes
;
accscalar_t
smooth_positives
=
1.0
-
smoothing
;
accscalar_t
smooth_negatives
=
smoothing
/
classes
;
accscalar_t
tmpGradOutput
=
gradOutput
[
blockIdx
.
x
];
int64_t
label
=
labels
[
blockIdx
.
x
];
accscalar_t
coeff
=
max_log_sum_exp
[
blockIdx
.
x
];
int
offset
=
threadIdx
.
x
;
int
last
=
classes
%
(
ILP
*
blockDim
.
x
);
for
(;
offset
<
classes
-
last
;
offset
+=
blockDim
.
x
*
ILP
)
{
accscalar_t
tmpLogits
[
ILP
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ILP
;
++
j
)
{
tmpLogits
[
j
]
=
static_cast
<
accscalar_t
>
(
logits
[
offset
+
j
*
blockDim
.
x
]);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
ILP
;
++
j
)
gradInput
[
offset
+
j
*
blockDim
.
x
]
=
tmpGradOutput
*
(
std
::
exp
(
tmpLogits
[
j
]
-
coeff
)
-
static_cast
<
accscalar_t
>
(
(
offset
+
j
*
blockDim
.
x
==
label
)
?
1
:
0
)
*
smooth_positives
-
smooth_negatives
);
}
for
(;
offset
<
classes
;
offset
+=
blockDim
.
x
)
gradInput
[
offset
]
=
tmpGradOutput
*
(
std
::
exp
(
static_cast
<
accscalar_t
>
(
logits
[
offset
])
-
coeff
)
-
static_cast
<
accscalar_t
>
((
offset
==
label
)
?
1
:
0
)
*
smooth_positives
-
smooth_negatives
);
}
template
<
template
<
typename
,
typename
,
typename
>
class
Epilogue
>
std
::
vector
<
Tensor
>
host_softmax_xentropy
(
const
Tensor
&
input_
,
const
Tensor
&
labels_
,
const
float
smoothing
,
const
bool
half_to_float
){
if
(
half_to_float
)
AT_ASSERTM
(
input_
.
type
().
scalarType
()
==
ScalarType
::
Half
,
"conversion is supported for Half type only"
);
AT_ASSERTM
(
labels_
.
type
().
scalarType
()
==
ScalarType
::
Long
,
"Label type should be CUDA Long"
);
auto
input
=
input_
.
contiguous
();
Tensor
max_log_sum_exp
=
at
::
empty_like
(
labels_
,
half_to_float
?
input
.
options
().
dtype
(
ScalarType
::
Float
)
:
input
.
options
());
Tensor
losses
=
at
::
empty_like
(
labels_
,
input_
.
options
().
dtype
(
ScalarType
::
Float
));
static_assert
(
std
::
is_same
<
acc_type
<
at
::
Half
,
true
>
,
float
>::
value
||
std
::
is_same
<
acc_type
<
at
::
Half
,
true
>
,
double
>::
value
,
"accscalar_t for half should be float or double"
);
AT_ASSERTM
(
input
.
dim
()
==
2
,
"Currently only 2 dim input supported"
);
AT_ASSERTM
(
labels_
.
dim
()
==
1
,
"Labels should be 1 dimensional"
);
AT_ASSERTM
(
input
.
size
(
0
)
==
labels_
.
size
(
0
),
"Input and label should have same number of examples"
);
AT_ASSERTM
(
input
.
numel
()
>
0
,
"Number of classes in input should not be 0"
);
const
int64_t
dim
=
1
;
int64_t
outer_size
=
1
;
int64_t
dim_size
=
input
.
size
(
dim
);
int64_t
inner_size
=
1
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
for
(
int64_t
i
=
0
;
i
<
dim
;
++
i
)
outer_size
*=
input
.
size
(
i
);
for
(
int64_t
i
=
dim
+
1
;
i
<
input
.
dim
();
++
i
)
inner_size
*=
input
.
size
(
i
);
// This kernel spawns a block per each element in the batch.
// XXX: it assumes that inner_size == 1
AT_CHECK
(
inner_size
==
1
,
"Currently only inner size 1 supported"
);
const
int
ILP
=
2
;
dim3
grid
(
outer_size
);
dim3
block
=
SoftMax_getBlockSize
(
ILP
,
dim_size
);
using
namespace
at
;
DISPATCH_FLOAT_AND_HALF
(
input
.
scalar_type
(),
0
,
"host_softmax_xentropy"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
if
(
!
half_to_float
)
{
cunn_SoftMaxXEntropyForward
<
ILP
,
scalar_t_0
,
accscalar_t
,
scalar_t_0
,
Epilogue
>
<<<
grid
,
block
,
2
*
block
.
x
*
sizeof
(
accscalar_t
),
stream
>>>
(
losses
.
data
<
accscalar_t
>
(),
max_log_sum_exp
.
data
<
scalar_t_0
>
(),
input
.
data
<
scalar_t_0
>
(),
labels_
.
data
<
int64_t
>
(),
dim_size
,
smoothing
);
}
else
{
cunn_SoftMaxXEntropyForward
<
ILP
,
scalar_t_0
,
accscalar_t
,
accscalar_t
,
Epilogue
>
<<<
grid
,
block
,
2
*
block
.
x
*
sizeof
(
accscalar_t
),
stream
>>>
(
losses
.
data
<
accscalar_t
>
(),
max_log_sum_exp
.
data
<
accscalar_t
>
(),
input
.
data
<
scalar_t_0
>
(),
labels_
.
data
<
int64_t
>
(),
dim_size
,
smoothing
);
}
);
THCudaCheck
(
cudaGetLastError
());
std
::
vector
<
at
::
Tensor
>
ret
=
{
losses
,
max_log_sum_exp
};
return
ret
;
}
template
<
template
<
typename
,
typename
,
typename
>
class
Epilogue
>
Tensor
host_softmax_xentropy_backward
(
const
at
::
Tensor
&
grad_loss
,
const
at
::
Tensor
&
logits_
,
const
at
::
Tensor
&
max_log_sum_exp
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
,
bool
half_to_float
)
{
const
int64_t
dim
=
1
;
Tensor
gI
=
at
::
empty_like
(
logits_
);
if
(
grad_loss
.
numel
()
==
0
)
{
return
gI
;
}
auto
grad
=
grad_loss
.
contiguous
();
auto
logits
=
logits_
.
contiguous
();
static_assert
(
std
::
is_same
<
acc_type
<
at
::
Half
,
true
>
,
float
>::
value
||
std
::
is_same
<
acc_type
<
at
::
Half
,
true
>
,
double
>::
value
,
"accscalar_t for half should be float or double"
);
if
(
grad
.
dim
()
==
0
)
grad
=
grad
.
view
(
1
);
AT_ASSERTM
(
logits_
.
dim
()
==
2
,
"Currently only 2 dim input supported"
);
AT_ASSERTM
(
labels
.
dim
()
==
1
,
"Labels should be 1 dimensional"
);
AT_ASSERTM
(
logits_
.
numel
()
>
0
,
"Number of classes in input should not be 0"
);
AT_ASSERTM
(
logits_
.
size
(
0
)
==
labels
.
size
(
0
),
"Input and label should have same number of examples"
);
AT_ASSERTM
(
labels
.
size
(
0
)
==
grad
.
size
(
0
),
"Label and loss should have same number of examples"
);
int64_t
outer_size
=
1
;
int64_t
dim_size
=
logits
.
size
(
dim
);
int64_t
inner_size
=
1
;
for
(
int64_t
i
=
0
;
i
<
dim
;
++
i
)
outer_size
*=
logits
.
size
(
i
);
for
(
int64_t
i
=
dim
+
1
;
i
<
logits
.
dim
();
++
i
)
inner_size
*=
logits
.
size
(
i
);
// See descriptions of kernels above.
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_CHECK
(
inner_size
==
1
,
"Currently only inner size 1 supported"
);
const
int
ILP
=
2
;
dim3
grid
(
outer_size
);
dim3
block
=
SoftMax_getBlockSize
(
ILP
,
dim_size
);
DISPATCH_FLOAT_AND_HALF
(
gI
.
scalar_type
(),
0
,
"host_softmax_xentropy_backward"
,
using
accscalar_t
=
acc_type
<
scalar_t_0
,
true
>
;
if
(
!
half_to_float
)
{
cunn_SoftMaxXEntropyBackward
<
ILP
,
scalar_t_0
,
accscalar_t
,
scalar_t_0
,
Epilogue
>
<<<
grid
,
block
,
block
.
x
*
sizeof
(
accscalar_t
),
stream
>>>
(
gI
.
data
<
scalar_t_0
>
(),
logits
.
data
<
scalar_t_0
>
(),
max_log_sum_exp
.
data
<
scalar_t_0
>
(),
grad
.
data
<
scalar_t_0
>
(),
labels
.
data
<
int64_t
>
(),
smoothing
,
dim_size
);
}
else
{
cunn_SoftMaxXEntropyBackward
<
ILP
,
scalar_t_0
,
accscalar_t
,
accscalar_t
,
Epilogue
>
<<<
grid
,
block
,
block
.
x
*
sizeof
(
accscalar_t
),
stream
>>>
(
gI
.
data
<
scalar_t_0
>
(),
logits
.
data
<
scalar_t_0
>
(),
max_log_sum_exp
.
data
<
accscalar_t
>
(),
grad
.
data
<
accscalar_t
>
(),
labels
.
data
<
int64_t
>
(),
smoothing
,
dim_size
);
}
);
THCudaCheck
(
cudaGetLastError
());
return
gI
;
}
std
::
vector
<
Tensor
>
softmax_xentropy_cuda
(
const
Tensor
&
input
,
const
Tensor
&
labels
,
const
float
smoothing
,
const
bool
half_to_float
){
return
host_softmax_xentropy
<
LogSoftMaxForwardEpilogue
>
(
input
,
labels
,
smoothing
,
half_to_float
);
}
at
::
Tensor
softmax_xentropy_backward_cuda
(
const
at
::
Tensor
&
grad_loss
,
const
at
::
Tensor
&
logits
,
const
at
::
Tensor
&
max_log_sum_exp
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
)
{
bool
half_to_float
=
grad_loss
.
type
().
scalarType
()
!=
logits
.
type
().
scalarType
();
if
(
half_to_float
)
{
AT_ASSERTM
((
grad_loss
.
type
().
scalarType
()
==
ScalarType
::
Float
&&
logits
.
type
().
scalarType
()
==
ScalarType
::
Half
),
"expected input and grad types to match, or input to be at::Half and grad to be at::Float"
);
}
return
host_softmax_xentropy_backward
<
LogSoftMaxBackwardEpilogue
>
(
grad_loss
,
logits
,
max_log_sum_exp
,
labels
,
smoothing
,
half_to_float
);
}
apex/contrib/groupbn/__init__.py
0 → 100644
View file @
4d6ed501
try
:
import
torch
import
bnp
from
.batch_norm
import
BatchNorm2d_NHWC
del
torch
del
bnp
del
batch_norm
except
ImportError
as
err
:
print
(
"apex was installed without --bnp flag, contrib.groupbn is not available"
)
apex/contrib/groupbn/batch_norm.py
0 → 100644
View file @
4d6ed501
import
torch
import
numpy
as
np
from
torch.nn.modules.batchnorm
import
_BatchNorm
import
bnp
class
bn_NHWC_impl
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
ret_cta
,
mom
,
epsilon
,
fuse_relu
,
is_train
,
bn_group
,
my_data
,
pair_data
,
magic
,
pair_data2
,
pair_data3
,
fwd_occup
,
fwd_grid_x
,
bwd_occup
,
bwd_grid_x
,
multi_stream
):
if
is_train
:
ctx
.
save_for_backward
(
x
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
)
ctx
.
epsilon
=
epsilon
ctx
.
momentum
=
mom
ctx
.
ret_cta
=
ret_cta
ctx
.
fuse_relu
=
fuse_relu
ctx
.
my_data
=
my_data
ctx
.
pair_data
=
pair_data
ctx
.
magic
=
magic
ctx
.
pair_data2
=
pair_data2
ctx
.
pair_data3
=
pair_data3
ctx
.
bn_group
=
bn_group
ctx
.
bwd_occup
=
bwd_occup
ctx
.
bwd_grid_x
=
bwd_grid_x
ctx
.
multi_stream
=
multi_stream
res
=
bnp
.
bn_fwd_nhwc
(
x
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
ret_cta
,
mom
,
epsilon
,
fuse_relu
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
magic
,
fwd_occup
,
fwd_grid_x
,
multi_stream
)
return
res
else
:
return
bnp
.
bn_fwd_eval_nhwc
(
x
,
s
,
b
,
rm
,
riv
,
ret_cta
,
bn_group
,
mom
,
epsilon
,
fuse_relu
)
@
staticmethod
def
backward
(
ctx
,
grad_y
):
x
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
=
ctx
.
saved_variables
epsilon
=
ctx
.
epsilon
mom
=
ctx
.
momentum
ret_cta
=
ctx
.
ret_cta
fuse_relu
=
ctx
.
fuse_relu
my_data
=
ctx
.
my_data
pair_data
=
ctx
.
pair_data
magic
=
ctx
.
magic
pair_data2
=
ctx
.
pair_data2
pair_data3
=
ctx
.
pair_data3
bn_group
=
ctx
.
bn_group
bwd_occup
=
ctx
.
bwd_occup
bwd_grid_x
=
ctx
.
bwd_grid_x
multi_stream
=
ctx
.
multi_stream
dx
,
dscale
,
dbias
=
bnp
.
bn_bwd_nhwc
(
x
,
grad_y
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
ret_cta
,
mom
,
epsilon
,
fuse_relu
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
magic
,
bwd_occup
,
bwd_grid_x
,
multi_stream
)
return
dx
,
dscale
,
dbias
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
bn_addrelu_NHWC_impl
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
z
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
grid_dim_y
,
ret_cta
,
mom
,
epsilon
,
is_train
,
bn_group
,
my_data
,
pair_data
,
magic
,
pair_data2
,
pair_data3
,
fwd_occup
,
fwd_grid_x
,
bwd_occup
,
bwd_grid_x
,
multi_stream
):
if
is_train
:
bitmask
=
torch
.
cuda
.
IntTensor
(((
x
.
numel
()
+
31
)
//
32
)
*
2
*
grid_dim_y
)
ctx
.
save_for_backward
(
x
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
bitmask
)
ctx
.
epsilon
=
epsilon
ctx
.
momentum
=
mom
ctx
.
ret_cta
=
ret_cta
ctx
.
my_data
=
my_data
ctx
.
pair_data
=
pair_data
ctx
.
magic
=
magic
ctx
.
pair_data2
=
pair_data2
ctx
.
pair_data3
=
pair_data3
ctx
.
bn_group
=
bn_group
ctx
.
bwd_occup
=
bwd_occup
ctx
.
bwd_grid_x
=
bwd_grid_x
ctx
.
multi_stream
=
multi_stream
res
=
bnp
.
bn_addrelu_fwd_nhwc
(
x
,
z
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
bitmask
,
ret_cta
,
mom
,
epsilon
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
magic
,
fwd_occup
,
fwd_grid_x
,
multi_stream
)
return
res
else
:
return
bnp
.
bn_addrelu_fwd_eval_nhwc
(
x
,
z
,
s
,
b
,
rm
,
riv
,
ret_cta
,
bn_group
,
mom
,
epsilon
)
@
staticmethod
def
backward
(
ctx
,
grad_y
):
x
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
bitmask
=
ctx
.
saved_variables
epsilon
=
ctx
.
epsilon
mom
=
ctx
.
momentum
ret_cta
=
ctx
.
ret_cta
my_data
=
ctx
.
my_data
pair_data
=
ctx
.
pair_data
magic
=
ctx
.
magic
pair_data2
=
ctx
.
pair_data2
pair_data3
=
ctx
.
pair_data3
bn_group
=
ctx
.
bn_group
bwd_occup
=
ctx
.
bwd_occup
bwd_grid_x
=
ctx
.
bwd_grid_x
multi_stream
=
ctx
.
multi_stream
dx
,
dz
,
dscale
,
dbias
=
bnp
.
bn_addrelu_bwd_nhwc
(
x
,
grad_y
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
bitmask
,
ret_cta
,
mom
,
epsilon
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
magic
,
bwd_occup
,
bwd_grid_x
,
multi_stream
)
return
dx
,
dz
,
dscale
,
dbias
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
BatchNorm2d_NHWC
(
_BatchNorm
):
# if using BatchNorm2d_NHWC simultaneously with multiple streams set multi_stream to True
def
__init__
(
self
,
num_features
,
fuse_relu
=
False
,
bn_group
=
1
,
max_cta_per_sm
=
2
,
cta_launch_margin
=
12
,
multi_stream
=
False
):
super
(
BatchNorm2d_NHWC
,
self
).
__init__
(
num_features
)
self
.
fuse_relu
=
fuse_relu
self
.
multi_stream
=
multi_stream
self
.
minibatch_mean
=
torch
.
cuda
.
FloatTensor
(
num_features
)
self
.
minibatch_riv
=
torch
.
cuda
.
FloatTensor
(
num_features
)
#defaut to distributed bn disabled
self
.
bn_group
=
bn_group
self
.
max_cta_per_sm
=
max_cta_per_sm
#used only in training fwd and bwd
self
.
cta_launch_margin
=
cta_launch_margin
#used only in training fwd and bwd
self
.
my_data
=
None
self
.
pair_data
=
None
self
.
pair_data2
=
None
self
.
pair_data3
=
None
self
.
local_rank
=
0
self
.
magic
=
torch
.
IntTensor
([
0
])
#calculate cta per sm occupancies
assert
(
max_cta_per_sm
>
0
)
# won't be able to do much with 0 CTAs :)
self
.
fwd_occupancy
=
min
(
bnp
.
bn_fwd_nhwc_occupancy
(),
max_cta_per_sm
)
self
.
bwd_occupancy
=
min
(
bnp
.
bn_bwd_nhwc_occupancy
(),
max_cta_per_sm
)
self
.
addrelu_fwd_occupancy
=
min
(
bnp
.
bn_addrelu_fwd_nhwc_occupancy
(),
max_cta_per_sm
)
self
.
addrelu_bwd_occupancy
=
min
(
bnp
.
bn_addrelu_bwd_nhwc_occupancy
(),
max_cta_per_sm
)
#calculate grid dimentions based on occupancy numbers
mp_count
=
torch
.
cuda
.
get_device_properties
(
None
).
multi_processor_count
self
.
fwd_grid_dim_x
=
max
(
mp_count
*
self
.
fwd_occupancy
-
cta_launch_margin
,
1
)
self
.
bwd_grid_dim_x
=
max
(
mp_count
*
self
.
bwd_occupancy
-
cta_launch_margin
,
1
)
self
.
addrelu_fwd_grid_dim_x
=
max
(
mp_count
*
self
.
addrelu_fwd_occupancy
-
cta_launch_margin
,
1
)
self
.
addrelu_bwd_grid_dim_x
=
max
(
mp_count
*
self
.
addrelu_bwd_occupancy
-
cta_launch_margin
,
1
)
self
.
grid_dim_y
=
(
num_features
+
63
)
//
64
# allocate scratch space used by implementation
# TODO: scratch space that is not supposed to be exposed at user code. We only need one time initialization, the
# same buffer could be reused in future iterations. Currently we exposed it here instead of requesting new
# buffer from cache allocator to avoid unnecessary initialization at future iterations.
self
.
ret_cta
=
torch
.
cuda
.
ByteTensor
(
8192
).
fill_
(
0
)
#FIXME: turn pair handles into an array
if
bn_group
>
1
:
local_rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
assert
(
world_size
>=
bn_group
)
assert
(
world_size
%
bn_group
==
0
)
bn_sync_steps
=
1
if
(
bn_group
==
4
):
bn_sync_steps
=
2
if
(
bn_group
==
8
):
bn_sync_steps
=
3
self
.
ipc_buffer
=
torch
.
cuda
.
ByteTensor
(
bnp
.
get_buffer_size
(
bn_sync_steps
))
self
.
my_data
=
bnp
.
get_data_ptr
(
self
.
ipc_buffer
)
# we are walking on very thin ice here by utilizing internal `_share_cuda_()`
self
.
storage
=
self
.
ipc_buffer
.
storage
()
self
.
share_cuda
=
self
.
storage
.
_share_cuda_
()
internal_cuda_mem
=
self
.
share_cuda
# internal_cuda_mem[1]: ipc_mem_handle
my_handle
=
torch
.
cuda
.
ByteTensor
(
np
.
frombuffer
(
internal_cuda_mem
[
1
],
dtype
=
np
.
uint8
))
# internal_cuda_mem[3]: offset
my_offset
=
torch
.
cuda
.
IntTensor
([
internal_cuda_mem
[
3
]])
handles_all
=
torch
.
empty
(
world_size
,
my_handle
.
size
(
0
),
dtype
=
my_handle
.
dtype
,
device
=
my_handle
.
device
)
handles_l
=
list
(
handles_all
.
unbind
(
0
))
torch
.
distributed
.
all_gather
(
handles_l
,
my_handle
)
offsets_all
=
torch
.
empty
(
world_size
,
my_offset
.
size
(
0
),
dtype
=
my_offset
.
dtype
,
device
=
my_offset
.
device
)
offsets_l
=
list
(
offsets_all
.
unbind
(
0
))
torch
.
distributed
.
all_gather
(
offsets_l
,
my_offset
)
#whom do I actually care about? that would be local_rank XOR 1
self
.
pair_handle
=
handles_l
[
local_rank
^
1
].
cpu
().
contiguous
()
pair_offset
=
offsets_l
[
local_rank
^
1
].
cpu
()
self
.
pair_data
=
bnp
.
get_remote_data_ptr
(
self
.
pair_handle
,
pair_offset
)
if
bn_group
>
2
:
self
.
pair_handle2
=
handles_l
[
local_rank
^
2
].
cpu
().
contiguous
()
pair_offset2
=
offsets_l
[
local_rank
^
2
].
cpu
()
self
.
pair_data2
=
bnp
.
get_remote_data_ptr
(
self
.
pair_handle2
,
pair_offset2
)
if
bn_group
>
4
:
self
.
pair_handle3
=
handles_l
[
local_rank
^
4
].
cpu
().
contiguous
()
pair_offset3
=
offsets_l
[
local_rank
^
4
].
cpu
()
self
.
pair_data3
=
bnp
.
get_remote_data_ptr
(
self
.
pair_handle3
,
pair_offset3
)
#FIXME: get magic value into C code and eliminate from here
self
.
magic
=
torch
.
IntTensor
([
2
])
self
.
local_rank
=
local_rank
def
forward
(
self
,
x
,
z
=
None
):
if
z
is
not
None
:
assert
(
self
.
fuse_relu
==
True
)
return
bn_addrelu_NHWC_impl
.
apply
(
x
,
z
,
self
.
weight
,
self
.
bias
,
self
.
running_mean
,
self
.
running_var
,
self
.
minibatch_mean
,
self
.
minibatch_riv
,
self
.
grid_dim_y
,
self
.
ret_cta
,
self
.
momentum
,
self
.
eps
,
self
.
training
,
self
.
bn_group
,
self
.
my_data
,
self
.
pair_data
,
(
self
.
magic
),
self
.
pair_data2
,
self
.
pair_data3
,
self
.
addrelu_fwd_occupancy
,
self
.
addrelu_fwd_grid_dim_x
,
self
.
addrelu_bwd_occupancy
,
self
.
addrelu_bwd_grid_dim_x
,
self
.
multi_stream
)
else
:
return
bn_NHWC_impl
.
apply
(
x
,
self
.
weight
,
self
.
bias
,
self
.
running_mean
,
self
.
running_var
,
self
.
minibatch_mean
,
self
.
minibatch_riv
,
self
.
ret_cta
,
self
.
momentum
,
self
.
eps
,
self
.
fuse_relu
,
self
.
training
,
self
.
bn_group
,
self
.
my_data
,
self
.
pair_data
,
(
self
.
magic
),
self
.
pair_data2
,
self
.
pair_data3
,
self
.
fwd_occupancy
,
self
.
fwd_grid_dim_x
,
self
.
bwd_occupancy
,
self
.
bwd_grid_dim_x
,
self
.
multi_stream
)
def
__del__
(
self
):
if
self
.
bn_group
>
1
:
bnp
.
close_remote_data
(
self
.
pair_handle
)
if
self
.
bn_group
>
2
:
bnp
.
close_remote_data
(
self
.
pair_handle2
)
if
self
.
bn_group
>
4
:
bnp
.
close_remote_data
(
self
.
pair_handle3
)
apex/contrib/test/test_label_smoothing.py
0 → 100644
View file @
4d6ed501
import
torch
from
apex.contrib
import
xentropy
as
label_smoothing
import
unittest
import
warnings
import
random
import
numpy
as
np
import
time
def
label_smoothing_raw
(
x
,
target
,
padding_idx
,
smoothing
):
logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
x
,
dim
=-
1
,
dtype
=
torch
.
float32
)
non_pad_mask
=
(
target
!=
padding_idx
)
nll_loss
=
-
logprobs
.
gather
(
dim
=-
1
,
index
=
target
.
unsqueeze
(
1
))
nll_loss
=
nll_loss
.
squeeze
(
1
)[
non_pad_mask
]
smooth_loss
=
-
logprobs
.
mean
(
dim
=-
1
)[
non_pad_mask
]
loss
=
(
1.0
-
smoothing
)
*
nll_loss
+
smoothing
*
smooth_loss
return
loss
def
label_smoothing_opt_1
(
x
,
target
,
padding_idx
,
smoothing
):
logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
x
,
dim
=-
1
,
dtype
=
torch
.
float32
)
pad_mask
=
(
target
==
padding_idx
)
ll_loss
=
logprobs
.
gather
(
dim
=-
1
,
index
=
target
.
unsqueeze
(
1
)).
squeeze
(
1
)
smooth_loss
=
logprobs
.
mean
(
dim
=-
1
)
loss
=
(
smoothing
-
1.0
)
*
ll_loss
-
smoothing
*
smooth_loss
loss
.
masked_fill_
(
pad_mask
,
0
)
return
loss
class
LabelSmoothingTest
(
unittest
.
TestCase
):
def
setUp
(
self
,
seed
=
1234
):
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
# Set pytorch print precision
torch
.
set_printoptions
(
precision
=
10
)
def
gen_test_inputs
(
self
,
N
,
T
,
H
,
smoothing
,
padding_idx
):
logits
=
torch
.
randn
((
N
*
T
,
H
),
dtype
=
torch
.
half
,
device
=
'cuda'
,
requires_grad
=
True
)
labels
=
torch
.
randint
(
0
,
H
,
[
N
*
T
],
device
=
'cuda'
)
for
i
in
random
.
sample
(
range
(
N
*
T
),
N
*
T
//
6
):
labels
[
i
]
=
padding_idx
half_to_float
=
(
logits
.
dtype
==
torch
.
half
)
return
logits
,
labels
,
half_to_float
def
print_max_diff_elem
(
self
,
ref
,
tst
):
ref
,
tst
=
ref
.
flatten
(),
tst
.
flatten
()
diff
=
(
ref
-
tst
).
abs
().
max
()
idx
=
(
ref
-
tst
).
abs
().
argmax
()
print
(
"Max atol idx: {}, diff: {:.6f}, ref: {:.6f}, tst: {:.6f}"
.
format
(
idx
,
diff
,
ref
[
idx
],
tst
[
idx
]))
def
test_label_smoothing_function
(
self
):
# Set label smoothing configuration
smoothing
,
padding_idx
=
0.1
,
0
N
,
T
,
H
=
128
,
74
,
32320
iters
=
10
loss_func
=
label_smoothing
.
SoftmaxCrossEntropyLoss
.
apply
for
i
in
range
(
iters
):
logits
,
labels
,
half_to_float
=
self
.
gen_test_inputs
(
N
,
T
,
H
,
smoothing
,
padding_idx
)
# Run original softmax cross entropy with label smoothing
logits
.
grad
=
None
losses
=
label_smoothing_raw
(
logits
,
labels
,
padding_idx
,
smoothing
)
loss
=
losses
.
sum
()
loss
.
backward
()
ref_loss
=
loss
.
clone
().
detach
()
ref_grad
=
logits
.
grad
.
clone
().
detach
()
# Run optimized softmax cross entropy with label smoothing
logits
.
grad
=
None
losses
=
loss_func
(
logits
,
labels
,
smoothing
,
padding_idx
,
half_to_float
)
loss
=
losses
.
sum
()
loss
.
backward
()
val_loss
=
loss
.
clone
().
detach
()
val_grad
=
logits
.
grad
.
clone
().
detach
()
# Validate
self
.
print_max_diff_elem
(
ref_grad
,
val_grad
)
self
.
assertTrue
(
torch
.
allclose
(
ref_loss
,
val_loss
,
atol
=
1e-5
,
rtol
=
1e-5
))
self
.
assertTrue
(
torch
.
allclose
(
ref_grad
,
val_grad
,
atol
=
1e-5
,
rtol
=
1e-5
))
def
test_label_smoothing_perf
(
self
):
# Set label smoothing configuration
smoothing
,
padding_idx
=
0.1
,
0
N
,
T
,
H
=
128
,
74
,
32320
iters
=
1000
loss_func
=
label_smoothing
.
SoftmaxCrossEntropyLoss
.
apply
print
()
logits
,
labels
,
half_to_float
=
self
.
gen_test_inputs
(
N
,
T
,
H
,
smoothing
,
padding_idx
)
# Run original softmax cross entropy with label smoothing
torch
.
cuda
.
synchronize
()
ts
=
time
.
time
()
for
i
in
range
(
iters
):
logits
.
grad
=
None
losses
=
label_smoothing_raw
(
logits
,
labels
,
padding_idx
,
smoothing
)
loss
=
losses
.
sum
()
/
N
loss
.
backward
()
torch
.
cuda
.
synchronize
()
print
(
"Raw time {:.2f} s elapsed for {} iterations, norm {:.4f}"
.
format
(
time
.
time
()
-
ts
,
iters
,
logits
.
grad
.
norm
()))
# Run optimized softmax cross entropy with label smoothing
torch
.
cuda
.
synchronize
()
ts
=
time
.
time
()
for
i
in
range
(
iters
):
logits
.
grad
=
None
losses
=
loss_func
(
logits
,
labels
,
smoothing
,
padding_idx
,
half_to_float
)
loss
=
losses
.
sum
()
/
N
loss
.
backward
()
torch
.
cuda
.
synchronize
()
print
(
"Opt time {:.2f} s elapsed for {} iterations, norm {:.4f}"
.
format
(
time
.
time
()
-
ts
,
iters
,
logits
.
grad
.
norm
()))
if
__name__
==
'__main__'
:
unittest
.
main
()
apex/contrib/xentropy/__init__.py
0 → 100644
View file @
4d6ed501
try
:
import
torch
import
xentropy_cuda
from
.softmax_xentropy
import
SoftmaxCrossEntropyLoss
del
torch
del
xentropy_cuda
del
softmax_xentropy
except
ImportError
as
err
:
print
(
"apex was installed without --xentropy flag, contrib.xentropy is not available"
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment