Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
4d6ed501
Commit
4d6ed501
authored
Aug 12, 2019
by
Deyu Fu
Browse files
Merge branch 'multi_tensor_sgd' into deyuf/fused_optimizer_v2
parents
690b1f71
9f64bf27
Changes
40
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
6397 additions
and
142 deletions
+6397
-142
README.md
README.md
+4
-4
apex/amp/_initialize.py
apex/amp/_initialize.py
+4
-25
apex/amp/_process_optimizer.py
apex/amp/_process_optimizer.py
+229
-90
apex/amp/handle.py
apex/amp/handle.py
+9
-11
apex/amp/scaler.py
apex/amp/scaler.py
+19
-12
apex/contrib/__init__.py
apex/contrib/__init__.py
+0
-0
apex/contrib/csrc/groupbn/batch_norm.cu
apex/contrib/csrc/groupbn/batch_norm.cu
+331
-0
apex/contrib/csrc/groupbn/batch_norm.h
apex/contrib/csrc/groupbn/batch_norm.h
+734
-0
apex/contrib/csrc/groupbn/batch_norm_add_relu.cu
apex/contrib/csrc/groupbn/batch_norm_add_relu.cu
+343
-0
apex/contrib/csrc/groupbn/batch_norm_add_relu.h
apex/contrib/csrc/groupbn/batch_norm_add_relu.h
+681
-0
apex/contrib/csrc/groupbn/cuda_utils.h
apex/contrib/csrc/groupbn/cuda_utils.h
+20
-0
apex/contrib/csrc/groupbn/interface.cpp
apex/contrib/csrc/groupbn/interface.cpp
+175
-0
apex/contrib/csrc/groupbn/ipc.cu
apex/contrib/csrc/groupbn/ipc.cu
+130
-0
apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h
apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h
+2685
-0
apex/contrib/csrc/xentropy/interface.cpp
apex/contrib/csrc/xentropy/interface.cpp
+52
-0
apex/contrib/csrc/xentropy/xentropy_kernel.cu
apex/contrib/csrc/xentropy/xentropy_kernel.cu
+610
-0
apex/contrib/groupbn/__init__.py
apex/contrib/groupbn/__init__.py
+9
-0
apex/contrib/groupbn/batch_norm.py
apex/contrib/groupbn/batch_norm.py
+225
-0
apex/contrib/test/test_label_smoothing.py
apex/contrib/test/test_label_smoothing.py
+128
-0
apex/contrib/xentropy/__init__.py
apex/contrib/xentropy/__init__.py
+9
-0
No files found.
README.md
View file @
4d6ed501
apex/amp/_initialize.py
View file @
4d6ed501
...
@@ -124,29 +124,13 @@ def check_optimizers(optimizers):
...
@@ -124,29 +124,13 @@ def check_optimizers(optimizers):
raise
RuntimeError
(
"An incoming optimizer is an instance of {}. "
.
format
(
bad_optim_type
)
+
raise
RuntimeError
(
"An incoming optimizer is an instance of {}. "
.
format
(
bad_optim_type
)
+
"The optimizer(s) passed to amp.initialize() must be bare
\n
"
"The optimizer(s) passed to amp.initialize() must be bare
\n
"
"instances of either ordinary Pytorch optimizers, or Apex fused
\n
"
"instances of either ordinary Pytorch optimizers, or Apex fused
\n
"
"optimizers (
currently just
FusedAdam
, but
FusedSGD
will be added
\n
"
"optimizers (FusedAdam
or
FusedSGD
).
\n
"
"
soon).
You should not manually wrap your optimizer in either
\n
"
"You should not manually wrap your optimizer in either
\n
"
"apex.fp16_utils.FP16_Optimizer or apex.optimizers.FP16_Optimizer.
\n
"
"apex.fp16_utils.FP16_Optimizer or apex.optimizers.FP16_Optimizer.
\n
"
"amp.initialize will take care of that for you (if necessary) based
\n
"
"amp.initialize will take care of that for you (if necessary) based
\n
"
"on the specified opt_level (and optional overridden properties)."
)
"on the specified opt_level (and optional overridden properties)."
)
def
wrap_fused_adam
(
optimizer
,
properties
):
msg
=
'Currently, the usage of FusedAdam is restricted to '
\
'amp.initialize(..., opt_level="O2", keep_batchnorm_fp32=False, '
\
'loss_scale=float or "dynamic"). We are working on enabling more general usage.'
assert
properties
.
master_weights
is
True
,
msg
assert
properties
.
cast_model_type
is
torch
.
float16
,
msg
assert
(
properties
.
keep_batchnorm_fp32
is
False
or
properties
.
keep_batchnorm_fp32
is
None
),
msg
if
properties
.
loss_scale
==
"dynamic"
:
return
FP16_Optimizer_for_fused
(
optimizer
,
dynamic_loss_scale
=
True
)
else
:
return
FP16_Optimizer_for_fused
(
optimizer
,
static_loss_scale
=
properties
.
loss_scale
)
def
_initialize
(
models
,
optimizers
,
properties
,
num_losses
=
1
,
cast_model_outputs
=
None
):
def
_initialize
(
models
,
optimizers
,
properties
,
num_losses
=
1
,
cast_model_outputs
=
None
):
from
apex.parallel
import
DistributedDataParallel
as
apex_DDP
from
apex.parallel
import
DistributedDataParallel
as
apex_DDP
from
.amp
import
init
as
amp_init
from
.amp
import
init
as
amp_init
...
@@ -176,7 +160,6 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
...
@@ -176,7 +160,6 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
if
not
_amp_state
.
allow_incoming_model_not_fp32
:
if
not
_amp_state
.
allow_incoming_model_not_fp32
:
check_params_fp32
(
models
)
check_params_fp32
(
models
)
# In the future, when FP16_Optimizer can be deprecated and master weights can
# In the future, when FP16_Optimizer can be deprecated and master weights can
# become an attribute, remember to stash master weights before casting the model.
# become an attribute, remember to stash master weights before casting the model.
...
@@ -223,10 +206,6 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
...
@@ -223,10 +206,6 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
model
.
forward
=
patch_forward
(
model
.
forward
)
model
.
forward
=
patch_forward
(
model
.
forward
)
for
i
,
optimizer
in
enumerate
(
optimizers
):
for
i
,
optimizer
in
enumerate
(
optimizers
):
# Still need to special case this for the first pass
if
isinstance
(
optimizer
,
FusedAdam
):
optimizers
[
i
]
=
wrap_fused_adam
(
optimizer
,
properties
)
else
:
optimizers
[
i
]
=
_process_optimizer
(
optimizer
,
properties
)
optimizers
[
i
]
=
_process_optimizer
(
optimizer
,
properties
)
_amp_state
.
loss_scalers
=
[]
_amp_state
.
loss_scalers
=
[]
...
...
apex/amp/_process_optimizer.py
View file @
4d6ed501
...
@@ -3,6 +3,7 @@ from ..fp16_utils import master_params_to_model_params
...
@@ -3,6 +3,7 @@ from ..fp16_utils import master_params_to_model_params
from
..multi_tensor_apply
import
multi_tensor_applier
from
..multi_tensor_apply
import
multi_tensor_applier
from
._amp_state
import
maybe_print
from
._amp_state
import
maybe_print
import
torch
import
torch
from
..optimizers
import
FusedAdam
,
FusedSGD
class
AmpOptimizerState
(
object
):
class
AmpOptimizerState
(
object
):
...
@@ -10,6 +11,20 @@ class AmpOptimizerState(object):
...
@@ -10,6 +11,20 @@ class AmpOptimizerState(object):
pass
pass
def
_master_params_to_model_params
(
self
):
stash
=
self
.
_amp_stash
if
multi_tensor_applier
.
available
:
if
len
(
stash
.
all_fp16_params
)
>
0
:
multi_tensor_applier
(
stash
.
multi_tensor_scale
,
stash
.
dummy_overflow_buf
,
[
stash
.
all_fp32_from_fp16_params
,
stash
.
all_fp16_params
],
1.0
)
else
:
for
fp16_group
,
fp32_from_fp16_group
in
zip
(
stash
.
fp16_groups
,
stash
.
fp32_from_fp16_groups
):
master_params_to_model_params
(
fp16_group
,
fp32_from_fp16_group
)
def
lazy_init_with_master_weights
(
self
):
def
lazy_init_with_master_weights
(
self
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
stash
.
fp16_groups
=
[]
stash
.
fp16_groups
=
[]
...
@@ -60,6 +75,8 @@ def lazy_init_with_master_weights(self):
...
@@ -60,6 +75,8 @@ def lazy_init_with_master_weights(self):
for
group
in
stash
.
fp32_from_fp32_groups
:
for
group
in
stash
.
fp32_from_fp32_groups
:
stash
.
all_fp32_from_fp32_params
+=
group
stash
.
all_fp32_from_fp32_params
+=
group
# all_fp16_grad_stash is only needed for fused optimizers.
stash
.
all_fp16_grad_stash
=
[
None
for
_
in
stash
.
all_fp16_params
]
# stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
# stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
stash
.
all_fp32_from_fp32_grad_stash
=
[
None
for
_
in
stash
.
all_fp32_from_fp32_params
]
stash
.
all_fp32_from_fp32_grad_stash
=
[
None
for
_
in
stash
.
all_fp32_from_fp32_params
]
...
@@ -73,15 +90,55 @@ def lazy_init_with_master_weights(self):
...
@@ -73,15 +90,55 @@ def lazy_init_with_master_weights(self):
self
.
load_state_dict
(
self
.
state_dict
())
self
.
load_state_dict
(
self
.
state_dict
())
def
post_backward_models_are_masters
(
scaler
,
params
,
stashed_grads
,
scale_override
=
None
):
grads_have_scale
,
stashed_have_scale
,
out_scale
=
scaler
.
loss_scale
(),
1.0
,
1.0
if
scale_override
is
not
None
:
grads_have_scale
,
stashed_have_scale
,
out_scale
=
scale_override
# This is a lot of python overhead...
grads_needing_unscale
=
[]
grads_needing_unscale_with_stash
=
[]
stashed
=
[]
for
param
,
stashed_grad
in
zip
(
params
,
stashed_grads
):
if
param
.
grad
is
None
and
stashed_grad
is
not
None
:
param
.
grad
=
stashed_grad
elif
param
.
grad
is
not
None
and
stashed_grad
is
None
:
grads_needing_unscale
.
append
(
param
.
grad
)
elif
param
.
grad
is
not
None
and
stashed_grad
is
not
None
:
grads_needing_unscale_with_stash
.
append
(
param
.
grad
)
stashed
.
append
(
stashed_grad
)
else
:
# param.grad is None and stashed_grad is None
continue
# unscale() implements grads*(1/scale), so "scale" should be grads_have_scale/out_scale.
if
len
(
grads_needing_unscale
)
>
0
:
scaler
.
unscale
(
grads_needing_unscale
,
grads_needing_unscale
,
None
,
# unused_scale, currently present to avoid API breakage elsewhere
models_are_masters
=
True
,
scale_override
=
grads_have_scale
/
out_scale
)
if
len
(
grads_needing_unscale_with_stash
)
>
0
:
scaler
.
unscale_with_stashed
(
grads_needing_unscale_with_stash
,
stashed
,
grads_needing_unscale_with_stash
,
scale_override
=
(
grads_have_scale
,
stashed_have_scale
,
out_scale
))
# Clear the stash.
for
i
in
range
(
len
(
stashed_grads
)):
stashed_grads
[
i
]
=
None
def
prepare_backward_with_master_weights
(
self
):
def
prepare_backward_with_master_weights
(
self
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
if
not
stash
.
lazy_init_called
:
self
.
_amp_lazy_init
()
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
# Set up to leverage grad copy elision:
# Set up to leverage grad copy elision.
# This may behave differently from an unpatched optimizer if zero_grad is used and the param is unused.
param
.
grad
=
None
param
.
grad
=
None
# for i, param in enumerate(stash.all_fp32_from_fp16_params):
# for i, param in enumerate(stash.all_fp32_from_fp16_params):
...
@@ -96,6 +153,8 @@ def prepare_backward_with_master_weights(self):
...
@@ -96,6 +153,8 @@ def prepare_backward_with_master_weights(self):
def
post_backward_with_master_weights
(
self
,
scaler
):
def
post_backward_with_master_weights
(
self
,
scaler
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
self
.
_amp_lazy_init
()
# This is a lot of python overhead...
# This is a lot of python overhead...
fp16_grads_needing_unscale
=
[]
fp16_grads_needing_unscale
=
[]
new_fp32_grads
=
[]
new_fp32_grads
=
[]
...
@@ -129,37 +188,10 @@ def post_backward_with_master_weights(self, scaler):
...
@@ -129,37 +188,10 @@ def post_backward_with_master_weights(self, scaler):
preexisting_fp32_grads
)
preexisting_fp32_grads
)
# fp32 params can be treated as they would be in the "no_master_weights" case.
# fp32 params can be treated as they would be in the "no_master_weights" case.
grads_needing_unscale
=
[]
post_backward_models_are_masters
(
grads_needing_unscale_with_stash
=
[]
scaler
,
stashed
=
[]
stash
.
all_fp32_from_fp32_params
,
for
param
,
stashed_grad
in
zip
(
stash
.
all_fp32_from_fp32_params
,
stash
.
all_fp32_from_fp32_grad_stash
)
stash
.
all_fp32_from_fp32_grad_stash
):
if
param
.
grad
is
None
and
stashed_grad
is
not
None
:
param
.
grad
=
stashed_grad
elif
param
.
grad
is
not
None
and
stashed_grad
is
None
:
grads_needing_unscale
.
append
(
param
.
grad
)
elif
param
.
grad
is
not
None
and
stashed_grad
is
not
None
:
grads_needing_unscale_with_stash
.
append
(
param
.
grad
)
stashed
.
append
(
stashed_grad
)
else
:
# param.grad is None and stashed_grad is None:
continue
if
len
(
grads_needing_unscale
)
>
0
:
scaler
.
unscale
(
grads_needing_unscale
,
grads_needing_unscale
,
scaler
.
loss_scale
(),
models_are_masters
=
True
)
if
len
(
grads_needing_unscale_with_stash
)
>
0
:
scaler
.
unscale_with_stashed
(
grads_needing_unscale_with_stash
,
stashed
,
grads_needing_unscale_with_stash
)
# Clear the stash.
for
i
in
range
(
len
(
stash
.
all_fp32_from_fp32_grad_stash
)):
stash
.
all_fp32_from_fp32_grad_stash
[
i
]
=
None
def
lazy_init_no_master_weights
(
self
):
def
lazy_init_no_master_weights
(
self
):
...
@@ -184,9 +216,7 @@ def lazy_init_no_master_weights(self):
...
@@ -184,9 +216,7 @@ def lazy_init_no_master_weights(self):
def
prepare_backward_no_master_weights
(
self
):
def
prepare_backward_no_master_weights
(
self
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
if
not
stash
.
lazy_init_called
:
self
.
_amp_lazy_init
()
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
stash
.
all_fp16_grad_stash
[
i
]
=
param
.
grad
stash
.
all_fp16_grad_stash
[
i
]
=
param
.
grad
...
@@ -202,55 +232,141 @@ def prepare_backward_no_master_weights(self):
...
@@ -202,55 +232,141 @@ def prepare_backward_no_master_weights(self):
def
post_backward_no_master_weights
(
self
,
scaler
):
def
post_backward_no_master_weights
(
self
,
scaler
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
self
.
_amp_lazy_init
()
split_types
=
((
stash
.
all_fp16_params
,
stash
.
all_fp16_grad_stash
),
split_types
=
((
stash
.
all_fp16_params
,
stash
.
all_fp16_grad_stash
),
(
stash
.
all_fp32_params
,
stash
.
all_fp32_grad_stash
))
(
stash
.
all_fp32_params
,
stash
.
all_fp32_grad_stash
))
for
params
,
stashed_grads
in
split_types
:
for
params
,
stashed_grads
in
split_types
:
# This is a lot of python overhead...
post_backward_models_are_masters
(
scaler
,
params
,
stashed_grads
)
grads_needing_unscale
=
[]
grads_needing_unscale_with_stash
=
[]
stashed
=
[]
for
param
,
stashed_grad
in
zip
(
params
,
stashed_grads
):
if
param
.
grad
is
None
and
stashed_grad
is
not
None
:
param
.
grad
=
stashed_grad
elif
param
.
grad
is
not
None
and
stashed_grad
is
None
:
grads_needing_unscale
.
append
(
param
.
grad
)
elif
param
.
grad
is
not
None
and
stashed_grad
is
not
None
:
grads_needing_unscale_with_stash
.
append
(
param
.
grad
)
stashed
.
append
(
stashed_grad
)
else
:
# param.grad is None and stashed_grad is None
continue
if
len
(
grads_needing_unscale
)
>
0
:
scaler
.
unscale
(
grads_needing_unscale
,
grads_needing_unscale
,
scaler
.
loss_scale
(),
models_are_masters
=
True
)
if
len
(
grads_needing_unscale_with_stash
)
>
0
:
#####################################################################################
scaler
.
unscale_with_stashed
(
# FusedAdam versions
grads_needing_unscale_with_stash
,
#####################################################################################
stashed
,
grads_needing_unscale_with_stash
)
# Clear the stash.
def
prepare_backward_with_master_weights_FusedAdam
(
self
):
for
i
in
range
(
len
(
stashed_grads
)):
stash
=
self
.
_amp_stash
stashed_grads
[
i
]
=
None
self
.
_amp_lazy_init
()
def
_master_params_to_model_params
(
self
):
def
post_backward_with_master_weights_FusedAdam
(
self
,
scaler
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
if
multi_tensor_applier
.
available
:
if
len
(
stash
.
all_fp16_params
)
>
0
:
self
.
_amp_lazy_init
()
multi_tensor_applier
(
stash
.
multi_tensor_scale
,
stash
.
scale
=
scaler
.
loss_scale
()
stash
.
grads
=
[[
param
.
grad
.
data
for
param
in
group
]
for
group
in
stash
.
fp16_groups
]
stash
.
output_params
=
[[
param
for
param
in
group
]
for
group
in
stash
.
fp16_groups
]
norm_groups
=
[]
skip
=
False
for
grad_group
in
stash
.
grads
:
norm
,
_
=
multi_tensor_applier
(
stash
.
multi_tensor_l2norm
,
stash
.
dummy_overflow_buf
,
stash
.
dummy_overflow_buf
,
[
stash
.
all_fp32_from_fp16_params
,
stash
.
all_fp16_params
],
[
grad_group
],
1.0
)
False
)
# Still syncing here for now.
norm
=
float
(
norm
)
norm_groups
.
append
(
norm
)
if
norm
==
float
(
'inf'
)
or
norm
==
-
float
(
'inf'
)
or
norm
!=
norm
:
skip
=
True
if
skip
:
scaler
.
_overflow_buf
.
fill_
(
1.
)
scaler
.
_has_overflow
=
True
stash
.
grad_norms
=
norm_groups
def
prepare_backward_no_master_weights_FusedAdam
(
self
):
stash
=
self
.
_amp_stash
self
.
_amp_lazy_init
()
def
post_backward_no_master_weights_FusedAdam
(
self
,
scaler
):
stash
=
self
.
_amp_stash
self
.
_amp_lazy_init
()
stash
.
scale
=
scaler
.
loss_scale
()
stash
.
grads
=
None
stash
.
output_params
=
None
stash
.
grad_norms
=
None
#####################################################################################
# FusedSGD versions
# Eat this ugly code duplication for now. First make it work, then make it clean.
# It's difficult to anticipate what can be unified between the FusedAdam and FusedSGD
# implementations until I have them both working.
#####################################################################################
# FusedSGD never explicitly materializes the fp32 gradients for "fp32 from fp16" master params
# outside the kernel, so we must accumulate directly into the model grads.
def
prepare_backward_with_master_weights_FusedSGD
(
self
):
if
self
.
materialize_master_grads
:
prepare_backward_with_master_weights
(
self
)
else
:
else
:
for
fp16_group
,
fp32_from_fp16_group
in
zip
(
stash
.
fp16_groups
,
stash
.
fp32_from_fp16_groups
):
stash
=
self
.
_amp_stash
master_params_to_model_params
(
fp16_group
,
fp32_from_fp16_group
)
self
.
_amp_lazy_init
()
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
stash
.
all_fp16_grad_stash
[
i
]
=
param
.
grad
# Set up to leverage grad copy elision:
param
.
grad
=
None
for
i
,
param
in
enumerate
(
stash
.
all_fp32_from_fp32_params
):
stash
.
all_fp32_from_fp32_grad_stash
[
i
]
=
param
.
grad
# Set up to leverage grad copy elision:
param
.
grad
=
None
def
post_backward_with_master_weights_FusedSGD
(
self
,
scaler
):
if
self
.
materialize_master_grads
:
post_backward_with_master_weights
(
self
,
scaler
)
else
:
stash
=
self
.
_amp_stash
self
.
_amp_lazy_init
()
grads_have_scale
=
scaler
.
loss_scale
()
stashed_have_scale
=
self
.
most_recent_scale
out_scale
=
grads_have_scale
if
self
.
scale_set_by_backward
:
out_scale
=
min
(
grads_have_scale
,
self
.
most_recent_scale
)
split_types
=
((
stash
.
all_fp16_params
,
stash
.
all_fp16_grad_stash
),
(
stash
.
all_fp32_from_fp32_params
,
stash
.
all_fp32_from_fp32_grad_stash
))
# unscale_with_stashed() implements grads*1/scale + stashed_grads*1.
# stashed_grads are scaled by self.most_recent_scale.
for
params
,
stashed_grads
in
split_types
:
post_backward_models_are_masters
(
scaler
,
params
,
stashed_grads
,
(
grads_have_scale
,
stashed_have_scale
,
out_scale
))
self
.
most_recent_scale
=
out_scale
self
.
scale_set_by_backward
=
True
def
prepare_backward_no_master_weights_FusedSGD
(
self
):
prepare_backward_no_master_weights
(
self
)
def
post_backward_no_master_weights_FusedSGD
(
self
,
scaler
):
post_backward_no_master_weights
(
self
,
scaler
)
def
_amp_lazy_init
(
self
):
stash
=
self
.
_amp_stash
if
not
stash
.
lazy_init_called
:
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
def
_process_optimizer
(
optimizer
,
properties
):
def
_process_optimizer
(
optimizer
,
properties
):
...
@@ -266,7 +382,8 @@ def _process_optimizer(optimizer, properties):
...
@@ -266,7 +382,8 @@ def _process_optimizer(optimizer, properties):
for
name
in
(
"_lazy_init_maybe_master_weights"
,
for
name
in
(
"_lazy_init_maybe_master_weights"
,
"_master_params_to_model_params"
,
"_master_params_to_model_params"
,
"_prepare_amp_backward"
,
"_prepare_amp_backward"
,
"_post_amp_backward"
):
"_post_amp_backward"
,
"_amp_lazy_init"
):
if
hasattr
(
optimizer
,
name
):
if
hasattr
(
optimizer
,
name
):
raise
RuntimeError
(
"Incoming optimizer already has {} defined."
.
format
(
name
))
raise
RuntimeError
(
"Incoming optimizer already has {} defined."
.
format
(
name
))
...
@@ -274,6 +391,7 @@ def _process_optimizer(optimizer, properties):
...
@@ -274,6 +391,7 @@ def _process_optimizer(optimizer, properties):
if
multi_tensor_applier
.
available
:
if
multi_tensor_applier
.
available
:
import
amp_C
import
amp_C
optimizer
.
_amp_stash
.
multi_tensor_scale
=
amp_C
.
multi_tensor_scale
optimizer
.
_amp_stash
.
multi_tensor_scale
=
amp_C
.
multi_tensor_scale
optimizer
.
_amp_stash
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
optimizer
.
_amp_stash
.
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
]);
optimizer
.
_amp_stash
.
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
]);
if
properties
.
master_weights
:
if
properties
.
master_weights
:
...
@@ -288,6 +406,7 @@ def _process_optimizer(optimizer, properties):
...
@@ -288,6 +406,7 @@ def _process_optimizer(optimizer, properties):
if
closure
is
not
None
:
if
closure
is
not
None
:
raise
RuntimeError
(
"Currently, Amp does not support closure use with optimizers."
)
raise
RuntimeError
(
"Currently, Amp does not support closure use with optimizers."
)
retval
=
old_step
()
retval
=
old_step
()
if
not
(
isinstance
(
self
,
FusedAdam
)
or
isinstance
(
self
,
FusedSGD
)):
self
.
_master_params_to_model_params
()
self
.
_master_params_to_model_params
()
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
for
param
in
self
.
_amp_stash
.
all_fp32_from_fp16_params
:
for
param
in
self
.
_amp_stash
.
all_fp32_from_fp16_params
:
...
@@ -298,9 +417,7 @@ def _process_optimizer(optimizer, properties):
...
@@ -298,9 +417,7 @@ def _process_optimizer(optimizer, properties):
old_zero_grad
=
optimizer
.
zero_grad
old_zero_grad
=
optimizer
.
zero_grad
def
new_zero_grad
(
self
):
def
new_zero_grad
(
self
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
if
not
stash
.
lazy_init_called
:
self
.
_amp_lazy_init
()
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
# Zero the model grads.
# Zero the model grads.
for
param
in
stash
.
all_fp16_params
:
for
param
in
stash
.
all_fp16_params
:
if
param
.
grad
is
not
None
:
if
param
.
grad
is
not
None
:
...
@@ -315,21 +432,43 @@ def _process_optimizer(optimizer, properties):
...
@@ -315,21 +432,43 @@ def _process_optimizer(optimizer, properties):
param
.
grad
=
None
param
.
grad
=
None
optimizer
.
zero_grad
=
types
.
MethodType
(
new_zero_grad
,
optimizer
)
optimizer
.
zero_grad
=
types
.
MethodType
(
new_zero_grad
,
optimizer
)
if
isinstance
(
optimizer
,
FusedAdam
):
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_with_master_weights_FusedAdam
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_with_master_weights_FusedAdam
,
optimizer
)
elif
isinstance
(
optimizer
,
FusedSGD
):
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_with_master_weights_FusedSGD
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_with_master_weights_FusedSGD
,
optimizer
)
else
:
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_with_master_weights
,
optimizer
)
prepare_backward_with_master_weights
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_with_master_weights
,
optimizer
)
post_backward_with_master_weights
,
optimizer
)
else
:
else
:
optimizer
.
_lazy_init_maybe_master_weights
=
types
.
MethodType
(
optimizer
.
_lazy_init_maybe_master_weights
=
types
.
MethodType
(
lazy_init_no_master_weights
,
optimizer
)
lazy_init_no_master_weights
,
optimizer
)
if
isinstance
(
optimizer
,
FusedAdam
):
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_no_master_weights_FusedAdam
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_no_master_weights_FusedAdam
,
optimizer
)
elif
isinstance
(
optimizer
,
FusedSGD
):
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_no_master_weights_FusedSGD
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_no_master_weights_FusedSGD
,
optimizer
)
else
:
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_no_master_weights
,
optimizer
)
prepare_backward_no_master_weights
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_no_master_weights
,
optimizer
)
post_backward_no_master_weights
,
optimizer
)
optimizer
.
_amp_lazy_init
=
types
.
MethodType
(
_amp_lazy_init
,
optimizer
)
old_add_param_group
=
optimizer
.
add_param_group
old_add_param_group
=
optimizer
.
add_param_group
def
new_add_param_group
(
self
,
new_group
):
def
new_add_param_group
(
self
,
new_group
):
...
...
apex/amp/handle.py
View file @
4d6ed501
...
@@ -6,8 +6,6 @@ from . import utils
...
@@ -6,8 +6,6 @@ from . import utils
from
.opt
import
OptimWrapper
from
.opt
import
OptimWrapper
from
.scaler
import
LossScaler
from
.scaler
import
LossScaler
from
._amp_state
import
_amp_state
,
master_params
,
maybe_print
from
._amp_state
import
_amp_state
,
master_params
,
maybe_print
from
..fp16_utils
import
FP16_Optimizer
as
FP16_Optimizer_general
from
..optimizers
import
FP16_Optimizer
as
FP16_Optimizer_for_fused
from
..parallel.LARC
import
LARC
from
..parallel.LARC
import
LARC
...
@@ -89,11 +87,6 @@ def scale_loss(loss,
...
@@ -89,11 +87,6 @@ def scale_loss(loss,
if
isinstance
(
optimizers
,
torch
.
optim
.
Optimizer
)
or
isinstance
(
optimizers
,
LARC
):
if
isinstance
(
optimizers
,
torch
.
optim
.
Optimizer
)
or
isinstance
(
optimizers
,
LARC
):
optimizers
=
[
optimizers
]
optimizers
=
[
optimizers
]
# this is what happens when i have to support tools from different sources under the same API...
# TODO: Rewrite FusedAdam to use multi-tensor apply and the same loss scaler.
if
isinstance
(
optimizers
,
FP16_Optimizer_for_fused
):
loss_scale
=
optimizers
.
cur_scale
else
:
loss_scaler
=
_amp_state
.
loss_scalers
[
loss_id
]
loss_scaler
=
_amp_state
.
loss_scalers
[
loss_id
]
loss_scale
=
loss_scaler
.
loss_scale
()
loss_scale
=
loss_scaler
.
loss_scale
()
...
@@ -120,8 +113,8 @@ def scale_loss(loss,
...
@@ -120,8 +113,8 @@ def scale_loss(loss,
for
optimizer
in
optimizers
:
for
optimizer
in
optimizers
:
optimizer
.
_amp_stash
.
params_have_scaled_gradients
=
True
optimizer
.
_amp_stash
.
params_have_scaled_gradients
=
True
else
:
else
:
# FusedAdam and FusedSGD
will
take care of unscaling as part of their step() methods.
# FusedAdam and FusedSGD
may
take care of unscaling as part of their step() methods.
if
not
isinstance
(
optimizers
,
FP16_Optimizer_for_fused
):
#
if not isinstance(optimizers, FP16_Optimizer_for_fused):
loss_scaler
.
clear_overflow_state
()
loss_scaler
.
clear_overflow_state
()
for
optimizer
in
optimizers
:
for
optimizer
in
optimizers
:
optimizer
.
_post_amp_backward
(
loss_scaler
)
optimizer
.
_post_amp_backward
(
loss_scaler
)
...
@@ -142,10 +135,15 @@ def scale_loss(loss,
...
@@ -142,10 +135,15 @@ def scale_loss(loss,
maybe_print
((
"Gradient overflow. Skipping step, loss scaler "
+
maybe_print
((
"Gradient overflow. Skipping step, loss scaler "
+
"{} reducing loss scale to {}"
).
format
(
loss_id
,
"{} reducing loss scale to {}"
).
format
(
loss_id
,
loss_scaler
.
loss_scale
()))
loss_scaler
.
loss_scale
()))
# TODO: I don't like the special casing for different optimizer implementations.
# Maybe skip should delegate to a method owned by the optimizers themselves.
if
hasattr
(
opt
.
_amp_stash
,
"all_fp32_from_fp16_params"
):
if
hasattr
(
opt
.
_amp_stash
,
"all_fp32_from_fp16_params"
):
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
for
param
in
opt
.
_amp_stash
.
all_fp32_from_fp16_params
:
for
param
in
opt
.
_amp_stash
.
all_fp32_from_fp16_params
:
param
.
grad
=
None
param
.
grad
=
None
if
hasattr
(
opt
,
"most_recent_scale"
):
opt
.
most_recent_scale
=
1.0
opt
.
scale_set_by_backward
=
False
opt
.
step
=
opt_step
opt
.
step
=
opt_step
opt
.
_amp_stash
.
already_patched
=
False
opt
.
_amp_stash
.
already_patched
=
False
return
skip_step
return
skip_step
...
...
apex/amp/scaler.py
View file @
4d6ed501
...
@@ -16,7 +16,7 @@ def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=F
...
@@ -16,7 +16,7 @@ def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=F
master_grad
.
mul_
(
scale
)
master_grad
.
mul_
(
scale
)
return
False
return
False
def
axpby_check_overflow_python
(
model_grad
,
stashed_grad
,
master_grad
,
scale
,
check_overflow
=
False
):
def
axpby_check_overflow_python
(
model_grad
,
stashed_grad
,
master_grad
,
a
,
b
,
check_overflow
=
False
):
# Exception handling for 18.04 compatibility
# Exception handling for 18.04 compatibility
if
check_overflow
:
if
check_overflow
:
cpu_sum
=
float
(
model_grad
.
float
().
sum
())
cpu_sum
=
float
(
model_grad
.
float
().
sum
())
...
@@ -26,9 +26,8 @@ def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, scale, ch
...
@@ -26,9 +26,8 @@ def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, scale, ch
# if master_grad is not model_grad: # copy_ probably internally short-circuits this
# if master_grad is not model_grad: # copy_ probably internally short-circuits this
# master_grad.copy_(model_grad)
# master_grad.copy_(model_grad)
assert
stashed_grad
.
dtype
==
master_grad
.
dtype
assert
stashed_grad
.
dtype
==
master_grad
.
dtype
converted_model_grad
=
model_grad
.
to
(
master_grad
.
dtype
)
converted_model_grad
=
model_grad
.
data
.
to
(
master_grad
.
dtype
)
stashed_grad
.
add_
(
scale
,
converted_model_grad
)
master_grad
.
data
=
a
*
converted_model_grad
.
data
+
b
*
stashed_grad
.
data
master_grad
.
data
=
stashed_grad
.
data
return
False
return
False
class
LossScaler
(
object
):
class
LossScaler
(
object
):
...
@@ -92,11 +91,13 @@ class LossScaler(object):
...
@@ -92,11 +91,13 @@ class LossScaler(object):
break
break
# unused_scale keeps some of the old API alive for hopefully a short time.
# unused_scale keeps some of the old API alive for hopefully a short time.
def
unscale
(
self
,
model_grads
,
master_grads
,
unused_scale
,
models_are_masters
=
False
):
def
unscale
(
self
,
model_grads
,
master_grads
,
unused_scale
,
models_are_masters
=
False
,
scale_override
=
None
):
if
self
.
_has_overflow
:
if
self
.
_has_overflow
:
return
return
scale
=
self
.
_loss_scale
scale
=
self
.
_loss_scale
if
scale_override
is
not
None
:
scale
=
scale_override
if
scale
==
1.0
and
models_are_masters
and
not
self
.
dynamic
:
if
scale
==
1.0
and
models_are_masters
and
not
self
.
dynamic
:
return
return
...
@@ -126,7 +127,8 @@ class LossScaler(object):
...
@@ -126,7 +127,8 @@ class LossScaler(object):
model_grads
,
model_grads
,
stashed_master_grads
,
stashed_master_grads
,
master_grads
,
master_grads
,
scale
):
a
,
b
):
for
model
,
stashed
,
master
in
zip
(
model_grads
,
stashed_master_grads
,
master_grads
):
for
model
,
stashed
,
master
in
zip
(
model_grads
,
stashed_master_grads
,
master_grads
):
if
model
is
None
and
stashed
is
None
:
if
model
is
None
and
stashed
is
None
:
continue
continue
...
@@ -141,7 +143,8 @@ class LossScaler(object):
...
@@ -141,7 +143,8 @@ class LossScaler(object):
self
.
_has_overflow
=
axpby_check_overflow_python
(
model
,
self
.
_has_overflow
=
axpby_check_overflow_python
(
model
,
stashed
,
stashed
,
master
,
master
,
1.
/
scale
,
a
,
b
,
self
.
dynamic
)
self
.
dynamic
)
if
self
.
_has_overflow
and
self
.
dynamic
:
if
self
.
_has_overflow
and
self
.
dynamic
:
break
break
...
@@ -149,11 +152,14 @@ class LossScaler(object):
...
@@ -149,11 +152,14 @@ class LossScaler(object):
def
unscale_with_stashed
(
self
,
def
unscale_with_stashed
(
self
,
model_grads
,
model_grads
,
stashed_master_grads
,
stashed_master_grads
,
master_grads
):
master_grads
,
scale_override
=
None
):
if
self
.
_has_overflow
:
if
self
.
_has_overflow
:
return
return
scale
=
self
.
_loss_scale
grads_have_scale
,
stashed_have_scale
,
out_scale
=
self
.
_loss_scale
,
1.0
,
1.0
if
scale_override
is
not
None
:
grads_have_scale
,
stashed_have_scale
,
out_scale
=
scale_override
if
LossScaler
.
has_fused_kernel
:
if
LossScaler
.
has_fused_kernel
:
if
(
not
LossScaler
.
warned_unscaling_non_fp32_grad
if
(
not
LossScaler
.
warned_unscaling_non_fp32_grad
...
@@ -167,14 +173,15 @@ class LossScaler(object):
...
@@ -167,14 +173,15 @@ class LossScaler(object):
multi_tensor_applier
(
LossScaler
.
multi_tensor_axpby_cuda
,
multi_tensor_applier
(
LossScaler
.
multi_tensor_axpby_cuda
,
self
.
_overflow_buf
,
self
.
_overflow_buf
,
[
model_grads
,
stashed_master_grads
,
master_grads
],
[
model_grads
,
stashed_master_grads
,
master_grads
],
1.
/
scale
,
out_scale
/
grads_have_scale
,
#
1./scale,
1.0
,
out_scale
/
stashed_have_scale
,
#
1.0,
0
)
# check only arg 0, aka the incoming model grads, for infs
0
)
# check only arg 0, aka the incoming model grads, for infs
else
:
else
:
self
.
unscale_with_stashed_python
(
model_grads
,
self
.
unscale_with_stashed_python
(
model_grads
,
stashed_master_grads
,
stashed_master_grads
,
master_grads
,
master_grads
,
scale
)
out_scale
/
grads_have_scale
,
out_scale
/
stashed_have_scale
)
# Defer to update_scale
# Defer to update_scale
# If the fused kernel is available, we only need one D2H memcopy and sync.
# If the fused kernel is available, we only need one D2H memcopy and sync.
...
...
apex/contrib/__init__.py
0 → 100644
View file @
4d6ed501
apex/contrib/csrc/groupbn/batch_norm.cu
0 → 100644
View file @
4d6ed501
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include "THC/THC.h"
#include "batch_norm.h"
#include <cuda.h>
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
static
size_t
round_up_to_multiple
(
size_t
x
,
int
multiple
)
{
return
((
x
+
multiple
-
1
)
/
multiple
)
*
multiple
;
}
// TODO: Stop manually allocating CUDA memory; allocate an ATen byte
// tensor instead.
struct
Workspace
{
Workspace
(
size_t
size
)
:
size
(
size
),
data
(
NULL
)
{
data
=
THCudaMalloc
(
at
::
globalContext
().
lazyInitCUDA
(),
size
);
}
Workspace
(
const
Workspace
&
)
=
delete
;
Workspace
(
Workspace
&&
)
=
default
;
Workspace
&
operator
=
(
Workspace
&&
)
=
default
;
~
Workspace
()
{
if
(
data
)
{
THCudaFree
(
at
::
globalContext
().
lazyInitCUDA
(),
data
);
}
}
size_t
size
;
void
*
data
;
};
// Return {y}
at
::
Tensor
nhwc_bn_fwd_train
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
const
int
N
=
x
.
size
(
0
);
const
int
H
=
x
.
size
(
1
);
const
int
W
=
x
.
size
(
2
);
const
int
C
=
x
.
size
(
3
);
// generating new magic number and use that for sync
int
*
magic
=
magic_tensor
.
data
<
int
>
();
*
magic
=
(
*
magic
+
1
)
&
0xff
;
// Allocate output tensor
at
::
Tensor
y
=
at
::
empty
({
N
,
H
,
W
,
C
},
x
.
options
());
// Create wrapper
NhwcBatchNorm
*
bn
=
new
NhwcBatchNorm
();
bn
->
setInputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
,
bn_group
);
bn
->
setOutputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
);
bn
->
setConstants
(
momentum
,
epsilon
);
// set pointers within the wrapper
bn
->
setInputOutputPointers
(
x
.
data
<
at
::
Half
>
(),
nullptr
,
y
.
data
<
at
::
Half
>
(),
nullptr
);
bn
->
setWeightPointers
({
scale
.
data
<
float
>
(),
bias
.
data
<
float
>
()},
{
nullptr
,
nullptr
});
bn
->
setParameterPointers
({
running_mean
.
data
<
float
>
(),
running_inv_var
.
data
<
float
>
()});
// deal with workspace(s)
auto
workspace_bytes
=
bn
->
numWorkspaceBytes
();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t
total_workspace_bytes
=
0
;
std
::
vector
<
size_t
>
workspace_offsets
;
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
total_workspace_bytes
=
round_up_to_multiple
(
total_workspace_bytes
,
512
);
workspace_offsets
.
push_back
(
total_workspace_bytes
);
auto
alloc_bytes
=
workspace_bytes
[
index
];
total_workspace_bytes
+=
alloc_bytes
;
}
// Allocate the workspace
Workspace
ws
(
total_workspace_bytes
);
std
::
vector
<
void
*>
workspace
;
workspace
.
push_back
(
minibatch_mean
.
data
<
float
>
());
workspace
.
push_back
(
minibatch_inv_var
.
data
<
float
>
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
2
];
void
*
retired_ctas
=
ret_cta
.
data
<
uint8_t
>
();
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
void
*
ptr
=
reinterpret_cast
<
uint8_t
*>
(
ws
.
data
)
+
workspace_offsets
[
index
-
3
];
workspace
.
push_back
(
ptr
);
}
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
// Don't fuse in ReLU for now at least
bn
->
fwd
(
stream
,
fuse_relu
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
*
magic
,
occupancy
,
grid_dim_x
,
coop
);
return
y
;
}
at
::
Tensor
nhwc_bn_fwd_eval
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
int
bn_group
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
)
{
const
int
N
=
x
.
size
(
0
);
const
int
H
=
x
.
size
(
1
);
const
int
W
=
x
.
size
(
2
);
const
int
C
=
x
.
size
(
3
);
// Allocate output tensor
at
::
Tensor
y
=
at
::
empty
({
N
,
H
,
W
,
C
},
x
.
options
());
// Create wrapper
NhwcBatchNorm
*
bn
=
new
NhwcBatchNorm
();
bn
->
setInputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
,
bn_group
);
bn
->
setOutputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
);
bn
->
setConstants
(
momentum
,
epsilon
);
// set pointers within the wrapper
bn
->
setInputOutputPointers
(
x
.
data
<
at
::
Half
>
(),
nullptr
,
y
.
data
<
at
::
Half
>
(),
nullptr
);
bn
->
setWeightPointers
({
scale
.
data
<
float
>
(),
bias
.
data
<
float
>
()},
{
nullptr
,
nullptr
});
bn
->
setParameterPointers
({
running_mean
.
data
<
float
>
(),
running_inv_var
.
data
<
float
>
()});
// deal with workspace(s)
auto
workspace_bytes
=
bn
->
numWorkspaceBytes
();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t
total_workspace_bytes
=
0
;
std
::
vector
<
size_t
>
workspace_offsets
;
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
total_workspace_bytes
=
round_up_to_multiple
(
total_workspace_bytes
,
512
);
workspace_offsets
.
push_back
(
total_workspace_bytes
);
auto
alloc_bytes
=
workspace_bytes
[
index
];
total_workspace_bytes
+=
alloc_bytes
;
}
// Allocate the workspace
Workspace
ws
(
total_workspace_bytes
);
std
::
vector
<
void
*>
workspace
;
workspace
.
push_back
(
nullptr
);
workspace
.
push_back
(
nullptr
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
2
];
void
*
retired_ctas
=
ret_cta
.
data
<
uint8_t
>
();
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
void
*
ptr
=
reinterpret_cast
<
uint8_t
*>
(
ws
.
data
)
+
workspace_offsets
[
index
-
3
];
workspace
.
push_back
(
ptr
);
}
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
// Don't fuse in ReLU for now at least
bn
->
fwdInference
(
stream
,
fuse_relu
);
return
y
;
}
std
::
vector
<
at
::
Tensor
>
nhwc_bn_bwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
dy
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
// shape
const
int
N
=
x
.
size
(
0
);
const
int
H
=
x
.
size
(
1
);
const
int
W
=
x
.
size
(
2
);
const
int
C
=
x
.
size
(
3
);
// generating new magic number and use that for sync
int
*
magic
=
magic_tensor
.
data
<
int
>
();
*
magic
=
(
*
magic
+
1
)
&
0xff
;
// outputs
at
::
Tensor
x_grad
,
scale_grad
,
bias_grad
;
// Allocate outputs
x_grad
=
at
::
empty_like
(
x
);
scale_grad
=
at
::
empty_like
(
scale
);
bias_grad
=
at
::
empty_like
(
bias
);
// Create wrapper
NhwcBatchNorm
*
bn
=
new
NhwcBatchNorm
();
bn
->
setInputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
,
bn_group
);
bn
->
setOutputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
);
bn
->
setConstants
(
momentum
,
epsilon
);
// set pointers within the wrapper
bn
->
setInputOutputPointers
(
x
.
data
<
at
::
Half
>
(),
x_grad
.
data
<
at
::
Half
>
(),
nullptr
,
dy
.
data
<
at
::
Half
>
());
bn
->
setWeightPointers
({
scale
.
data
<
float
>
(),
bias
.
data
<
float
>
()},
{
scale_grad
.
data
<
float
>
(),
bias_grad
.
data
<
float
>
()});
bn
->
setParameterPointers
({
running_mean
.
data
<
float
>
(),
running_inv_var
.
data
<
float
>
()});
// deal with workspace(s)
auto
workspace_bytes
=
bn
->
numWorkspaceBytes
();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t
total_workspace_bytes
=
0
;
std
::
vector
<
size_t
>
workspace_offsets
;
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
total_workspace_bytes
=
round_up_to_multiple
(
total_workspace_bytes
,
512
);
workspace_offsets
.
push_back
(
total_workspace_bytes
);
auto
alloc_bytes
=
workspace_bytes
[
index
];
total_workspace_bytes
+=
alloc_bytes
;
}
// Allocate the workspace
Workspace
ws
(
total_workspace_bytes
);
std
::
vector
<
void
*>
workspace
;
workspace
.
push_back
(
minibatch_mean
.
data
<
float
>
());
workspace
.
push_back
(
minibatch_inv_var
.
data
<
float
>
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
2
];
void
*
retired_ctas
=
ret_cta
.
data
<
uint8_t
>
();
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
void
*
ptr
=
reinterpret_cast
<
uint8_t
*>
(
ws
.
data
)
+
workspace_offsets
[
index
-
3
];
workspace
.
push_back
(
ptr
);
}
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
bn
->
dgrad
(
stream
,
fuse_relu
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
*
magic
,
occupancy
,
grid_dim_x
,
coop
);
return
std
::
vector
<
at
::
Tensor
>
{
x_grad
,
scale_grad
,
bias_grad
};
}
int
nhwc_bn_fwd_occupancy
()
{
int
device_id
=-
1
;
cudaGetDevice
(
&
device_id
);
//max occupancy supported by the code is 2
return
NhwcBatchNorm
::
smem_driven_fwd_occupancy
(
device_id
,
2
);
}
int
nhwc_bn_bwd_occupancy
()
{
int
device_id
=-
1
;
cudaGetDevice
(
&
device_id
);
//max occupancy supported by the code is 2
return
NhwcBatchNorm
::
smem_driven_bwd_occupancy
(
device_id
,
2
);
}
apex/contrib/csrc/groupbn/batch_norm.h
0 → 100644
View file @
4d6ed501
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nhwc_batch_norm.h
* \brief CUDA NHWC Batch Normalization code
* \author Shankara Rao Thejaswi Nanditale, Dick Carter, Evgeni Krimer
*/
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
#include <cudnn.h>
#include <algorithm>
#include <vector>
#include <string>
#include "nhwc_batch_norm_kernel.h"
#include "cuda_utils.h"
#define VERBOSE_DEFAULT false
class
NhwcBatchNorm
{
public:
NhwcBatchNorm
()
{
name_
=
"nhwc_batchnorm"
;
createTensorDescriptor
(
&
X_tensor_desc_
);
createTensorDescriptor
(
&
Y_tensor_desc_
);
}
~
NhwcBatchNorm
()
{
destroyTensorDescriptor
(
X_tensor_desc_
);
destroyTensorDescriptor
(
Y_tensor_desc_
);
}
void
die
()
{
std
::
cerr
<<
"batchnorm not initialized"
<<
std
::
endl
;
exit
(
-
1
);
}
void
fwd
(
cudaStream_t
stream
,
bool
use_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
void
dgrad
(
cudaStream_t
stream
,
bool
use_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
void
fwdInference
(
cudaStream_t
stream
,
bool
use_relu
);
dim3
calc_fwd_grid
(
int
*
loop
,
const
int
grid_dim_x
);
dim3
calc_bwd_grid
(
int
*
loop
,
const
int
grid_dim_x
);
void
setInputDescriptor
(
const
cudnnTensorFormat_t
format
,
const
cudnnDataType_t
data_type
,
int
n
,
int
c
,
int
h
,
int
w
,
int
bn_group
)
{
m_
=
n
*
h
*
w
;
int
m_bn_adjusted
=
m_
*
bn_group
;
c_
=
c
;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
svar_inv_count_
=
1.
f
/
m_bn_adjusted
;
// factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1).
int
divisor
=
m_bn_adjusted
-
1
;
// nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs.
rvar_inv_count_
=
divisor
==
0
?
1.
f
:
1.
f
/
divisor
;
setTensorDescriptor
(
X_tensor_desc_
,
format
,
data_type
,
n
,
c
,
h
,
w
);
}
void
setOutputDescriptor
(
const
cudnnTensorFormat_t
format
,
const
cudnnDataType_t
data_type
,
int
n
,
int
c
,
int
h
,
int
w
)
{
setTensorDescriptor
(
Y_tensor_desc_
,
format
,
data_type
,
n
,
c
,
h
,
w
);
}
const
std
::
vector
<
size_t
>
numWorkspaceBytes
()
const
;
void
setWorkspacePointers
(
const
std
::
vector
<
void
*>&
workspace
,
const
std
::
vector
<
size_t
>&
num_workspace_bytes
);
void
setInputOutputPointers
(
void
*
X
,
void
*
dX
,
void
*
Y
,
void
*
dY
)
{
X_
=
X
;
dX_
=
dX
;
Y_
=
Y
;
dY_
=
dY
;
}
// Sets the pointers for the scale and weight (in that order) data and derivative buffers.
void
setWeightPointers
(
const
std
::
vector
<
void
*>&
weight_pointers
,
const
std
::
vector
<
void
*>&
deriv_pointers
)
{
assert
(
weight_pointers
.
size
()
==
2
);
assert
(
deriv_pointers
.
size
()
==
2
);
scale_
=
static_cast
<
float
*>
(
weight_pointers
[
0
]);
bias_
=
static_cast
<
float
*>
(
weight_pointers
[
1
]);
dscale_
=
static_cast
<
float
*>
(
deriv_pointers
[
0
]);
dbias_
=
static_cast
<
float
*>
(
deriv_pointers
[
1
]);
}
// Sets the pointers for the population mean and variance buffers, in that order.
void
setParameterPointers
(
const
std
::
vector
<
void
*>&
param_pointers
)
{
assert
(
param_pointers
.
size
()
==
2
);
population_mean_
=
static_cast
<
float
*>
(
param_pointers
[
0
]);
population_variance_
=
static_cast
<
float
*>
(
param_pointers
[
1
]);
}
void
setConstants
(
const
double
exp_avg_factor
,
const
double
eps
)
{
exp_avg_factor_
=
exp_avg_factor
;
eps_
=
eps
;
}
void
processCudnnStatus
(
const
cudnnStatus_t
&
status
,
const
std
::
string
&
string
=
std
::
string
(),
bool
verbose
=
VERBOSE_DEFAULT
)
{
if
(
status
!=
CUDNN_STATUS_SUCCESS
)
LOG
(
FATAL
)
<<
string
<<
" "
<<
cudnnGetErrorString
(
status
);
else
if
(
verbose
)
LOG
(
INFO
)
<<
string
<<
" "
<<
cudnnGetErrorString
(
status
);
}
void
checkCudaStatus
(
const
std
::
string
&
string
=
std
::
string
(),
bool
verbose
=
VERBOSE_DEFAULT
)
{
cudaError_t
status
=
cudaGetLastError
();
if
(
status
!=
cudaSuccess
)
LOG
(
FATAL
)
<<
string
<<
" "
<<
cudaGetErrorString
(
status
);
else
if
(
verbose
)
LOG
(
INFO
)
<<
string
<<
" "
<<
cudaGetErrorString
(
status
);
}
size_t
size_retired_ctas
(
int
grid_y
)
const
{
// Note that the value of max_grid_y to handle known GPUs is about 160.
const
int
max_grid_y
=
1024
;
if
(
grid_y
>
max_grid_y
)
LOG
(
INFO
)
<<
"GPU capabilities exceeds assumptions."
;
const
int
retired_cta_bytes
=
max_grid_y
*
2
*
sizeof
(
int
);
// Since the region will be initialized once and used for many kernels,
// the idea is to return an ample size that will cover all uses.
return
retired_cta_bytes
;
}
cudnnTensorDescriptor_t
X_tensor_desc_
=
nullptr
;
cudnnTensorDescriptor_t
Y_tensor_desc_
=
nullptr
;
void
*
X_
=
nullptr
;
void
*
dX_
=
nullptr
;
void
*
Y_
=
nullptr
;
void
*
dY_
=
nullptr
;
// Learned scale and bias weights.
float
*
scale_
=
nullptr
;
float
*
dscale_
=
nullptr
;
float
*
bias_
=
nullptr
;
float
*
dbias_
=
nullptr
;
// Computed population mean and variance parameters.
float
*
population_mean_
=
nullptr
;
float
*
population_variance_
=
nullptr
;
// Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd).
float
*
minibatch_mean_
=
nullptr
;
float
*
minibatch_variance_
=
nullptr
;
int
m_
=
0
;
// Number of values per channel that BN is normalizing.
int
c_
=
0
;
// Number of channels over which BN is normalizing.
float
svar_inv_count_
=
0.
f
;
// factor to scale sum of squared errors to get saved variance
float
rvar_inv_count_
=
0.
f
;
// factor to scale sum of squared errors to get running variance
double
exp_avg_factor_
=
0.
;
double
eps_
=
0.
;
std
::
string
name_
;
private:
void
setTensorDescriptor
(
cudnnTensorDescriptor_t
descriptor
,
cudnnTensorFormat_t
format
,
cudnnDataType_t
data_type
,
int
n
,
int
c
,
int
h
,
int
w
)
{
cudnnStatus_t
status
=
CUDNN_STATUS_SUCCESS
;
status
=
cudnnSetTensor4dDescriptor
(
descriptor
,
format
,
data_type
,
n
,
c
,
h
,
w
);
processCudnnStatus
(
status
,
"set tensor descriptor"
);
}
void
createTensorDescriptor
(
cudnnTensorDescriptor_t
*
descriptor
)
{
cudnnStatus_t
status
=
CUDNN_STATUS_SUCCESS
;
status
=
cudnnCreateTensorDescriptor
(
descriptor
);
processCudnnStatus
(
status
,
"create tensor_descriptor"
);
}
void
destroyTensorDescriptor
(
cudnnTensorDescriptor_t
descriptor
)
{
cudnnStatus_t
status
=
CUDNN_STATUS_SUCCESS
;
status
=
cudnnDestroyTensorDescriptor
(
descriptor
);
processCudnnStatus
(
status
,
"destroy tensor_descriptor"
);
}
protected:
float
*
partial_sums_
=
nullptr
;
int
*
partial_counts_
=
nullptr
;
int
*
retired_ctas_
=
nullptr
;
void
_setFwdParams
(
NhwcBatchNormFwdParams
*
params
)
const
;
void
_setFwdInferenceParams
(
NhwcBatchNormFwdInferenceParams
*
params
)
const
;
void
_setBwdParams
(
NhwcBatchNormBwdParams
*
params
)
const
;
// @todo: ability to configure these?
// Kernel params
static
const
int
USE_ONLINE_APPROACH
=
1
;
static
const
int
THREADS_PER_CTA
=
512
;
static
const
int
THREADS_PER_PIXEL
=
16
;
static
const
int
C_ELEMENTS_PER_CTA
=
64
;
static
const
int
ELEMENTS_PER_LDG
=
C_ELEMENTS_PER_CTA
/
THREADS_PER_PIXEL
;
static
const
int
MAX_SMEM_WITHOUT_OPT_IN
=
48
*
1024
;
typedef
uint16_t
StorageType
;
//typedef float StorageType;
// increasing this to 6 causes spills in fwd kernel!
static
const
int
PIXELS_PER_THREAD_IN_REGISTERS_FWD
=
5
;
static
const
int
PIXELS_PER_THREAD_IN_REGISTERS_BWD
=
3
;
static
const
int
PIXELS_PER_THREAD_IN_SMEM_FWD
=
10
;
static
const
int
PIXELS_PER_THREAD_IN_SMEM_BWD
=
5
;
static
const
int
PIXELS_PER_THREAD_FWD
=
PIXELS_PER_THREAD_IN_REGISTERS_FWD
+
\
PIXELS_PER_THREAD_IN_SMEM_FWD
;
static
const
int
PIXELS_PER_THREAD_BWD
=
PIXELS_PER_THREAD_IN_REGISTERS_BWD
+
\
PIXELS_PER_THREAD_IN_SMEM_BWD
;
static
const
int
PIXELS_PER_THREAD_FWD_INFERENCE
=
4
;
// Derived params
static
const
size_t
SMEM_SIZE_FWD
=
PIXELS_PER_THREAD_IN_SMEM_FWD
*
THREADS_PER_CTA
*
\
ELEMENTS_PER_LDG
*
sizeof
(
StorageType
);
static
const
size_t
SMEM_SIZE_BWD
=
PIXELS_PER_THREAD_IN_SMEM_BWD
*
THREADS_PER_CTA
*
\
ELEMENTS_PER_LDG
*
2
*
sizeof
(
StorageType
);
static
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
static
const
int
PIXELS_PER_CTA_FWD
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
*
\
PIXELS_PER_THREAD_FWD
;
static
const
int
PIXELS_PER_CTA_BWD
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
*
\
PIXELS_PER_THREAD_BWD
;
static
const
int
PIXELS_PER_CTA_FWD_INFERENCE
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
*
\
PIXELS_PER_THREAD_FWD_INFERENCE
;
// max grid.y in case of group bn is limited by exchange buffer size
static
const
int
MAX_GBN_BLOCK_Y
=
256
;
// Helper function to launch the forward kernel.
// We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel
// version that was compiled with that occupancy in its launch bounds. This way, we avoid
// needless register spills.
void
_fwdKernelLauncher
(
cudaStream_t
stream
,
NhwcBatchNormFwdParams
params
,
dim3
grid_dim
,
int
outer_loops
,
bool
use_relu
,
const
int
occupancy
,
const
bool
coop
)
{
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(¶ms); \
using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
cudaLaunchKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if
(
outer_loops
==
1
&&
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
1
,
true
,
false
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
1
,
true
,
false
,
1
,
coop
);
}
else
if
(
outer_loops
==
1
&&
!
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
1
,
false
,
false
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
1
,
false
,
false
,
1
,
coop
);
}
else
if
(
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
0
,
true
,
false
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
0
,
true
,
false
,
1
,
coop
);
}
else
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
0
,
false
,
false
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
0
,
false
,
false
,
1
,
coop
);
}
#undef LAUNCH_FWD_KERNEL
}
// Helper function to launch the backward kernel.
void
_bwdKernelLauncher
(
cudaStream_t
stream
,
NhwcBatchNormBwdParams
params
,
dim3
grid_dim
,
int
outer_loops
,
bool
use_relu
,
const
int
occupancy
,
const
bool
coop
)
{
#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_func = nhwc_batch_norm_bwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(bwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " bwd coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(¶ms); \
using BWD_FUNC = decltype(nhwc_batch_norm_bwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
cudaLaunchKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd coop serial kernel"); \
} while (0)
#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_relu_func = nhwc_batch_norm_bwd_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(bwd_relu_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " bwd-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(¶ms); \
using BWD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
cudaLaunchKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if
(
outer_loops
==
1
&&
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_RELU_KERNEL
(
1
,
2
,
coop
);
else
LAUNCH_BWD_RELU_KERNEL
(
1
,
1
,
coop
);
}
else
if
(
outer_loops
==
1
&&
!
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_KERNEL
(
1
,
2
,
coop
);
else
LAUNCH_BWD_KERNEL
(
1
,
1
,
coop
);
}
else
if
(
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_RELU_KERNEL
(
0
,
2
,
coop
);
else
LAUNCH_BWD_RELU_KERNEL
(
0
,
1
,
coop
);
}
else
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_KERNEL
(
0
,
2
,
coop
);
else
LAUNCH_BWD_KERNEL
(
0
,
1
,
coop
);
}
#undef LAUNCH_BWD_KERNEL
}
public:
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static
int
smem_driven_fwd_occupancy
(
int
device_id
,
const
int
max_cta_per_sm
)
{
using
namespace
at
::
cuda
::
utils
;
int
fwd_reduction_bytes
=
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
32
)
*
ELEMENTS_PER_LDG
*
sizeof
(
float
);
int
fwd_smem_bytes
=
SMEM_SIZE_FWD
+
fwd_reduction_bytes
;
int
occupancy
=
MaxSharedMemoryPerMultiprocessor
(
device_id
)
/
fwd_smem_bytes
;
return
std
::
min
(
max_cta_per_sm
,
occupancy
);
}
// Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.
static
int
smem_driven_bwd_occupancy
(
int
device_id
,
const
int
max_cta_per_sm
)
{
using
namespace
at
::
cuda
::
utils
;
int
bwd_reduction_bytes
=
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
32
)
*
ELEMENTS_PER_LDG
*
sizeof
(
float
);
int
bwd_smem_bytes
=
SMEM_SIZE_BWD
+
bwd_reduction_bytes
;
int
occupancy
=
MaxSharedMemoryPerMultiprocessor
(
device_id
)
/
bwd_smem_bytes
;
return
std
::
min
(
max_cta_per_sm
,
occupancy
);
}
};
const
std
::
vector
<
size_t
>
NhwcBatchNorm
::
numWorkspaceBytes
()
const
{
assert
(
c_
>
0
);
// choose the max memory required between fwd/bwd passes
int
grid_x_fwd
=
div_up
(
m_
,
PIXELS_PER_CTA_FWD
);
int
grid_x_bwd
=
div_up
(
m_
,
PIXELS_PER_CTA_BWD
);
int
grid_x
=
max
(
grid_x_fwd
,
grid_x_bwd
);
int
grid_y
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
const
size_t
num_mean_bytes
=
c_
*
sizeof
(
float
);
const
size_t
num_variance_bytes
=
num_mean_bytes
;
const
size_t
size_sums
=
grid_y
*
grid_x
*
THREADS_PER_PIXEL
*
\
ELEMENTS_PER_LDG
*
2
*
sizeof
(
float
);
const
size_t
size_counts
=
grid_y
*
grid_x
*
sizeof
(
int
);
return
{
num_mean_bytes
,
num_variance_bytes
,
size_retired_ctas
(
grid_y
),
size_sums
,
size_counts
};
}
void
NhwcBatchNorm
::
setWorkspacePointers
(
const
std
::
vector
<
void
*>&
workspace
,
const
std
::
vector
<
size_t
>&
num_workspace_bytes
)
{
assert
(
workspace
.
size
()
==
5
);
assert
(
num_workspace_bytes
.
size
()
==
5
);
minibatch_mean_
=
static_cast
<
float
*>
(
workspace
[
0
]);
minibatch_variance_
=
static_cast
<
float
*>
(
workspace
[
1
]);
retired_ctas_
=
static_cast
<
int
*>
(
workspace
[
2
]);
partial_sums_
=
static_cast
<
float
*>
(
workspace
[
3
]);
partial_counts_
=
static_cast
<
int
*>
(
workspace
[
4
]);
}
void
NhwcBatchNorm
::
_setFwdParams
(
NhwcBatchNormFwdParams
*
params
)
const
{
params
->
gmem_src
=
static_cast
<
uint16_t
*>
(
X_
);
params
->
gmem_dst
=
static_cast
<
uint16_t
*>
(
Y_
);
params
->
gmem_src1
=
nullptr
;
params
->
gmem_bias
=
bias_
;
params
->
gmem_scale
=
scale_
;
params
->
gmem_running_mean
=
population_mean_
;
params
->
gmem_running_var
=
population_variance_
;
params
->
gmem_saved_mean
=
minibatch_mean_
;
params
->
gmem_saved_var
=
minibatch_variance_
;
params
->
gmem_relu_bitmask
=
nullptr
;
params
->
nhw
=
m_
;
params
->
c
=
c_
;
params
->
svar_inv_count
=
svar_inv_count_
;
params
->
rvar_inv_count
=
rvar_inv_count_
;
params
->
gmem_sums
=
partial_sums_
;
params
->
gmem_counts
=
partial_counts_
;
params
->
gmem_retired_ctas
=
retired_ctas_
;
params
->
var_eps
=
eps_
;
params
->
outer_loops
=
0
;
params
->
exp_avg_factor
=
static_cast
<
float
>
(
exp_avg_factor_
);
params
->
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
}
void
NhwcBatchNorm
::
_setFwdInferenceParams
(
NhwcBatchNormFwdInferenceParams
*
params
)
const
{
params
->
gmem_src
=
static_cast
<
uint16_t
*>
(
X_
);
params
->
gmem_dst
=
static_cast
<
uint16_t
*>
(
Y_
);
params
->
gmem_src1
=
nullptr
;
params
->
gmem_bias
=
bias_
;
params
->
gmem_scale
=
scale_
;
params
->
gmem_mean
=
population_mean_
;
params
->
gmem_var
=
population_variance_
;
params
->
nhw
=
m_
;
params
->
c
=
c_
;
params
->
var_eps
=
eps_
;
}
void
NhwcBatchNorm
::
_setBwdParams
(
NhwcBatchNormBwdParams
*
params
)
const
{
params
->
gmem_src
=
static_cast
<
uint16_t
*>
(
X_
);
params
->
gmem_dy
=
static_cast
<
uint16_t
*>
(
dY_
);
params
->
gmem_dst
=
static_cast
<
uint16_t
*>
(
dX_
);
params
->
gmem_dst1
=
nullptr
;
params
->
gmem_relu_bitmask
=
nullptr
;
params
->
gmem_dscale
=
dscale_
;
params
->
gmem_dbias
=
dbias_
;
params
->
gmem_scale
=
scale_
;
params
->
gmem_bias
=
bias_
;
params
->
gmem_saved_mean
=
minibatch_mean_
;
params
->
gmem_saved_var
=
minibatch_variance_
;
params
->
nhw
=
m_
;
params
->
c
=
c_
;
params
->
svar_inv_count
=
svar_inv_count_
;
params
->
gmem_sums
=
partial_sums_
;
params
->
gmem_retired_ctas
=
retired_ctas_
;
params
->
outer_loops
=
0
;
params
->
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
}
void
NhwcBatchNorm
::
fwdInference
(
cudaStream_t
stream
,
bool
use_relu
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
&&
scale_
!=
nullptr
&&
bias_
!=
nullptr
// && minibatch_mean_ != nullptr
// && minibatch_variance_ != nullptr
&&
population_mean_
!=
nullptr
&&
population_variance_
!=
nullptr
&&
X_
!=
nullptr
// && dX_ != nullptr
&&
Y_
!=
nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&&
partial_sums_
!=
nullptr
&&
partial_counts_
!=
nullptr
;
if
(
!
ptrs_are_set
)
die
();
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_FWD_INFERENCE
);
grid_dim
.
y
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
// @todo: maybe just move this inside initialize routine?
NhwcBatchNormFwdInferenceParams
params
;
_setFwdInferenceParams
(
&
params
);
if
(
use_relu
)
{
nhwc_batch_norm_fwd_inference
<
StorageType
,
THREADS_PER_CTA
,
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
,
true
,
false
>
<<<
grid_dim
,
THREADS_PER_CTA
,
0
,
stream
>>>
(
params
);
checkCudaStatus
(
name_
+
" fwd_inference-relu kernel"
);
}
else
{
nhwc_batch_norm_fwd_inference
<
StorageType
,
THREADS_PER_CTA
,
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
,
false
,
false
>
<<<
grid_dim
,
THREADS_PER_CTA
,
0
,
stream
>>>
(
params
);
checkCudaStatus
(
name_
+
" fwd_inference kernel"
);
}
}
dim3
NhwcBatchNorm
::
calc_fwd_grid
(
int
*
loop
,
const
int
grid_dim_x
)
{
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_FWD
);
int
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
unsigned
int
max_grid_x
=
grid_dim_x
;
if
(
grid_dim
.
x
<=
max_grid_x
)
{
*
loop
=
1
;
if
(
max_grid_x
/
grid_dim
.
x
>
1
)
{
grid_dim
.
y
=
std
::
min
(
c_blks
,
static_cast
<
int
>
(
max_grid_x
/
grid_dim
.
x
));
assert
(
grid_dim
.
y
<
MAX_GBN_BLOCK_Y
);
//FIXME: turn into a loop
}
else
{
grid_dim
.
y
=
1
;
}
}
else
{
grid_dim
.
x
=
max_grid_x
;
grid_dim
.
y
=
1
;
int
nhw_in_regs
=
m_
-
PIXELS_PER_THREAD_IN_SMEM_FWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
int
pixels_per_iteration
=
PIXELS_PER_THREAD_IN_REGISTERS_FWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
*
loop
=
div_up
(
nhw_in_regs
,
pixels_per_iteration
);
}
return
grid_dim
;
}
dim3
NhwcBatchNorm
::
calc_bwd_grid
(
int
*
loop
,
const
int
grid_dim_x
)
{
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_BWD
);
int
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
unsigned
int
max_grid_x
=
grid_dim_x
;
if
(
grid_dim
.
x
<=
max_grid_x
)
{
*
loop
=
1
;
if
(
max_grid_x
/
grid_dim
.
x
>
1
)
{
grid_dim
.
y
=
std
::
min
(
c_blks
,
static_cast
<
int
>
(
max_grid_x
/
grid_dim
.
x
));
assert
(
grid_dim
.
y
<
MAX_GBN_BLOCK_Y
);
//FIXME: turn into a loop
}
else
{
grid_dim
.
y
=
1
;
}
}
else
{
grid_dim
.
x
=
max_grid_x
;
grid_dim
.
y
=
1
;
int
nhw_in_regs
=
m_
-
PIXELS_PER_THREAD_IN_SMEM_BWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
int
pixels_per_iteration
=
PIXELS_PER_THREAD_IN_REGISTERS_BWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
*
loop
=
div_up
(
nhw_in_regs
,
pixels_per_iteration
);
}
return
grid_dim
;
}
void
NhwcBatchNorm
::
fwd
(
cudaStream_t
stream
,
bool
use_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
&&
scale_
!=
nullptr
&&
bias_
!=
nullptr
&&
minibatch_mean_
!=
nullptr
&&
minibatch_variance_
!=
nullptr
&&
population_mean_
!=
nullptr
&&
population_variance_
!=
nullptr
&&
X_
!=
nullptr
// && dX_ != nullptr
&&
Y_
!=
nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&&
partial_sums_
!=
nullptr
&&
partial_counts_
!=
nullptr
&&
retired_ctas_
!=
nullptr
;
if
(
!
ptrs_are_set
)
die
();
// reset of retired_cta_count no longer needed
NhwcBatchNormFwdParams
params
;
_setFwdParams
(
&
params
);
params
.
my_data
=
my_data
;
params
.
pair_datas
[
0
]
=
pair_data
;
params
.
pair_datas
[
1
]
=
pair_data2
;
params
.
pair_datas
[
2
]
=
pair_data3
;
params
.
magic
=
magic
;
params
.
sync_iters
=
(
bn_group
==
8
)
?
3
:
(
bn_group
>>
1
);
dim3
grid_dim
=
calc_fwd_grid
(
&
params
.
outer_loops
,
grid_dim_x
);
_fwdKernelLauncher
(
stream
,
params
,
grid_dim
,
params
.
outer_loops
,
use_relu
,
occupancy
,
coop
);
}
void
NhwcBatchNorm
::
dgrad
(
cudaStream_t
stream
,
bool
use_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
&&
scale_
!=
nullptr
&&
(
bias_
!=
nullptr
||
!
use_relu
)
&&
minibatch_mean_
!=
nullptr
&&
minibatch_variance_
!=
nullptr
// && population_mean_ != nullptr
// && population_variance_ != nullptr
&&
X_
!=
nullptr
&&
dX_
!=
nullptr
// && Y_ != nullptr
&&
dY_
!=
nullptr
&&
dscale_
!=
nullptr
&&
dbias_
!=
nullptr
;
if
(
!
ptrs_are_set
)
die
();
// reset of retired_cta_count no longer needed
NhwcBatchNormBwdParams
params
;
_setBwdParams
(
&
params
);
params
.
my_data
=
my_data
;
params
.
pair_datas
[
0
]
=
pair_data
;
params
.
pair_datas
[
1
]
=
pair_data2
;
params
.
pair_datas
[
2
]
=
pair_data3
;
params
.
magic
=
magic
;
params
.
sync_iters
=
(
bn_group
==
8
)
?
3
:
(
bn_group
>>
1
);
params
.
wgrad_coeff
=
1.0
/
bn_group
;
dim3
grid_dim
=
calc_bwd_grid
(
&
params
.
outer_loops
,
grid_dim_x
);
_bwdKernelLauncher
(
stream
,
params
,
grid_dim
,
params
.
outer_loops
,
use_relu
,
occupancy
,
coop
);
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
apex/contrib/csrc/groupbn/batch_norm_add_relu.cu
0 → 100644
View file @
4d6ed501
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include "THC/THC.h"
#include "batch_norm_add_relu.h"
#include <cuda.h>
//FIXME move the common stuff to common h file
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
static
size_t
round_up_to_multiple
(
size_t
x
,
int
multiple
)
{
return
((
x
+
multiple
-
1
)
/
multiple
)
*
multiple
;
}
// TODO: Stop manually allocating CUDA memory; allocate an ATen byte
// tensor instead.
struct
Workspace
{
Workspace
(
size_t
size
)
:
size
(
size
),
data
(
NULL
)
{
data
=
THCudaMalloc
(
at
::
globalContext
().
lazyInitCUDA
(),
size
);
}
Workspace
(
const
Workspace
&
)
=
delete
;
Workspace
(
Workspace
&&
)
=
default
;
Workspace
&
operator
=
(
Workspace
&&
)
=
default
;
~
Workspace
()
{
if
(
data
)
{
THCudaFree
(
at
::
globalContext
().
lazyInitCUDA
(),
data
);
}
}
size_t
size
;
void
*
data
;
};
// Return {y}
at
::
Tensor
nhwc_bn_addrelu_fwd_train
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
z
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
bitmask
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
const
int
N
=
x
.
size
(
0
);
const
int
H
=
x
.
size
(
1
);
const
int
W
=
x
.
size
(
2
);
const
int
C
=
x
.
size
(
3
);
// generating new magic number and use that for sync
int
*
magic
=
magic_tensor
.
data
<
int
>
();
*
magic
=
(
*
magic
+
1
)
&
0xff
;
// Allocate output tensor
at
::
Tensor
y
=
at
::
empty
({
N
,
H
,
W
,
C
},
x
.
options
());
// Create wrapper
NhwcBatchNormAddRelu
*
bn
=
new
NhwcBatchNormAddRelu
();
bn
->
setInputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
,
bn_group
);
bn
->
setOutputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
);
bn
->
setConstants
(
momentum
,
epsilon
);
// set pointers within the wrapper
bn
->
setInputOutputPointers
(
x
.
data
<
at
::
Half
>
(),
nullptr
,
y
.
data
<
at
::
Half
>
(),
nullptr
,
z
.
data
<
at
::
Half
>
(),
nullptr
);
bn
->
setWeightPointers
({
scale
.
data
<
float
>
(),
bias
.
data
<
float
>
()},
{
nullptr
,
nullptr
});
bn
->
setParameterPointers
({
running_mean
.
data
<
float
>
(),
running_inv_var
.
data
<
float
>
()});
// deal with workspace(s)
auto
workspace_bytes
=
bn
->
numWorkspaceBytes
();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t
total_workspace_bytes
=
0
;
std
::
vector
<
size_t
>
workspace_offsets
;
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
total_workspace_bytes
=
round_up_to_multiple
(
total_workspace_bytes
,
512
);
workspace_offsets
.
push_back
(
total_workspace_bytes
);
auto
alloc_bytes
=
workspace_bytes
[
index
];
total_workspace_bytes
+=
alloc_bytes
;
}
// Allocate the workspace
Workspace
ws
(
total_workspace_bytes
);
std
::
vector
<
void
*>
workspace
;
workspace
.
push_back
(
minibatch_mean
.
data
<
float
>
());
workspace
.
push_back
(
minibatch_inv_var
.
data
<
float
>
());
workspace
.
push_back
(
bitmask
.
data
<
int32_t
>
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
3
];
void
*
retired_ctas
=
ret_cta
.
data
<
uint8_t
>
();
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
void
*
ptr
=
reinterpret_cast
<
uint8_t
*>
(
ws
.
data
)
+
workspace_offsets
[
index
-
4
];
workspace
.
push_back
(
ptr
);
}
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
// Don't fuse in ReLU for now at least
bn
->
fwd
(
stream
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
*
magic
,
occupancy
,
grid_dim_x
,
coop
);
return
y
;
}
at
::
Tensor
nhwc_bn_addrelu_fwd_eval
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
z
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
int
bn_group
,
const
float
momentum
,
const
float
epsilon
)
{
const
int
N
=
x
.
size
(
0
);
const
int
H
=
x
.
size
(
1
);
const
int
W
=
x
.
size
(
2
);
const
int
C
=
x
.
size
(
3
);
// Allocate output tensor
at
::
Tensor
y
=
at
::
empty
({
N
,
H
,
W
,
C
},
x
.
options
());
// Create wrapper
NhwcBatchNormAddRelu
*
bn
=
new
NhwcBatchNormAddRelu
();
bn
->
setInputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
,
bn_group
);
bn
->
setOutputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
);
bn
->
setConstants
(
momentum
,
epsilon
);
// set pointers within the wrapper
bn
->
setInputOutputPointers
(
x
.
data
<
at
::
Half
>
(),
nullptr
,
y
.
data
<
at
::
Half
>
(),
nullptr
,
z
.
data
<
at
::
Half
>
(),
nullptr
);
bn
->
setWeightPointers
({
scale
.
data
<
float
>
(),
bias
.
data
<
float
>
()},
{
nullptr
,
nullptr
});
bn
->
setParameterPointers
({
running_mean
.
data
<
float
>
(),
running_inv_var
.
data
<
float
>
()});
// deal with workspace(s)
auto
workspace_bytes
=
bn
->
numWorkspaceBytes
();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t
total_workspace_bytes
=
0
;
std
::
vector
<
size_t
>
workspace_offsets
;
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
total_workspace_bytes
=
round_up_to_multiple
(
total_workspace_bytes
,
512
);
workspace_offsets
.
push_back
(
total_workspace_bytes
);
auto
alloc_bytes
=
workspace_bytes
[
index
];
total_workspace_bytes
+=
alloc_bytes
;
}
// Allocate the workspace
Workspace
ws
(
total_workspace_bytes
);
std
::
vector
<
void
*>
workspace
;
workspace
.
push_back
(
nullptr
);
workspace
.
push_back
(
nullptr
);
workspace
.
push_back
(
nullptr
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
3
];
void
*
retired_ctas
=
ret_cta
.
data
<
uint8_t
>
();
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
void
*
ptr
=
reinterpret_cast
<
uint8_t
*>
(
ws
.
data
)
+
workspace_offsets
[
index
-
4
];
workspace
.
push_back
(
ptr
);
}
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
// Don't fuse in ReLU for now at least
bn
->
fwdInference
(
stream
);
return
y
;
}
std
::
vector
<
at
::
Tensor
>
nhwc_bn_addrelu_bwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
dy
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
bitmask
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
// shape
const
int
N
=
x
.
size
(
0
);
const
int
H
=
x
.
size
(
1
);
const
int
W
=
x
.
size
(
2
);
const
int
C
=
x
.
size
(
3
);
// generating new magic number and use that for sync
int
*
magic
=
magic_tensor
.
data
<
int
>
();
*
magic
=
(
*
magic
+
1
)
&
0xff
;
// outputs
at
::
Tensor
x_grad
,
z_grad
,
scale_grad
,
bias_grad
;
// Allocate outputs
x_grad
=
at
::
empty_like
(
x
);
z_grad
=
at
::
empty_like
(
x
);
scale_grad
=
at
::
empty_like
(
scale
);
bias_grad
=
at
::
empty_like
(
bias
);
// Create wrapper
NhwcBatchNormAddRelu
*
bn
=
new
NhwcBatchNormAddRelu
();
bn
->
setInputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
,
bn_group
);
bn
->
setOutputDescriptor
(
CUDNN_TENSOR_NHWC
,
CUDNN_DATA_HALF
,
N
,
C
,
H
,
W
);
bn
->
setConstants
(
momentum
,
epsilon
);
// set pointers within the wrapper
bn
->
setInputOutputPointers
(
x
.
data
<
at
::
Half
>
(),
x_grad
.
data
<
at
::
Half
>
(),
nullptr
,
dy
.
data
<
at
::
Half
>
(),
nullptr
,
z_grad
.
data
<
at
::
Half
>
());
bn
->
setWeightPointers
({
scale
.
data
<
float
>
(),
bias
.
data
<
float
>
()},
{
scale_grad
.
data
<
float
>
(),
bias_grad
.
data
<
float
>
()});
bn
->
setParameterPointers
({
running_mean
.
data
<
float
>
(),
running_inv_var
.
data
<
float
>
()});
// deal with workspace(s)
auto
workspace_bytes
=
bn
->
numWorkspaceBytes
();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t
total_workspace_bytes
=
0
;
std
::
vector
<
size_t
>
workspace_offsets
;
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
total_workspace_bytes
=
round_up_to_multiple
(
total_workspace_bytes
,
512
);
workspace_offsets
.
push_back
(
total_workspace_bytes
);
auto
alloc_bytes
=
workspace_bytes
[
index
];
total_workspace_bytes
+=
alloc_bytes
;
}
// Allocate the workspace
Workspace
ws
(
total_workspace_bytes
);
std
::
vector
<
void
*>
workspace
;
workspace
.
push_back
(
minibatch_mean
.
data
<
float
>
());
workspace
.
push_back
(
minibatch_inv_var
.
data
<
float
>
());
workspace
.
push_back
(
bitmask
.
data
<
int32_t
>
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
3
];
void
*
retired_ctas
=
ret_cta
.
data
<
uint8_t
>
();
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
void
*
ptr
=
reinterpret_cast
<
uint8_t
*>
(
ws
.
data
)
+
workspace_offsets
[
index
-
4
];
workspace
.
push_back
(
ptr
);
}
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
bn
->
dgrad
(
stream
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
*
magic
,
occupancy
,
grid_dim_x
,
coop
);
return
std
::
vector
<
at
::
Tensor
>
{
x_grad
,
z_grad
,
scale_grad
,
bias_grad
};
}
int
nhwc_bn_addrelu_fwd_occupancy
()
{
int
device_id
=-
1
;
cudaGetDevice
(
&
device_id
);
//max occupancy supported by the code is 2
return
NhwcBatchNormAddRelu
::
smem_driven_fwd_occupancy
(
device_id
,
2
);
}
int
nhwc_bn_addrelu_bwd_occupancy
()
{
int
device_id
=-
1
;
cudaGetDevice
(
&
device_id
);
//max occupancy supported by the code is 2
return
NhwcBatchNormAddRelu
::
smem_driven_bwd_occupancy
(
device_id
,
2
);
}
apex/contrib/csrc/groupbn/batch_norm_add_relu.h
0 → 100644
View file @
4d6ed501
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nhwc_batch_norm_add_relu.h
* \brief CUDA NHWC Batch Normalization code with fused addition
* \author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer
*/
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
#include <cudnn.h>
#include <algorithm>
#include <vector>
#include <string>
#include "nhwc_batch_norm_kernel.h"
#include "cuda_utils.h"
#define VERBOSE_DEFAULT false
class
NhwcBatchNormAddRelu
{
public:
NhwcBatchNormAddRelu
()
{
name_
=
"nhwc_batchnormaddrelu"
;
createTensorDescriptor
(
&
X_tensor_desc_
);
createTensorDescriptor
(
&
Y_tensor_desc_
);
}
~
NhwcBatchNormAddRelu
()
{
destroyTensorDescriptor
(
X_tensor_desc_
);
destroyTensorDescriptor
(
Y_tensor_desc_
);
}
void
die
()
{
std
::
cerr
<<
"batchnormaddrelu not initialized"
<<
std
::
endl
;
exit
(
-
1
);
}
void
fwd
(
cudaStream_t
stream
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
void
dgrad
(
cudaStream_t
stream
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
void
fwdInference
(
cudaStream_t
stream
);
dim3
calc_fwd_grid
(
int
*
loop
,
const
int
grid_dim_x
);
dim3
calc_bwd_grid
(
int
*
loop
,
const
int
grid_dim_x
);
void
setInputDescriptor
(
const
cudnnTensorFormat_t
format
,
const
cudnnDataType_t
data_type
,
int
n
,
int
c
,
int
h
,
int
w
,
int
bn_group
)
{
m_
=
n
*
h
*
w
;
int
m_bn_adjusted
=
m_
*
bn_group
;
c_
=
c
;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
svar_inv_count_
=
1.
f
/
m_bn_adjusted
;
// factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1).
int
divisor
=
m_bn_adjusted
-
1
;
// nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs.
rvar_inv_count_
=
divisor
==
0
?
1.
f
:
1.
f
/
divisor
;
setTensorDescriptor
(
X_tensor_desc_
,
format
,
data_type
,
n
,
c
,
h
,
w
);
}
void
setOutputDescriptor
(
const
cudnnTensorFormat_t
format
,
const
cudnnDataType_t
data_type
,
int
n
,
int
c
,
int
h
,
int
w
)
{
setTensorDescriptor
(
Y_tensor_desc_
,
format
,
data_type
,
n
,
c
,
h
,
w
);
}
const
std
::
vector
<
size_t
>
numWorkspaceBytes
()
const
;
void
setWorkspacePointers
(
const
std
::
vector
<
void
*>&
workspace
,
const
std
::
vector
<
size_t
>&
num_workspace_bytes
);
void
setInputOutputPointers
(
void
*
X
,
void
*
dX
,
void
*
Y
,
void
*
dY
,
void
*
addend
,
void
*
dAddend
)
{
X_
=
X
;
dX_
=
dX
;
Y_
=
Y
;
dY_
=
dY
;
addend_
=
addend
;
dAddend_
=
dAddend
;
}
// Sets the pointers for the scale and weight (in that order) data and derivative buffers.
void
setWeightPointers
(
const
std
::
vector
<
void
*>&
weight_pointers
,
const
std
::
vector
<
void
*>&
deriv_pointers
)
{
assert
(
weight_pointers
.
size
()
==
2
);
assert
(
deriv_pointers
.
size
()
==
2
);
scale_
=
static_cast
<
float
*>
(
weight_pointers
[
0
]);
bias_
=
static_cast
<
float
*>
(
weight_pointers
[
1
]);
dscale_
=
static_cast
<
float
*>
(
deriv_pointers
[
0
]);
dbias_
=
static_cast
<
float
*>
(
deriv_pointers
[
1
]);
}
// Sets the pointers for the population mean and variance buffers, in that order.
void
setParameterPointers
(
const
std
::
vector
<
void
*>&
param_pointers
)
{
assert
(
param_pointers
.
size
()
==
2
);
population_mean_
=
static_cast
<
float
*>
(
param_pointers
[
0
]);
population_variance_
=
static_cast
<
float
*>
(
param_pointers
[
1
]);
}
void
setConstants
(
const
double
exp_avg_factor
,
const
double
eps
)
{
exp_avg_factor_
=
exp_avg_factor
;
eps_
=
eps
;
}
void
processCudnnStatus
(
const
cudnnStatus_t
&
status
,
const
std
::
string
&
string
=
std
::
string
(),
bool
verbose
=
VERBOSE_DEFAULT
)
{
if
(
status
!=
CUDNN_STATUS_SUCCESS
)
LOG
(
FATAL
)
<<
string
<<
" "
<<
cudnnGetErrorString
(
status
);
else
if
(
verbose
)
LOG
(
INFO
)
<<
string
<<
" "
<<
cudnnGetErrorString
(
status
);
}
void
checkCudaStatus
(
const
std
::
string
&
string
=
std
::
string
(),
bool
verbose
=
VERBOSE_DEFAULT
)
{
cudaError_t
status
=
cudaGetLastError
();
if
(
status
!=
cudaSuccess
)
LOG
(
FATAL
)
<<
string
<<
" "
<<
cudaGetErrorString
(
status
);
else
if
(
verbose
)
LOG
(
INFO
)
<<
string
<<
" "
<<
cudaGetErrorString
(
status
);
}
size_t
size_retired_ctas
(
int
grid_y
)
const
{
// Note that the value of max_grid_y to handle known GPUs is about 160.
const
int
max_grid_y
=
1024
;
if
(
grid_y
>
max_grid_y
)
LOG
(
INFO
)
<<
"GPU capabilities exceeds assumptions."
;
const
int
retired_cta_bytes
=
max_grid_y
*
2
*
sizeof
(
int
);
// Since the region will be initialized once and used for many kernels,
// the idea is to return an ample size that will cover all uses.
return
retired_cta_bytes
;
}
cudnnTensorDescriptor_t
X_tensor_desc_
=
nullptr
;
cudnnTensorDescriptor_t
Y_tensor_desc_
=
nullptr
;
void
*
X_
=
nullptr
;
void
*
dX_
=
nullptr
;
void
*
Y_
=
nullptr
;
void
*
dY_
=
nullptr
;
void
*
addend_
=
nullptr
;
void
*
dAddend_
=
nullptr
;
// Learned scale and bias weights.
float
*
scale_
=
nullptr
;
float
*
dscale_
=
nullptr
;
float
*
bias_
=
nullptr
;
float
*
dbias_
=
nullptr
;
// Computed population mean and variance parameters.
float
*
population_mean_
=
nullptr
;
float
*
population_variance_
=
nullptr
;
// Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd).
float
*
minibatch_mean_
=
nullptr
;
float
*
minibatch_variance_
=
nullptr
;
int
m_
=
0
;
// Number of values per channel that BN is normalizing.
int
c_
=
0
;
// Number of channels over which BN is normalizing.
float
svar_inv_count_
=
0.
f
;
// factor to scale sum of squared errors to get saved variance
float
rvar_inv_count_
=
0.
f
;
// factor to scale sum of squared errors to get running variance
double
exp_avg_factor_
=
0.
;
double
eps_
=
0.
;
std
::
string
name_
;
private:
void
setTensorDescriptor
(
cudnnTensorDescriptor_t
descriptor
,
cudnnTensorFormat_t
format
,
cudnnDataType_t
data_type
,
int
n
,
int
c
,
int
h
,
int
w
)
{
cudnnStatus_t
status
=
CUDNN_STATUS_SUCCESS
;
status
=
cudnnSetTensor4dDescriptor
(
descriptor
,
format
,
data_type
,
n
,
c
,
h
,
w
);
processCudnnStatus
(
status
,
"set tensor descriptor"
);
}
void
createTensorDescriptor
(
cudnnTensorDescriptor_t
*
descriptor
)
{
cudnnStatus_t
status
=
CUDNN_STATUS_SUCCESS
;
status
=
cudnnCreateTensorDescriptor
(
descriptor
);
processCudnnStatus
(
status
,
"create tensor_descriptor"
);
}
void
destroyTensorDescriptor
(
cudnnTensorDescriptor_t
descriptor
)
{
cudnnStatus_t
status
=
CUDNN_STATUS_SUCCESS
;
status
=
cudnnDestroyTensorDescriptor
(
descriptor
);
processCudnnStatus
(
status
,
"destroy tensor_descriptor"
);
}
protected:
float
*
partial_sums_
=
nullptr
;
int
*
partial_counts_
=
nullptr
;
int
*
retired_ctas_
=
nullptr
;
unsigned
int
*
relu_bitmask_
=
nullptr
;
void
_setFwdParams
(
NhwcBatchNormFwdParams
*
params
)
const
;
void
_setFwdInferenceParams
(
NhwcBatchNormFwdInferenceParams
*
params
)
const
;
void
_setBwdParams
(
NhwcBatchNormBwdParams
*
params
)
const
;
// @todo: ability to configure these?
// Kernel params
static
const
int
USE_ONLINE_APPROACH
=
1
;
static
const
int
THREADS_PER_CTA
=
512
;
static
const
int
THREADS_PER_PIXEL
=
16
;
static
const
int
C_ELEMENTS_PER_CTA
=
64
;
static
const
int
ELEMENTS_PER_LDG
=
C_ELEMENTS_PER_CTA
/
THREADS_PER_PIXEL
;
static
const
int
MAX_SMEM_WITHOUT_OPT_IN
=
48
*
1024
;
typedef
uint16_t
StorageType
;
// increasing this to 6 causes spills in fwd kernel!
static
const
int
PIXELS_PER_THREAD_IN_REGISTERS_FWD
=
5
;
static
const
int
PIXELS_PER_THREAD_IN_REGISTERS_BWD
=
3
;
static
const
int
PIXELS_PER_THREAD_IN_SMEM_FWD
=
10
;
static
const
int
PIXELS_PER_THREAD_IN_SMEM_BWD
=
5
;
static
const
int
PIXELS_PER_THREAD_FWD
=
PIXELS_PER_THREAD_IN_REGISTERS_FWD
+
\
PIXELS_PER_THREAD_IN_SMEM_FWD
;
static
const
int
PIXELS_PER_THREAD_BWD
=
PIXELS_PER_THREAD_IN_REGISTERS_BWD
+
\
PIXELS_PER_THREAD_IN_SMEM_BWD
;
static
const
int
PIXELS_PER_THREAD_FWD_INFERENCE
=
4
;
// Derived params
static
const
size_t
SMEM_SIZE_FWD
=
PIXELS_PER_THREAD_IN_SMEM_FWD
*
THREADS_PER_CTA
*
\
ELEMENTS_PER_LDG
*
sizeof
(
StorageType
);
static
const
size_t
SMEM_SIZE_BWD
=
PIXELS_PER_THREAD_IN_SMEM_BWD
*
THREADS_PER_CTA
*
\
ELEMENTS_PER_LDG
*
2
*
sizeof
(
StorageType
);
static
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
static
const
int
PIXELS_PER_CTA_FWD
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
*
\
PIXELS_PER_THREAD_FWD
;
static
const
int
PIXELS_PER_CTA_BWD
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
*
\
PIXELS_PER_THREAD_BWD
;
static
const
int
PIXELS_PER_CTA_FWD_INFERENCE
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
*
\
PIXELS_PER_THREAD_FWD_INFERENCE
;
// max grid.y in case of group bn is limited by exchange buffer size
static
const
int
MAX_GBN_BLOCK_Y
=
256
;
// Helper function to launch the forward kernel.
// We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel
// version that was compiled with that occupancy in its launch bounds. This way, we avoid
// needless register spills.
void
_fwdKernelLauncher
(
cudaStream_t
stream
,
NhwcBatchNormFwdParams
params
,
dim3
grid_dim
,
int
outer_loops
,
const
int
occupancy
,
const
bool
coop
)
{
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(¶ms); \
using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
cudaLaunchKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if
(
outer_loops
==
1
)
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
1
,
false
,
true
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
1
,
false
,
true
,
1
,
coop
);
}
else
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
0
,
false
,
true
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
0
,
false
,
true
,
1
,
coop
);
}
#undef LAUNCH_FWD_KERNEL
}
// Helper function to launch the backward kernel.
void
_bwdKernelLauncher
(
cudaStream_t
stream
,
NhwcBatchNormBwdParams
params
,
dim3
grid_dim
,
int
outer_loops
,
const
int
occupancy
,
const
bool
coop
)
{
#define LAUNCH_BWD_ADD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
auto bwd_add_relu_func = nhwc_batch_norm_bwd_add_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(bwd_add_relu_func, \
cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + \
" bwd-add-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(¶ms); \
using BWD_ADD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_add_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
cudaLaunchKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-add-relu coop serial kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if
(
outer_loops
==
1
)
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_ADD_RELU_KERNEL
(
1
,
2
,
coop
);
else
LAUNCH_BWD_ADD_RELU_KERNEL
(
1
,
1
,
coop
);
}
else
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_ADD_RELU_KERNEL
(
0
,
2
,
coop
);
else
LAUNCH_BWD_ADD_RELU_KERNEL
(
0
,
1
,
coop
);
}
#undef LAUNCH_BWD_KERNEL
}
public:
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static
int
smem_driven_fwd_occupancy
(
int
device_id
,
const
int
max_cta_per_sm
)
{
using
namespace
at
::
cuda
::
utils
;
int
fwd_reduction_bytes
=
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
32
)
*
ELEMENTS_PER_LDG
*
sizeof
(
float
);
int
fwd_smem_bytes
=
SMEM_SIZE_FWD
+
fwd_reduction_bytes
;
int
occupancy
=
MaxSharedMemoryPerMultiprocessor
(
device_id
)
/
fwd_smem_bytes
;
return
std
::
min
(
max_cta_per_sm
,
occupancy
);
}
// Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.
static
int
smem_driven_bwd_occupancy
(
int
device_id
,
const
int
max_cta_per_sm
)
{
using
namespace
at
::
cuda
::
utils
;
int
bwd_reduction_bytes
=
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
32
)
*
ELEMENTS_PER_LDG
*
sizeof
(
float
);
int
bwd_smem_bytes
=
SMEM_SIZE_BWD
+
bwd_reduction_bytes
;
int
occupancy
=
MaxSharedMemoryPerMultiprocessor
(
device_id
)
/
bwd_smem_bytes
;
return
std
::
min
(
max_cta_per_sm
,
occupancy
);
}
};
const
std
::
vector
<
size_t
>
NhwcBatchNormAddRelu
::
numWorkspaceBytes
()
const
{
assert
(
c_
>
0
);
// choose the max memory required between fwd/bwd passes
int
grid_x_fwd
=
div_up
(
m_
,
PIXELS_PER_CTA_FWD
);
int
grid_x_bwd
=
div_up
(
m_
,
PIXELS_PER_CTA_BWD
);
int
grid_x
=
max
(
grid_x_fwd
,
grid_x_bwd
);
int
grid_y
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
const
size_t
num_mean_bytes
=
c_
*
sizeof
(
float
);
const
size_t
num_variance_bytes
=
num_mean_bytes
;
int
elems_per_group
=
((
m_
+
31
)
&
~
31
)
*
2
;
int
group_count
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
const
size_t
bitmask_bytes
=
elems_per_group
*
group_count
*
sizeof
(
unsigned
int
);
const
size_t
size_sums
=
grid_y
*
grid_x
*
THREADS_PER_PIXEL
*
\
ELEMENTS_PER_LDG
*
2
*
sizeof
(
float
);
const
size_t
size_counts
=
grid_y
*
grid_x
*
sizeof
(
int
);
return
{
num_mean_bytes
,
num_variance_bytes
,
bitmask_bytes
,
size_retired_ctas
(
grid_y
),
size_sums
,
size_counts
};
}
void
NhwcBatchNormAddRelu
::
setWorkspacePointers
(
const
std
::
vector
<
void
*>&
workspace
,
const
std
::
vector
<
size_t
>&
num_workspace_bytes
)
{
assert
(
workspace
.
size
()
==
6
);
assert
(
num_workspace_bytes
.
size
()
==
6
);
minibatch_mean_
=
static_cast
<
float
*>
(
workspace
[
0
]);
minibatch_variance_
=
static_cast
<
float
*>
(
workspace
[
1
]);
relu_bitmask_
=
static_cast
<
unsigned
int
*>
(
workspace
[
2
]);
retired_ctas_
=
static_cast
<
int
*>
(
workspace
[
3
]);
partial_sums_
=
static_cast
<
float
*>
(
workspace
[
4
]);
partial_counts_
=
static_cast
<
int
*>
(
workspace
[
5
]);
}
void
NhwcBatchNormAddRelu
::
_setFwdParams
(
NhwcBatchNormFwdParams
*
params
)
const
{
params
->
gmem_src
=
static_cast
<
uint16_t
*>
(
X_
);
params
->
gmem_dst
=
static_cast
<
uint16_t
*>
(
Y_
);
params
->
gmem_src1
=
static_cast
<
uint16_t
*>
(
addend_
);
params
->
gmem_bias
=
bias_
;
params
->
gmem_scale
=
scale_
;
params
->
gmem_running_mean
=
population_mean_
;
params
->
gmem_running_var
=
population_variance_
;
params
->
gmem_saved_mean
=
minibatch_mean_
;
params
->
gmem_saved_var
=
minibatch_variance_
;
params
->
gmem_relu_bitmask
=
relu_bitmask_
;
params
->
nhw
=
m_
;
params
->
c
=
c_
;
params
->
svar_inv_count
=
svar_inv_count_
;
params
->
rvar_inv_count
=
rvar_inv_count_
;
params
->
gmem_sums
=
partial_sums_
;
params
->
gmem_counts
=
partial_counts_
;
params
->
gmem_retired_ctas
=
retired_ctas_
;
params
->
var_eps
=
eps_
;
params
->
outer_loops
=
0
;
params
->
exp_avg_factor
=
static_cast
<
float
>
(
exp_avg_factor_
);
params
->
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
}
void
NhwcBatchNormAddRelu
::
_setFwdInferenceParams
(
NhwcBatchNormFwdInferenceParams
*
params
)
const
{
params
->
gmem_src
=
static_cast
<
uint16_t
*>
(
X_
);
params
->
gmem_dst
=
static_cast
<
uint16_t
*>
(
Y_
);
params
->
gmem_src1
=
static_cast
<
uint16_t
*>
(
addend_
);
params
->
gmem_bias
=
bias_
;
params
->
gmem_scale
=
scale_
;
params
->
gmem_mean
=
population_mean_
;
params
->
gmem_var
=
population_variance_
;
params
->
nhw
=
m_
;
params
->
c
=
c_
;
params
->
var_eps
=
eps_
;
}
void
NhwcBatchNormAddRelu
::
_setBwdParams
(
NhwcBatchNormBwdParams
*
params
)
const
{
params
->
gmem_src
=
static_cast
<
uint16_t
*>
(
X_
);
params
->
gmem_dy
=
static_cast
<
uint16_t
*>
(
dY_
);
params
->
gmem_dst
=
static_cast
<
uint16_t
*>
(
dX_
);
params
->
gmem_dst1
=
static_cast
<
uint16_t
*>
(
dAddend_
);
params
->
gmem_relu_bitmask
=
relu_bitmask_
;
params
->
gmem_dscale
=
dscale_
;
params
->
gmem_dbias
=
dbias_
;
params
->
gmem_scale
=
scale_
;
params
->
gmem_bias
=
bias_
;
params
->
gmem_saved_mean
=
minibatch_mean_
;
params
->
gmem_saved_var
=
minibatch_variance_
;
params
->
nhw
=
m_
;
params
->
c
=
c_
;
params
->
svar_inv_count
=
svar_inv_count_
;
params
->
gmem_sums
=
partial_sums_
;
params
->
gmem_retired_ctas
=
retired_ctas_
;
params
->
outer_loops
=
0
;
params
->
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
}
void
NhwcBatchNormAddRelu
::
fwdInference
(
cudaStream_t
stream
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
&&
scale_
!=
nullptr
&&
bias_
!=
nullptr
// && minibatch_mean_ != nullptr
// && minibatch_variance_ != nullptr
&&
population_mean_
!=
nullptr
&&
population_variance_
!=
nullptr
&&
X_
!=
nullptr
// && dX_ != nullptr
&&
Y_
!=
nullptr
&&
addend_
!=
nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&&
partial_sums_
!=
nullptr
&&
partial_counts_
!=
nullptr
;
if
(
!
ptrs_are_set
)
die
();
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_FWD_INFERENCE
);
grid_dim
.
y
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
// @todo: maybe just move this inside initialize routine?
NhwcBatchNormFwdInferenceParams
params
;
_setFwdInferenceParams
(
&
params
);
nhwc_batch_norm_fwd_inference
<
StorageType
,
THREADS_PER_CTA
,
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
,
false
,
true
>
<<<
grid_dim
,
THREADS_PER_CTA
,
0
,
stream
>>>
(
params
);
checkCudaStatus
(
name_
+
" fwd_inference-relu kernel"
);
}
dim3
NhwcBatchNormAddRelu
::
calc_fwd_grid
(
int
*
loop
,
const
int
grid_dim_x
)
{
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_FWD
);
int
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
unsigned
int
max_grid_x
=
grid_dim_x
;
if
(
grid_dim
.
x
<=
max_grid_x
)
{
*
loop
=
1
;
if
(
max_grid_x
/
grid_dim
.
x
>
1
)
{
grid_dim
.
y
=
std
::
min
(
c_blks
,
static_cast
<
int
>
(
max_grid_x
/
grid_dim
.
x
));
assert
(
grid_dim
.
y
<
MAX_GBN_BLOCK_Y
);
//FIXME: turn into a loop
}
else
{
grid_dim
.
y
=
1
;
}
}
else
{
grid_dim
.
x
=
max_grid_x
;
grid_dim
.
y
=
1
;
int
nhw_in_regs
=
m_
-
PIXELS_PER_THREAD_IN_SMEM_FWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
int
pixels_per_iteration
=
PIXELS_PER_THREAD_IN_REGISTERS_FWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
*
loop
=
div_up
(
nhw_in_regs
,
pixels_per_iteration
);
}
return
grid_dim
;
}
dim3
NhwcBatchNormAddRelu
::
calc_bwd_grid
(
int
*
loop
,
const
int
grid_dim_x
)
{
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_BWD
);
int
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
unsigned
int
max_grid_x
=
grid_dim_x
;
if
(
grid_dim
.
x
<=
max_grid_x
)
{
*
loop
=
1
;
if
(
max_grid_x
/
grid_dim
.
x
>
1
)
{
grid_dim
.
y
=
std
::
min
(
c_blks
,
static_cast
<
int
>
(
max_grid_x
/
grid_dim
.
x
));
assert
(
grid_dim
.
y
<
MAX_GBN_BLOCK_Y
);
//FIXME: turn into a loop
}
else
{
grid_dim
.
y
=
1
;
}
}
else
{
grid_dim
.
x
=
max_grid_x
;
grid_dim
.
y
=
1
;
int
nhw_in_regs
=
m_
-
PIXELS_PER_THREAD_IN_SMEM_BWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
int
pixels_per_iteration
=
PIXELS_PER_THREAD_IN_REGISTERS_BWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
*
loop
=
div_up
(
nhw_in_regs
,
pixels_per_iteration
);
}
return
grid_dim
;
}
void
NhwcBatchNormAddRelu
::
fwd
(
cudaStream_t
stream
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
&&
scale_
!=
nullptr
&&
bias_
!=
nullptr
&&
minibatch_mean_
!=
nullptr
&&
minibatch_variance_
!=
nullptr
&&
relu_bitmask_
!=
nullptr
&&
population_mean_
!=
nullptr
&&
population_variance_
!=
nullptr
&&
X_
!=
nullptr
// && dX_ != nullptr
&&
Y_
!=
nullptr
&&
addend_
!=
nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&&
partial_sums_
!=
nullptr
&&
partial_counts_
!=
nullptr
&&
retired_ctas_
!=
nullptr
;
if
(
!
ptrs_are_set
)
die
();
// reset of retired_cta_count no longer needed
NhwcBatchNormFwdParams
params
;
_setFwdParams
(
&
params
);
params
.
my_data
=
my_data
;
params
.
pair_datas
[
0
]
=
pair_data
;
params
.
pair_datas
[
1
]
=
pair_data2
;
params
.
pair_datas
[
2
]
=
pair_data3
;
params
.
magic
=
magic
;
params
.
sync_iters
=
(
bn_group
==
8
)
?
3
:
(
bn_group
>>
1
);
dim3
grid_dim
=
calc_fwd_grid
(
&
params
.
outer_loops
,
grid_dim_x
);
_fwdKernelLauncher
(
stream
,
params
,
grid_dim
,
params
.
outer_loops
,
occupancy
,
coop
);
}
void
NhwcBatchNormAddRelu
::
dgrad
(
cudaStream_t
stream
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
&&
scale_
!=
nullptr
&&
bias_
!=
nullptr
&&
minibatch_mean_
!=
nullptr
&&
minibatch_variance_
!=
nullptr
&&
relu_bitmask_
!=
nullptr
// && population_mean_ != nullptr
// && population_variance_ != nullptr
&&
X_
!=
nullptr
&&
dX_
!=
nullptr
// && Y_ != nullptr
&&
dY_
!=
nullptr
&&
dAddend_
!=
nullptr
&&
dscale_
!=
nullptr
&&
dbias_
!=
nullptr
&&
retired_ctas_
!=
nullptr
;
if
(
!
ptrs_are_set
)
die
();
// reset of retired_cta_count no longer needed
NhwcBatchNormBwdParams
params
;
_setBwdParams
(
&
params
);
params
.
my_data
=
my_data
;
params
.
pair_datas
[
0
]
=
pair_data
;
params
.
pair_datas
[
1
]
=
pair_data2
;
params
.
pair_datas
[
2
]
=
pair_data3
;
params
.
magic
=
magic
;
params
.
sync_iters
=
(
bn_group
==
8
)
?
3
:
(
bn_group
>>
1
);
params
.
wgrad_coeff
=
1.0
/
bn_group
;
dim3
grid_dim
=
calc_bwd_grid
(
&
params
.
outer_loops
,
grid_dim_x
);
_bwdKernelLauncher
(
stream
,
params
,
grid_dim
,
params
.
outer_loops
,
occupancy
,
coop
);
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
apex/contrib/csrc/groupbn/cuda_utils.h
0 → 100644
View file @
4d6ed501
#include <ATen/cuda/CUDAContext.h>
#ifndef CUDA_UTILS_H
#define CUDA_UTILS_H
namespace
at
{
namespace
cuda
{
namespace
utils
{
static
inline
int
MaxSharedMemoryPerMultiprocessor
(
int
device_id
)
{
return
getDeviceProperties
(
device_id
)
->
sharedMemPerMultiprocessor
;
}
}
}
}
#endif
apex/contrib/csrc/groupbn/interface.cpp
0 → 100644
View file @
4d6ed501
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/ArrayRef.h>
#include <ATen/ScalarType.h>
#include "ATen/Scalar.h"
#ifndef VERSION_GE_1_1
#include "ATen/Type.h"
#endif
#include "ATen/Tensor.h"
#include "ATen/Storage.h"
#include "ATen/Generator.h"
namespace
py
=
pybind11
;
int64_t
get_buffer_size
(
const
int
bn_sync_steps
);
void
*
get_data_ptr
(
const
at
::
Tensor
&
data
);
void
*
get_remote_data_ptr
(
const
at
::
Tensor
&
handle
,
const
int64_t
offset
);
void
close_remote_data
(
const
at
::
Tensor
&
handle
);
at
::
Tensor
nhwc_bn_fwd_train
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
at
::
Tensor
nhwc_bn_fwd_eval
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
int
bn_group
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
);
std
::
vector
<
at
::
Tensor
>
nhwc_bn_bwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
dy
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
at
::
Tensor
nhwc_bn_addrelu_fwd_train
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
z
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
bitmask
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
at
::
Tensor
nhwc_bn_addrelu_fwd_eval
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
z
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
int
bn_group
,
const
float
momentum
,
const
float
epsilon
);
std
::
vector
<
at
::
Tensor
>
nhwc_bn_addrelu_bwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
dy
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
bitmask
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
int
nhwc_bn_fwd_occupancy
();
int
nhwc_bn_bwd_occupancy
();
int
nhwc_bn_addrelu_fwd_occupancy
();
int
nhwc_bn_addrelu_bwd_occupancy
();
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"get_buffer_size"
,
&
get_buffer_size
,
"get_buffer_size"
);
m
.
def
(
"get_data_ptr"
,
&
get_data_ptr
,
"get_data_ptr"
);
m
.
def
(
"get_remote_data_ptr"
,
&
get_remote_data_ptr
,
"get_remote_data_ptr"
);
m
.
def
(
"close_remote_data"
,
&
close_remote_data
,
"close_remote_data"
);
m
.
def
(
"bn_fwd_nhwc"
,
&
nhwc_bn_fwd_train
,
"bn_fwd_nhwc"
);
m
.
def
(
"bn_fwd_eval_nhwc"
,
&
nhwc_bn_fwd_eval
,
"bn_fwd_eval_nhwc"
);
m
.
def
(
"bn_bwd_nhwc"
,
&
nhwc_bn_bwd
,
"bn_bwd_nhwc"
);
m
.
def
(
"bn_fwd_nhwc_occupancy"
,
&
nhwc_bn_fwd_occupancy
,
"bn_fwd_nhwc_occupancy"
);
m
.
def
(
"bn_bwd_nhwc_occupancy"
,
&
nhwc_bn_bwd_occupancy
,
"bn_bwd_nhwc_occupancy"
);
m
.
def
(
"bn_addrelu_fwd_nhwc"
,
&
nhwc_bn_addrelu_fwd_train
,
"bn_addrelu_fwd_nhwc"
);
m
.
def
(
"bn_addrelu_fwd_eval_nhwc"
,
&
nhwc_bn_addrelu_fwd_eval
,
"bn_addrelu_fwd_eval_nhwc"
);
m
.
def
(
"bn_addrelu_bwd_nhwc"
,
&
nhwc_bn_addrelu_bwd
,
"bn_addrelu_bwd_nhwc"
);
m
.
def
(
"bn_addrelu_fwd_nhwc_occupancy"
,
&
nhwc_bn_addrelu_fwd_occupancy
,
"bn_addrelu_fwd_nhwc_occupancy"
);
m
.
def
(
"bn_addrelu_bwd_nhwc_occupancy"
,
&
nhwc_bn_addrelu_bwd_occupancy
,
"bn_addrelu_bwd_nhwc_occupancy"
);
}
apex/contrib/csrc/groupbn/ipc.cu
0 → 100644
View file @
4d6ed501
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include "THC/THC.h"
#include <cuda.h>
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
template
<
>
struct
std
::
hash
<
cudaIpcMemHandle_t
>
{
size_t
operator
()
(
const
cudaIpcMemHandle_t
&
handle
)
const
{
size_t
hash
=
0
;
uint8_t
*
ptr
=
(
uint8_t
*
)
&
handle
;
assert
(
sizeof
(
uint8_t
)
==
1
);
for
(
int
i
=
0
;
i
<
sizeof
(
cudaIpcMemHandle_t
);
i
++
)
{
hash
+=
*
ptr
;
ptr
++
;
}
return
hash
;
}
};
template
<
>
struct
std
::
equal_to
<
cudaIpcMemHandle_t
>
{
bool
operator
()
(
const
cudaIpcMemHandle_t
&
lhs
,
const
cudaIpcMemHandle_t
&
rhs
)
const
{
return
(
std
::
memcmp
((
void
*
)
&
lhs
,
(
void
*
)
&
rhs
,
sizeof
(
cudaIpcMemHandle_t
))
==
0
);
}
};
namespace
{
namespace
gpuipc
{
//from: src/operator/nn/cudnn/nhwc_batch_norm_kernel.h
// The number of threads per pixel.
const
int
THREADS_PER_PIXEL
=
16
;
// The number of elements per ldg.
const
int
ELEMENTS_PER_LDG
=
4
;
// The number of reducing ops, each uses its own space : mean, var, dscale, dbias
const
int
REDUCE_OPS
=
4
;
// Maximum block.y supported - limited due to buffer allocation
const
int
MAX_BLOCK_Y
=
256
;
const
int
MAX_OFFSET
=
REDUCE_OPS
*
MAX_BLOCK_Y
;
const
int
BYTES_PER_ELEM
=
4
;
// Buffer size per sync step
const
int
SINGLE_SYNC_BUFFER_BYTES
=
MAX_OFFSET
*
THREADS_PER_PIXEL
*
2
*
ELEMENTS_PER_LDG
*
BYTES_PER_ELEM
;
};
class
IpcMemHandleRegistry
{
public:
void
*
getPtr
(
const
cudaIpcMemHandle_t
&
handle
,
int64_t
offset
)
{
if
(
registry_
.
count
(
handle
)
==
0
)
{
registry_
.
insert
(
std
::
make_pair
(
handle
,
RegistryEntry
()));
registry_
[
handle
].
dev_ptr
=
ipcOpenMem
(
handle
);
}
registry_
[
handle
].
ref_count
++
;
return
(((
uint8_t
*
)
registry_
[
handle
].
dev_ptr
)
+
offset
);
}
void
releasePtr
(
const
cudaIpcMemHandle_t
&
handle
)
{
if
(
registry_
.
count
(
handle
)
==
0
)
{
}
if
(
--
registry_
[
handle
].
ref_count
==
0
)
{
ipcCloseMem
(
registry_
[
handle
].
dev_ptr
);
registry_
.
erase
(
handle
);
}
}
struct
RegistryEntry
{
void
*
dev_ptr
;
int
ref_count
;
RegistryEntry
()
:
dev_ptr
(
NULL
)
,
ref_count
(
0
)
{}
};
protected:
std
::
unordered_map
<
cudaIpcMemHandle_t
,
RegistryEntry
>
registry_
;
void
*
ipcOpenMem
(
const
cudaIpcMemHandle_t
&
handle
)
{
void
*
data
;
cudaIpcOpenMemHandle
(
&
data
,
handle
,
cudaIpcMemLazyEnablePeerAccess
);
cudaCheckErrors
(
"ipc init"
);
return
data
;
}
void
ipcCloseMem
(
void
*
dev_ptr
)
{
cudaIpcCloseMemHandle
(
dev_ptr
);
cudaCheckErrors
(
"ipc close"
);
}
};
}
static
IpcMemHandleRegistry
ipc_mem_registry
;
int64_t
get_buffer_size
(
const
int
bn_sync_steps
)
{
return
bn_sync_steps
*
gpuipc
::
SINGLE_SYNC_BUFFER_BYTES
;
}
void
*
get_remote_data_ptr
(
const
at
::
Tensor
&
handle
,
const
int64_t
offset
)
{
cudaIpcMemHandle_t
my_handle
;
memcpy
((
unsigned
char
*
)(
&
my_handle
),
handle
.
data
<
uint8_t
>
(),
sizeof
(
my_handle
));
return
ipc_mem_registry
.
getPtr
(
my_handle
,
offset
);
}
void
close_remote_data
(
const
at
::
Tensor
&
handle
)
{
cudaIpcMemHandle_t
my_handle
;
memcpy
((
unsigned
char
*
)(
&
my_handle
),
handle
.
data
<
uint8_t
>
(),
sizeof
(
my_handle
));
ipc_mem_registry
.
releasePtr
(
my_handle
);
}
void
*
get_data_ptr
(
const
at
::
Tensor
&
data
)
{
return
data
.
data
<
uint8_t
>
();
}
apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h
0 → 100644
View file @
4d6ed501
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nhwc_batch_norm_kernel.h
* \brief CUDA NHWC Batch Normalization code
* \author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer
*/
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_
#include <stdint.h>
#include <algorithm>
#define DEVICE_FUNCTION static inline __device__
// CTA margin used by cooperative launch. Can be overridden by env var NHWC_BATCHNORM_LAUNCH_MARGIN.
#define NHWC_BATCHNORM_LAUNCH_MARGIN_MIN 3
#define NHWC_BATCHNORM_LAUNCH_MARGIN_DEFAULT NHWC_BATCHNORM_LAUNCH_MARGIN_MIN
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
ELEMENTS_PER_LDG
>
struct
PackedStorage
{
enum
{
PACKED_ELEMENTS_PER_LDG
=
ELEMENTS_PER_LDG
};
typedef
T
Type
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
ELEMENTS_PER_LDG
>
struct
PackedStorage
<
uint16_t
,
ELEMENTS_PER_LDG
>
{
enum
{
PACKED_ELEMENTS_PER_LDG
=
ELEMENTS_PER_LDG
/
2
};
typedef
int
Type
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
from_float
(
int
(
&
dst
)[
N
],
const
float
(
&
src
)[
2
*
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
uint16_t
lo
,
hi
;
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;"
:
"=h"
(
lo
)
:
"f"
(
src
[
2
*
i
+
0
]));
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;"
:
"=h"
(
hi
)
:
"f"
(
src
[
2
*
i
+
1
]));
asm
volatile
(
"mov.b32 %0, {%1, %2};"
:
"=r"
(
dst
[
i
])
:
"h"
(
lo
),
"h"
(
hi
));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
from_float
(
float
(
&
dst
)[
N
],
const
float
(
&
src
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
dst
[
i
]
=
src
[
i
];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
to_float
(
float
(
&
dst
)[
2
*
N
],
int
(
&
src
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
uint16_t
lo
,
hi
;
asm
volatile
(
"mov.b32 {%0, %1}, %2;"
:
"=h"
(
lo
),
"=h"
(
hi
)
:
"r"
(
src
[
i
]));
asm
volatile
(
"cvt.f32.f16 %0, %1;"
:
"=f"
(
dst
[
2
*
i
+
0
])
:
"h"
(
lo
));
asm
volatile
(
"cvt.f32.f16 %0, %1;"
:
"=f"
(
dst
[
2
*
i
+
1
])
:
"h"
(
hi
));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
to_float
(
float
(
&
dst
)[
N
],
float
(
&
src
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
dst
[
i
]
=
src
[
i
];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
ldg
(
int
(
&
dst
)[
1
],
const
uint16_t
*
gmem
)
{
dst
[
0
]
=
__ldg
((
const
int
*
)
gmem
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
ldg_stream
(
int
(
&
dst
)[
1
],
const
uint16_t
*
gmem
)
{
unsigned
int
tmp
;
asm
volatile
(
"ld.global.cs.nc.s32 %0, [%1];"
:
"=r"
(
tmp
)
:
"l"
((
const
uint
*
)
gmem
));
dst
[
0
]
=
tmp
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
ldg
(
int
(
&
dst
)[
2
],
const
uint16_t
*
gmem
)
{
int2
tmp
=
__ldg
((
const
int2
*
)
gmem
);
dst
[
0
]
=
tmp
.
x
;
dst
[
1
]
=
tmp
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
ldg_stream
(
int
(
&
dst
)[
2
],
const
uint16_t
*
gmem
)
{
int2
tmp
;
asm
volatile
(
"ld.global.cs.nc.v2.s32 {%0,%1}, [%2];"
:
"=r"
(
tmp
.
x
),
"=r"
(
tmp
.
y
)
:
"l"
((
const
int2
*
)
gmem
));
dst
[
0
]
=
tmp
.
x
;
dst
[
1
]
=
tmp
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
ldg
(
float
(
&
dst
)[
N
],
const
uint16_t
*
gmem
)
{
int
tmp
[
N
/
2
];
ldg
(
tmp
,
gmem
);
to_float
(
dst
,
tmp
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
ldg_stream
(
float
(
&
dst
)[
N
],
const
uint16_t
*
gmem
)
{
int
tmp
[
N
/
2
];
ldg_stream
(
tmp
,
gmem
);
to_float
(
dst
,
tmp
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
stg
(
uint16_t
*
gmem
,
int
(
&
src
)[
1
])
{
reinterpret_cast
<
int
*>
(
gmem
)[
0
]
=
src
[
0
];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
stg_stream
(
uint16_t
*
gmem
,
int
(
&
src
)[
1
])
{
unsigned
int
tmp
=
src
[
0
];
asm
volatile
(
"st.global.cs.s32 [%0], %1;"
::
"l"
((
uint
*
)
gmem
)
,
"r"
(
tmp
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
stg
(
uint16_t
*
gmem
,
int
(
&
src
)[
2
])
{
reinterpret_cast
<
int2
*>
(
gmem
)[
0
]
=
make_int2
(
src
[
0
],
src
[
1
]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
stg_stream
(
uint16_t
*
gmem
,
int
(
&
src
)[
2
])
{
asm
volatile
(
"st.global.cs.v2.s32 [%0], {%1,%2};"
::
"l"
((
uint
*
)
gmem
)
,
"r"
(
src
[
0
]),
"r"
(
src
[
1
]));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
stg
(
uint16_t
*
gmem
,
float
(
&
src
)[
N
])
{
int
tmp
[
N
/
2
];
from_float
(
tmp
,
src
);
stg
(
gmem
,
tmp
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
stg_stream
(
uint16_t
*
gmem
,
float
(
&
src
)[
N
])
{
int
tmp
[
N
/
2
];
from_float
(
tmp
,
src
);
stg_stream
(
gmem
,
tmp
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
read_from_gmem
(
float
(
&
dst
)[
2
],
const
float
*
gmem
,
int
idx
)
{
float2
tmp
=
__ldg
(
reinterpret_cast
<
const
float2
*>
(
&
gmem
[
2
*
idx
]));
dst
[
0
]
=
tmp
.
x
;
dst
[
1
]
=
tmp
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
read_from_gmem
(
float
(
&
dst
)[
4
],
const
float
*
gmem
,
int
idx
)
{
float4
tmp
=
__ldg
(
reinterpret_cast
<
const
float4
*>
(
&
gmem
[
4
*
idx
]));
dst
[
0
]
=
tmp
.
x
;
dst
[
1
]
=
tmp
.
y
;
dst
[
2
]
=
tmp
.
z
;
dst
[
3
]
=
tmp
.
w
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
read_from_smem
(
float
(
&
x
)[
2
],
const
float
*
smem
,
int
idx
)
{
float2
tmp
=
*
(
const
float2
*
)
&
smem
[
2
*
idx
];
x
[
0
]
=
tmp
.
x
;
x
[
1
]
=
tmp
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
read_from_smem
(
int
(
&
x
)[
1
],
const
int
*
smem
,
int
idx
)
{
x
[
0
]
=
smem
[
idx
];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
read_from_smem
(
float
(
&
x
)[
4
],
const
float
*
smem
,
int
idx
)
{
float4
tmp
=
*
(
const
float4
*
)
&
smem
[
4
*
idx
];
x
[
0
]
=
tmp
.
x
;
x
[
1
]
=
tmp
.
y
;
x
[
2
]
=
tmp
.
z
;
x
[
3
]
=
tmp
.
w
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
read_from_smem
(
int
(
&
x
)[
2
],
const
int
*
smem
,
int
idx
)
{
int2
tmp
=
*
(
const
int2
*
)
&
smem
[
2
*
idx
];
x
[
0
]
=
tmp
.
x
;
x
[
1
]
=
tmp
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
write_to_gmem
(
float
*
gmem
,
int
idx
,
const
float
(
&
src
)[
2
])
{
reinterpret_cast
<
float2
*>
(
&
gmem
[
2
*
idx
])[
0
]
=
make_float2
(
src
[
0
],
src
[
1
]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
write_to_gmem
(
float
*
gmem
,
int
idx
,
const
float
(
&
src
)[
4
])
{
reinterpret_cast
<
float4
*>
(
&
gmem
[
4
*
idx
])[
0
]
=
make_float4
(
src
[
0
],
src
[
1
],
src
[
2
],
src
[
3
]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
scaled_write_to_gmem
(
float
*
gmem
,
int
idx
,
const
float
(
&
src
)[
4
],
const
float
coeff
)
{
reinterpret_cast
<
float4
*>
(
&
gmem
[
4
*
idx
])[
0
]
=
make_float4
(
src
[
0
]
*
coeff
,
src
[
1
]
*
coeff
,
src
[
2
]
*
coeff
,
src
[
3
]
*
coeff
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
write_to_smem
(
float
*
smem
,
int
idx
,
const
float
(
&
x
)[
2
])
{
reinterpret_cast
<
float2
*>
(
&
smem
[
2
*
idx
])[
0
]
=
make_float2
(
x
[
0
],
x
[
1
]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
write_to_smem
(
int
*
smem
,
int
idx
,
const
int
(
&
x
)[
1
])
{
smem
[
idx
]
=
x
[
0
];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
write_to_smem
(
float
*
smem
,
int
idx
,
const
float
(
&
x
)[
4
])
{
reinterpret_cast
<
float4
*>
(
&
smem
[
4
*
idx
])[
0
]
=
make_float4
(
x
[
0
],
x
[
1
],
x
[
2
],
x
[
3
]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
write_to_smem
(
int
*
smem
,
int
idx
,
const
int
(
&
x
)[
2
])
{
reinterpret_cast
<
int2
*>
(
&
smem
[
2
*
idx
])[
0
]
=
make_int2
(
x
[
0
],
x
[
1
]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
zero_array
(
int
(
&
dst
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
dst
[
i
]
=
0
;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
zero_array
(
float
(
&
dst
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
dst
[
i
]
=
0.
f
;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
add
(
float
(
&
x
)[
N
],
const
float
(
&
y
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
x
[
i
]
+=
y
[
i
];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
multiply
(
float
(
&
x
)[
N
],
const
float
(
&
y
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
x
[
i
]
*=
y
[
i
];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
scale_
(
float
(
&
x
)[
N
],
float
scalar
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
x
[
i
]
*=
scalar
;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
normalize
(
float
(
&
x
)[
N
],
const
float
(
&
bias
)[
N
],
const
float
(
&
scale
)[
N
],
const
float
(
&
m1
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
x
[
i
]
=
bias
[
i
]
+
scale
[
i
]
*
(
x
[
i
]
-
m1
[
i
]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Storage
>
DEVICE_FUNCTION
Storage
relu
(
Storage
in
)
{
Storage
zero
=
(
Storage
)
0.
f
;
return
(
in
<
zero
)
?
zero
:
in
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
relu_activation
(
float
(
&
x
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
x
[
i
]
=
relu
(
x
[
i
]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_CTA
>
DEVICE_FUNCTION
void
parallel_sums_16x2
(
float
*
smem
,
float
(
&
x
)[
4
],
int
nhw
,
void
*
params_my_data
,
void
**
params_pair_datas
,
int
off
,
const
int
magic
,
const
int
sync_iters
)
{
// The size of a warp.
const
int
THREADS_PER_WARP
=
32
;
// The number of warps in a CTA.
const
int
WARPS_PER_CTA
=
THREADS_PER_CTA
/
THREADS_PER_WARP
;
// The number of threads per pixel.
const
int
THREADS_PER_PIXEL
=
16
;
// The number of elements per ldg.
const
int
ELEMENTS_PER_LDG
=
4
;
// The number of reducing ops, each uses its own space : mean, var, dscale, dbias
const
int
REDUCE_OPS
=
4
;
// Maximum block.y supported - limited due to buffer allocation
const
int
MAX_BLOCK_Y
=
256
;
const
int
MAX_OFFSET
=
REDUCE_OPS
*
MAX_BLOCK_Y
;
// The warp decomposition.
const
int
warp_id
=
threadIdx
.
x
/
THREADS_PER_WARP
;
const
int
lane_id
=
threadIdx
.
x
%
THREADS_PER_WARP
;
// total size of data per sync iter
const
int
data_total
=
MAX_OFFSET
*
THREADS_PER_PIXEL
*
ELEMENTS_PER_LDG
*
2
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
x
[
i
]
+=
__shfl_sync
(
0xffffffffU
,
x
[
i
],
THREADS_PER_PIXEL
+
lane_id
);
}
// The warp leaders, write to SMEM.
if
(
lane_id
<
THREADS_PER_PIXEL
)
{
write_to_smem
(
smem
,
warp_id
*
THREADS_PER_PIXEL
+
lane_id
,
x
);
}
// The data is in SMEM. Do the final reduction.
__syncthreads
();
// The 1st warp does all the work.
// We do the final reduction each half-warp sequentially reduces the final values.
if
(
warp_id
==
0
)
{
read_from_smem
(
x
,
smem
,
threadIdx
.
x
);
#pragma unroll
for
(
int
offset
=
1
;
offset
<
WARPS_PER_CTA
/
(
THREADS_PER_WARP
/
THREADS_PER_PIXEL
);
++
offset
)
{
float
y
[
ELEMENTS_PER_LDG
];
// Read the mean and variance from the other pixel.
read_from_smem
(
y
,
smem
,
threadIdx
.
x
+
offset
*
THREADS_PER_WARP
);
// Compute the updated sum.
add
(
x
,
y
);
}
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
x
[
i
]
+=
__shfl_sync
(
0xffffffffU
,
x
[
i
],
THREADS_PER_PIXEL
+
lane_id
);
}
// Make sure the data was read from SMEM.
__syncwarp
();
// Store the final values.
if
(
threadIdx
.
x
<
THREADS_PER_PIXEL
)
{
// probably could do it earlier, before sync
for
(
int
sync_iter
=
0
;
sync_iter
<
sync_iters
;
++
sync_iter
)
{
//float* params_pair_data = (reinterpret_cast<float**>(params_pair_datas))[sync_iter];
void
*
params_pair_data
=
params_pair_datas
[
sync_iter
];
// skip the space consumed by previous sync iterations
const
int
xbuf_offset
=
sync_iter
*
data_total
;
// data starts after flags, but have to skip previous
const
int
data_offset
=
xbuf_offset
+
off
*
ELEMENTS_PER_LDG
*
THREADS_PER_PIXEL
*
2
+
ELEMENTS_PER_LDG
*
threadIdx
.
x
*
2
;
// after sums for this GPU were computed, let CTA0 broadcast the sum to over GPU
if
(
blockIdx
.
x
==
0
)
{
volatile
float
*
write_data
=
&
((
reinterpret_cast
<
float
*>
(
params_pair_data
))[
data_offset
]);
// write the data to memory region to be reflected to other GPU
asm
volatile
(
"st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};"
::
"l"
(
write_data
)
,
"f"
(
x
[
0
]),
"r"
(
magic
),
"f"
(
x
[
2
]),
"r"
(
magic
));
asm
volatile
(
"st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};"
::
"l"
(
write_data
+
4
)
,
"f"
(
x
[
1
]),
"r"
(
magic
),
"f"
(
x
[
3
]),
"r"
(
magic
));
}
// now each CTA (on each GPU) reads the data written by CTA 0 of the other GPU
volatile
float
*
read_data
=
&
((
reinterpret_cast
<
float
*>
(
params_my_data
))[
data_offset
]);
float
other
[
4
];
uint32_t
other_flag_a
,
other_flag_b
;
do
{
asm
volatile
(
"ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];"
:
"=f"
(
other
[
0
]),
"=r"
(
other_flag_a
),
"=f"
(
other
[
2
]),
"=r"
(
other_flag_b
)
:
"l"
(
read_data
));
}
while
((
other_flag_a
!=
magic
)
||
(
other_flag_b
!=
magic
));
do
{
asm
volatile
(
"ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];"
:
"=f"
(
other
[
1
]),
"=r"
(
other_flag_a
),
"=f"
(
other
[
3
]),
"=r"
(
other_flag_b
)
:
"l"
(
read_data
+
4
));
}
while
((
other_flag_a
!=
magic
)
||
(
other_flag_b
!=
magic
));
add
(
x
,
other
);
}
// finally, after syncing up and accounting for partial sums from
// other GPUs as required, write the result
write_to_smem
(
smem
,
threadIdx
.
x
,
x
);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_CTA
>
DEVICE_FUNCTION
void
parallel_sums_8x4
(
float
*
smem
,
float
(
&
x
)[
4
],
int
nhw
)
{
// The size of a warp.
const
int
THREADS_PER_WARP
=
32
;
// The number of warps in a CTA.
const
int
WARPS_PER_CTA
=
THREADS_PER_CTA
/
THREADS_PER_WARP
;
// The number of threads per pixel.
const
int
THREADS_PER_PIXEL
=
8
;
// The number of elements per ldg.
const
int
ELEMENTS_PER_LDG
=
4
;
// The warp decomposition.
const
int
warp_id
=
threadIdx
.
x
/
THREADS_PER_WARP
;
const
int
lane_id
=
threadIdx
.
x
%
THREADS_PER_WARP
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
x
[
i
]
+=
__shfl_sync
(
0xffffffffU
,
x
[
i
],
THREADS_PER_PIXEL
+
lane_id
);
x
[
i
]
+=
__shfl_sync
(
0xffffffffU
,
x
[
i
],
THREADS_PER_PIXEL
*
2
+
lane_id
);
}
// The warp leaders, write to SMEM.
if
(
lane_id
<
THREADS_PER_PIXEL
)
{
write_to_smem
(
smem
,
warp_id
*
THREADS_PER_PIXEL
+
lane_id
,
x
);
}
// The data is in SMEM. Do the final reduction.
__syncthreads
();
// The 1st warp does all the work.
// We do the final reduction each half-warp sequentially reduces the final values.
if
(
warp_id
==
0
)
{
read_from_smem
(
x
,
smem
,
threadIdx
.
x
);
#pragma unroll
for
(
int
offset
=
1
;
offset
<
WARPS_PER_CTA
/
(
THREADS_PER_WARP
/
THREADS_PER_PIXEL
);
++
offset
)
{
float
y
[
ELEMENTS_PER_LDG
];
// Read the mean and variance from the other pixel.
read_from_smem
(
y
,
smem
,
threadIdx
.
x
+
offset
*
THREADS_PER_WARP
);
// Compute the updated sum.
add
(
x
,
y
);
}
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
x
[
i
]
+=
__shfl_sync
(
0xffffffffU
,
x
[
i
],
THREADS_PER_PIXEL
+
lane_id
);
x
[
i
]
+=
__shfl_sync
(
0xffffffffU
,
x
[
i
],
THREADS_PER_PIXEL
*
2
+
lane_id
);
}
// Make sure the data was read from SMEM.
__syncwarp
();
// Store the final values.
if
(
threadIdx
.
x
<
THREADS_PER_PIXEL
)
{
write_to_smem
(
smem
,
threadIdx
.
x
,
x
);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_CTA
,
int
THREADS_PER_PIXEL
,
int
ELEMENTS_PER_LDG
>
DEVICE_FUNCTION
void
parallel_sums
(
float
*
smem
,
float
(
&
x
)[
ELEMENTS_PER_LDG
],
int
nhw
)
{
// The size of a warp.
const
int
THREADS_PER_WARP
=
32
;
// The number of warps in a CTA.
const
int
WARPS_PER_CTA
=
THREADS_PER_CTA
/
THREADS_PER_WARP
;
// The number of pixels computed by a single warp.
const
int
PIXELS_PER_WARP
=
THREADS_PER_WARP
/
THREADS_PER_PIXEL
;
// The position in the warp.
const
int
nhw_in_warp
=
nhw
%
PIXELS_PER_WARP
;
// The C in the warp.
const
int
c_in_warp
=
threadIdx
.
x
%
THREADS_PER_PIXEL
;
// Store the values to shared memory.
write_to_smem
(
smem
,
threadIdx
.
x
,
x
);
// Compute the parallel sums.
for
(
int
offset
=
PIXELS_PER_WARP
/
2
;
offset
>
0
;
offset
/=
2
)
{
// NOP.
__syncwarp
();
// Read the running sum from the other thread.
float
y
[
ELEMENTS_PER_LDG
];
if
(
nhw_in_warp
<
offset
)
{
read_from_smem
(
y
,
smem
,
threadIdx
.
x
+
offset
*
THREADS_PER_PIXEL
);
}
// Compute the updated sum.
add
(
x
,
y
);
// NOP.
__syncwarp
();
// Update the sum in SMEM.
if
(
offset
>
1
&&
nhw_in_warp
<
offset
)
{
write_to_smem
(
smem
,
threadIdx
.
x
,
x
);
}
}
// The warps are done. Do the final reduction at the CTA level.
__syncthreads
();
// The warp leaders, write to SMEM.
const
int
idx
=
(
threadIdx
.
x
/
THREADS_PER_WARP
)
*
THREADS_PER_PIXEL
+
c_in_warp
;
if
(
nhw_in_warp
==
0
)
{
write_to_smem
(
smem
,
idx
,
x
);
}
// The data is in SMEM. Do the final reduction.
__syncthreads
();
// Read the 1st element to prepare the work.
if
(
nhw
<
WARPS_PER_CTA
/
2
)
{
read_from_smem
(
x
,
smem
,
threadIdx
.
x
);
}
// We have the running mean and running m2. Let's build the mean/var of the CTA.
for
(
int
offset
=
WARPS_PER_CTA
/
2
;
offset
>
0
;
offset
/=
2
)
{
// NOP.
__syncwarp
();
// Read the mean and variance from the other pixel.
float
y
[
ELEMENTS_PER_LDG
];
if
(
nhw
<
offset
)
{
read_from_smem
(
y
,
smem
,
threadIdx
.
x
+
offset
*
THREADS_PER_PIXEL
);
}
// Compute the updated sum.
add
(
x
,
y
);
// NOP.
__syncwarp
();
// Store the mean/var for the different pixels.
if
(
nhw
<
offset
)
{
write_to_smem
(
smem
,
threadIdx
.
x
,
x
);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_PIXEL
,
int
ELEMENTS_PER_LDG
>
struct
ParallelSums
{
template
<
int
THREADS_PER_CTA
>
DEVICE_FUNCTION
void
dispatch
(
float
*
smem
,
float
(
&
x
)[
ELEMENTS_PER_LDG
],
int
nhw
)
{
parallel_sums
<
THREADS_PER_CTA
,
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>
(
smem
,
x
,
nhw
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
ParallelSums
<
16
,
4
>
{
template
<
int
THREADS_PER_CTA
>
DEVICE_FUNCTION
void
dispatch
(
float
*
smem
,
float
(
&
x
)[
4
],
int
nhw
)
{
parallel_sums_16x2
<
THREADS_PER_CTA
>
(
smem
,
x
,
nhw
,
0
,
0
,
0
,
0
,
0
);
}
template
<
int
THREADS_PER_CTA
>
DEVICE_FUNCTION
void
dispatchX
(
float
*
smem
,
float
(
&
x
)[
4
],
int
nhw
,
void
*
params_my_data
,
void
**
params_pair_datas
,
int
off
,
const
int
magic
,
const
unsigned
int
&
sync_iters
)
{
parallel_sums_16x2
<
THREADS_PER_CTA
>
(
smem
,
x
,
nhw
,
params_my_data
,
params_pair_datas
,
off
,
magic
,
sync_iters
);
}
};
template
<
>
struct
ParallelSums
<
8
,
4
>
{
template
<
int
THREADS_PER_CTA
>
DEVICE_FUNCTION
void
dispatch
(
float
*
smem
,
float
(
&
x
)[
4
],
int
nhw
)
{
parallel_sums_8x4
<
THREADS_PER_CTA
>
(
smem
,
x
,
nhw
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
int
div_up
(
int
m
,
int
n
)
{
return
(
m
+
n
-
1
)
/
n
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// It is expected that all threads in the CTA enter this function!
DEVICE_FUNCTION
void
inter_block_sync
(
int
*
gmem_retired_ctas
,
int
expected_count
,
bool
master
)
{
// Register the CTA.
if
(
threadIdx
.
x
==
0
)
{
// Issue the membar.
__threadfence
();
// Notify that the CTA is done.
int
val_to_add
=
1
;
if
(
master
)
{
val_to_add
=
-
(
expected_count
-
1
);
}
atomicAdd
(
gmem_retired_ctas
,
val_to_add
);
}
// Are all CTAs done?
if
(
threadIdx
.
x
==
0
)
{
int
retired_ctas
=
-
1
;
do
{
__threadfence
();
asm
volatile
(
"ld.global.cg.b32 %0, [%1];"
:
"=r"
(
retired_ctas
)
:
"l"
(
gmem_retired_ctas
));
}
while
(
retired_ctas
!=
0
);
}
__syncthreads
();
}
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
NhwcBatchNormFwdInferenceParams
{
// The input/output tensors.
uint16_t
*
gmem_src
,
*
gmem_dst
,
*
gmem_src1
;
// the final mean and variance as calculated during the training process
float
*
gmem_mean
,
*
gmem_var
;
// The bias/scale.
float
*
gmem_bias
,
*
gmem_scale
;
// The dimensions.
int
nhw
,
c
;
// epsilon
float
var_eps
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// No DESIRED_OCCUPANCY launch bounds needed, as this is not launched cooperatively
template
<
typename
Storage
,
int
THREADS_PER_CTA
,
int
THREADS_PER_PIXEL
,
int
ELEMENTS_PER_LDG
,
bool
USE_RELU
,
bool
USE_ADD_RELU
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
)
void
nhwc_batch_norm_fwd_inference
(
NhwcBatchNormFwdInferenceParams
params
)
{
// The number of pixels loaded in a single LDG.
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
// The number of C elements per CTA.
const
int
C_ELEMENTS_PER_CTA
=
THREADS_PER_PIXEL
*
ELEMENTS_PER_LDG
;
// The start position in the NHW dimension where the CTA starts.
const
int
cta_nhw_stride
=
gridDim
.
x
*
PIXELS_PER_LDG
;
// Compute the NHW coordinate of the thread in the CTA.
const
int
thread_in_cta_nhw
=
threadIdx
.
x
/
THREADS_PER_PIXEL
;
// thread's starting point in NHW
const
int
thread_nhw
=
thread_in_cta_nhw
+
blockIdx
.
x
*
PIXELS_PER_LDG
;
// The position in the C dimension where the CTA starts.
const
int
cta_c
=
blockIdx
.
y
*
C_ELEMENTS_PER_CTA
;
// Compute the C coordinate of the thread in the CTA.
const
int
thread_in_cta_c
=
threadIdx
.
x
%
THREADS_PER_PIXEL
;
// Compute the C coordinate of the thread.
const
int
thread_c
=
cta_c
+
thread_in_cta_c
*
ELEMENTS_PER_LDG
;
// Is the thread working on a valid C dimension?
const
int
is_valid_c
=
thread_c
<
params
.
c
;
float
mean
[
ELEMENTS_PER_LDG
],
var
[
ELEMENTS_PER_LDG
];
float
scale
[
ELEMENTS_PER_LDG
],
bias
[
ELEMENTS_PER_LDG
];
zero_array
(
mean
);
zero_array
(
var
);
zero_array
(
scale
);
zero_array
(
bias
);
if
(
is_valid_c
)
{
read_from_gmem
(
var
,
&
params
.
gmem_var
[
cta_c
],
thread_in_cta_c
);
read_from_gmem
(
scale
,
&
params
.
gmem_scale
[
cta_c
],
thread_in_cta_c
);
read_from_gmem
(
mean
,
&
params
.
gmem_mean
[
cta_c
],
thread_in_cta_c
);
read_from_gmem
(
bias
,
&
params
.
gmem_bias
[
cta_c
],
thread_in_cta_c
);
}
// Update the scale with the stddev and eps.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
scale
[
i
]
*=
rsqrtf
(
var
[
i
]
+
params
.
var_eps
);
}
// The base pointers for reading/writing
uint16_t
*
const
gmem_src
=
&
params
.
gmem_src
[
thread_c
];
uint16_t
*
const
gmem_dst
=
&
params
.
gmem_dst
[
thread_c
];
const
uint16_t
*
gmem_src1
=
nullptr
;
if
(
USE_ADD_RELU
)
{
gmem_src1
=
&
params
.
gmem_src1
[
thread_c
];
}
// apply BN
for
(
int
nhw
=
thread_nhw
;
nhw
<
params
.
nhw
;
nhw
+=
cta_nhw_stride
)
{
float
x_math
[
ELEMENTS_PER_LDG
];
zero_array
(
x_math
);
if
(
is_valid_c
)
{
ldg
(
x_math
,
&
gmem_src
[
nhw
*
params
.
c
]);
}
// Normalize and apply activation function
normalize
(
x_math
,
bias
,
scale
,
mean
);
if
(
USE_ADD_RELU
)
{
float
x1_math
[
ELEMENTS_PER_LDG
];
ldg
(
x1_math
,
&
gmem_src1
[
nhw
*
params
.
c
]);
add
(
x_math
,
x1_math
);
relu_activation
(
x_math
);
}
else
if
(
USE_RELU
)
{
relu_activation
(
x_math
);
}
if
(
is_valid_c
)
{
stg
(
&
gmem_dst
[
nhw
*
params
.
c
],
x_math
);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
NhwcBatchNormFwdParams
{
// The input/output tensors.
uint16_t
*
gmem_src
,
*
gmem_dst
,
*
gmem_src1
;
// The bias/scale.
float
*
gmem_bias
,
*
gmem_scale
;
// running mean/var (refer BN API from cudnn doc)
float
*
gmem_running_mean
,
*
gmem_running_var
;
// saved mean/var (refer BN API from cudnn doc)
float
*
gmem_saved_mean
,
*
gmem_saved_var
;
// ReLU bitmask
unsigned
int
*
gmem_relu_bitmask
;
// The dimensions.
int
nhw
,
c
;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
float
svar_inv_count
;
// factor to scale sum of squared errors to get running variance. Should be 1/nhw or 1/(nhw-1).
float
rvar_inv_count
;
// The buffer to do the reduction for mean, stddev and count.
float
*
gmem_sums
;
// The buffer to count items in the different CTAs.
int
*
gmem_counts
;
// The counters of retired CTAs.
int
*
gmem_retired_ctas
;
// The epsilon to apply to the computation of the variance.
float
var_eps
;
// outer loop count
int
outer_loops
;
// exponential average factor
float
exp_avg_factor
;
// number of CTAs along .x dimension
int
c_blks
;
void
*
my_data
;
void
*
pair_datas
[
4
];
int
magic
;
int
sync_iters
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Storage
,
int
THREADS_PER_CTA
,
int
THREADS_PER_PIXEL
,
int
PIXELS_PER_THREAD_IN_REGISTERS
,
int
PIXELS_PER_THREAD_IN_SMEM
,
int
ELEMENTS_PER_LDG
,
int
USE_ONLINE_APPROACH
,
int
OUTER_LOOPS_
,
bool
USE_RELU
,
bool
USE_ADD_RELU
,
int
DESIRED_OCCUPANCY
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
,
DESIRED_OCCUPANCY
)
void
nhwc_batch_norm_fwd
(
NhwcBatchNormFwdParams
params
)
{
// The number of pixels loaded in a single LDG.
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
// The number of pixels computed per CTA stored in registers.
const
int
PIXELS_PER_CTA_IN_REGISTERS
=
PIXELS_PER_THREAD_IN_REGISTERS
*
PIXELS_PER_LDG
;
// The number of pixels computed per CTA stored in SMEM.
const
int
PIXELS_PER_CTA_IN_SMEM
=
PIXELS_PER_THREAD_IN_SMEM
*
PIXELS_PER_LDG
;
// The number of C elements per CTA.
const
int
C_ELEMENTS_PER_CTA
=
THREADS_PER_PIXEL
*
ELEMENTS_PER_LDG
;
// Shared memory to do CTA-wide parallel sums.
__shared__
float
smem
[
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
32
)
*
ELEMENTS_PER_LDG
];
// Compute the NHW coordinate of the thread in the CTA.
const
int
thread_in_cta_nhw
=
threadIdx
.
x
/
THREADS_PER_PIXEL
;
// The adapter for the storage.
typedef
PackedStorage
<
Storage
,
ELEMENTS_PER_LDG
>
PackedStorage_
;
// The data type for packed storage in SMEM.
typedef
typename
PackedStorage_
::
Type
PackedStorageType
;
// The number of elements in the packed storage.
const
int
PACKED_ELEMENTS_PER_LDG
=
PackedStorage_
::
PACKED_ELEMENTS_PER_LDG
;
// Registers to keep the data live for the persistent approach.
PackedStorageType
x_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
// Shared memory buffer to store the extra pixels.
extern
__shared__
PackedStorageType
smem_storage_packed
[];
for
(
int
c_blk_index
=
blockIdx
.
y
;
c_blk_index
<
params
.
c_blks
;
c_blk_index
+=
gridDim
.
y
)
{
// The position in the NHW dimension where the CTA starts.
int
cta_nhw_regs
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_REGISTERS
;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int
cta_nhw_smem
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_SMEM
;
// The position in the C dimension where the CTA starts.
const
int
cta_c
=
c_blk_index
*
C_ELEMENTS_PER_CTA
;
// Compute the C coordinate of the thread in the CTA.
const
int
thread_in_cta_c
=
threadIdx
.
x
%
THREADS_PER_PIXEL
;
// Compute the C coordinate of the thread.
int
thread_c
=
cta_c
+
thread_in_cta_c
*
ELEMENTS_PER_LDG
;
// Is the thread working on a valid C dimension?
const
int
is_valid_c
=
thread_c
<
params
.
c
;
// Clamp thread_c so that we load from valid locations even if we don't use the value
if
(
!
is_valid_c
)
thread_c
=
params
.
c
-
4
;
// Single pass numerically stable algorithm, see:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm
//
// n = 0, mean = 0.0, M2 = 0.0
//
// for x in data:
// n += 1
// delta = x - mean
// mean += delta/n
// delta2 = x - mean
// M2 += delta*delta2
//
// if n < 2:
// return float('nan')
// else:
// return M2 / (n - 1)
// Register to store the number of elements read so far.
float
count
=
0.
f
,
mean
[
ELEMENTS_PER_LDG
],
m2
[
ELEMENTS_PER_LDG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
mean
[
i
]
=
0.
f
;
m2
[
i
]
=
0.
f
;
}
// The number of elements loaded by this CTA.
int
cta_count
=
0
;
// The base pointer to load from.
const
uint16_t
*
gmem_src
=
&
params
.
gmem_src
[
thread_c
];
// outer loops
int
OUTER_LOOPS
=
OUTER_LOOPS_
==
1
?
1
:
params
.
outer_loops
;
// Load the batch of elements. Compute the mean/var across those elements.
const
int
pixels_per_iteration
=
PIXELS_PER_CTA_IN_REGISTERS
*
gridDim
.
x
;
if
(
OUTER_LOOPS_
!=
1
)
{
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized, offset is evenly divisible by 32
int
offset
=
(
pixels_per_iteration
*
OUTER_LOOPS
+
PIXELS_PER_CTA_IN_SMEM
*
gridDim
.
x
-
params
.
nhw
)
&
~
31
;
cta_nhw_regs
-=
offset
;
cta_nhw_smem
-=
offset
;
}
#pragma unroll 1
for
(
int
loop_i
=
0
;
loop_i
<
OUTER_LOOPS
;
++
loop_i
)
{
// The nhw position.
int
nhw_regs
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count
+=
max
(
min
(
nhw_regs
+
PIXELS_PER_CTA_IN_REGISTERS
,
params
.
nhw
)
-
max
(
nhw_regs
,
0
),
0
);
// Load the data and compute the local mean/sum and the variance.
if
(
USE_ONLINE_APPROACH
)
{
// Read the elements from memory.
float
is_valid
[
PIXELS_PER_THREAD_IN_REGISTERS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
nhw_regs
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
zero_array
(
x_storage
[
i
]);
is_valid
[
i
]
=
0.
f
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
if
(
loop_i
==
OUTER_LOOPS
-
1
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
}
else
{
ldg
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
}
is_valid
[
i
]
=
1.
f
;
}
}
// Do the math.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
// Update the count.
count
+=
is_valid
[
i
];
// Invert the count.
float
inv_count
=
is_valid
[
i
]
?
1.
f
/
count
:
0.
f
;
// Update the mean and m2 using deltas.
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
float
delta0
=
x_math
[
j
]
-
mean
[
j
];
mean
[
j
]
+=
delta0
*
inv_count
;
float
delta1
=
x_math
[
j
]
-
mean
[
j
];
m2
[
j
]
+=
delta0
*
delta1
*
is_valid
[
i
];
}
}
}
else
{
// Read the elements from memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
nhw_regs
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
zero_array
(
x_storage
[
i
]);
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
if
(
loop_i
==
OUTER_LOOPS
-
1
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
}
else
{
ldg
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
}
count
+=
1.
f
;
}
}
// Sum the elements in registers.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
// Update the mean and m2 using deltas.
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
mean
[
j
]
+=
x_math
[
j
];
}
}
// Compute the mean.
float
inv_count
=
1.
f
/
count
;
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
mean
[
j
]
*=
inv_count
;
}
// Compute the variance.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
// Is it a valid pixel?
float
is_valid
=
i
<
static_cast
<
int
>
(
count
)
?
1.
f
:
0.
f
;
// Update the mean and m2 using deltas.
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
m2
[
j
]
+=
(
x_math
[
j
]
-
mean
[
j
])
*
(
x_math
[
j
]
-
mean
[
j
])
*
is_valid
;
}
}
}
}
// The elements to load and store in SMEM.
int
smem_nhw
=
OUTER_LOOPS
*
pixels_per_iteration
+
cta_nhw_smem
;
// Load elements from SMEM, update the CTA count.
int
pixels_in_smem
=
min
(
smem_nhw
+
PIXELS_PER_CTA_IN_SMEM
,
params
.
nhw
)
-
max
(
smem_nhw
,
0
);
if
(
pixels_in_smem
>
0
)
{
cta_count
+=
pixels_in_smem
;
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
float
is_pixel_valid
=
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
?
1.
f
:
0.
f
;
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
ldg_stream
(
x_storage_local
,
&
gmem_src
[(
is_pixel_valid
?
idx
:
0
)
*
params
.
c
]);
// The offset to store in SMEM.
const
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
// Store in SMEM.
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
x_storage_local
);
// Update the count.
count
+=
is_pixel_valid
;
// Invert the count.
float
inv_count
=
is_pixel_valid
?
1.
f
/
count
:
0.
f
;
float
x_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
// Update the mean and m2 using deltas.
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
float
delta0
=
x_math
[
j
]
-
mean
[
j
];
mean
[
j
]
+=
delta0
*
inv_count
;
float
delta1
=
x_math
[
j
]
-
mean
[
j
];
m2
[
j
]
+=
delta0
*
delta1
*
is_pixel_valid
;
}
}
}
// We scale the mean by the number of elements. It brings more stability.
float
m1
[
ELEMENTS_PER_LDG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
m1
[
i
]
=
mean
[
i
]
*
count
;
}
// Run the parallel sum accross the CTA to get the local sum.
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
m1
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
m1
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// Adjust the variance.
float
inv_cta_count
=
1.
f
/
static_cast
<
float
>
(
cta_count
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
float
mean_diff
=
m1
[
i
]
*
inv_cta_count
-
mean
[
i
];
m2
[
i
]
=
m2
[
i
]
+
mean_diff
*
mean_diff
*
count
;
}
// Run the parallel sum accross the CTA to get the local adjusted variance.
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
m2
,
thread_in_cta_nhw
);
// The workspace in global memory is distributed across the different CTA.
int
gmem_sums_offset
=
c_blk_index
*
gridDim
.
x
*
C_ELEMENTS_PER_CTA
*
2
;
// Write the data for the CTA to global memory.
float
*
gmem_sums
=
&
params
.
gmem_sums
[
gmem_sums_offset
];
if
(
threadIdx
.
x
<
THREADS_PER_PIXEL
)
{
const
int
idx
=
blockIdx
.
x
*
THREADS_PER_PIXEL
+
threadIdx
.
x
;
write_to_gmem
(
&
gmem_sums
[
0
],
idx
,
m1
);
write_to_gmem
(
&
gmem_sums
[
C_ELEMENTS_PER_CTA
*
gridDim
.
x
],
idx
,
m2
);
}
// The memory location to store the number of pixels per CTA.
int
*
gmem_counts
=
&
params
.
gmem_counts
[
c_blk_index
*
gridDim
.
x
];
if
(
threadIdx
.
x
==
0
)
{
gmem_counts
[
blockIdx
.
x
]
=
cta_count
;
}
// Read the bias and scale.
float
bias
[
ELEMENTS_PER_LDG
],
scale
[
ELEMENTS_PER_LDG
];
if
(
is_valid_c
)
{
read_from_gmem
(
bias
,
&
params
.
gmem_bias
[
cta_c
],
thread_in_cta_c
);
read_from_gmem
(
scale
,
&
params
.
gmem_scale
[
cta_c
],
thread_in_cta_c
);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int
*
gmem_retired_ctas
=
&
params
.
gmem_retired_ctas
[
c_blk_index
%
(
2
*
gridDim
.
y
)];
inter_block_sync
(
gmem_retired_ctas
,
gridDim
.
x
,
blockIdx
.
x
==
0
);
// Reset the mean to compute the global mean.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
m1
[
i
]
=
0.
f
;
}
// Build the global mean.
#pragma unroll 1
for
(
int
idx
=
threadIdx
.
x
;
idx
<
THREADS_PER_PIXEL
*
gridDim
.
x
;
idx
+=
THREADS_PER_CTA
)
{
float
tmp
[
ELEMENTS_PER_LDG
];
read_from_gmem
(
tmp
,
gmem_sums
,
idx
);
add
(
m1
,
tmp
);
}
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
m1
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
3
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
m1
,
thread_in_cta_nhw
);
}
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
m1
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// Normalize the mean.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
m1
[
i
]
=
m1
[
i
]
*
params
.
svar_inv_count
;
}
// Reset the variance.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
m2
[
i
]
=
0.
f
;
}
// for add+relu fusion
const
uint16_t
*
gmem_src1
=
nullptr
;
if
(
USE_ADD_RELU
)
{
gmem_src1
=
&
params
.
gmem_src1
[
thread_c
];
}
// Build the global variance.
#pragma unroll 1
for
(
int
idx
=
threadIdx
.
x
;
idx
<
THREADS_PER_PIXEL
*
gridDim
.
x
;
idx
+=
THREADS_PER_CTA
)
{
// Read the means computed by different CTAs (again). Reuse tmp if we have 1 iteration.
float
tmp_mean
[
ELEMENTS_PER_LDG
],
tmp_var
[
ELEMENTS_PER_LDG
];
read_from_gmem
(
tmp_mean
,
&
gmem_sums
[
0
],
idx
);
read_from_gmem
(
tmp_var
,
&
gmem_sums
[
C_ELEMENTS_PER_CTA
*
gridDim
.
x
],
idx
);
// Read the number of pixels visited by a given CTA.
cta_count
=
__ldg
(
&
gmem_counts
[
idx
/
THREADS_PER_PIXEL
]);
// Compute the diff to update the variance.
float
mean_diff
[
ELEMENTS_PER_LDG
],
inv_cta_count
=
1.
f
/
static_cast
<
float
>
(
cta_count
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
mean_diff
[
i
]
=
m1
[
i
]
-
tmp_mean
[
i
]
*
inv_cta_count
;
}
// Update the variance.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
m2
[
i
]
+=
tmp_var
[
i
]
+
mean_diff
[
i
]
*
mean_diff
[
i
]
*
static_cast
<
float
>
(
cta_count
);
}
}
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
m2
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
2
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
m2
,
thread_in_cta_nhw
);
}
__syncthreads
();
read_from_smem
(
m2
,
smem
,
thread_in_cta_c
);
// Finalize the stddev.
// becasue saved var and running var may have different denominator, we don't do it here
// scale_(m2, inv_count);
// store the saved mean/var
float
svarinv
[
ELEMENTS_PER_LDG
];
bool
is_valid_for_saving
=
is_valid_c
&&
blockIdx
.
x
==
0
&&
thread_in_cta_nhw
==
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
svarinv
[
i
]
=
rsqrtf
(
m2
[
i
]
*
params
.
svar_inv_count
+
params
.
var_eps
);
}
if
(
is_valid_for_saving
)
{
write_to_gmem
(
params
.
gmem_saved_mean
,
thread_c
/
ELEMENTS_PER_LDG
,
m1
);
write_to_gmem
(
params
.
gmem_saved_var
,
thread_c
/
ELEMENTS_PER_LDG
,
svarinv
);
}
// store the running mean/var
float
rmean
[
ELEMENTS_PER_LDG
],
rvar
[
ELEMENTS_PER_LDG
];
zero_array
(
rmean
);
zero_array
(
rvar
);
if
(
params
.
exp_avg_factor
!=
1.
f
&&
is_valid_for_saving
)
{
read_from_gmem
(
rmean
,
params
.
gmem_running_mean
,
thread_c
/
ELEMENTS_PER_LDG
);
read_from_gmem
(
rvar
,
params
.
gmem_running_var
,
thread_c
/
ELEMENTS_PER_LDG
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
rmean
[
i
]
=
(
1.
f
-
params
.
exp_avg_factor
)
*
rmean
[
i
]
+
\
params
.
exp_avg_factor
*
m1
[
i
];
rvar
[
i
]
=
(
1.
f
-
params
.
exp_avg_factor
)
*
rvar
[
i
]
+
\
params
.
exp_avg_factor
*
(
m2
[
i
]
*
params
.
rvar_inv_count
);
}
if
(
is_valid_for_saving
)
{
write_to_gmem
(
params
.
gmem_running_mean
,
thread_c
/
ELEMENTS_PER_LDG
,
rmean
);
write_to_gmem
(
params
.
gmem_running_var
,
thread_c
/
ELEMENTS_PER_LDG
,
rvar
);
}
// Update the scale with the stddev and eps.
multiply
(
scale
,
svarinv
);
// The base pointer to write to.
uint16_t
*
const
gmem_dst
=
&
params
.
gmem_dst
[
thread_c
];
unsigned
int
*
const
gmem_relu_bitmask
=
params
.
gmem_relu_bitmask
+
((
params
.
nhw
+
31
)
&
~
31
)
*
2
*
c_blk_index
;
// Store the elements in registers.
#pragma unroll 1
for
(
int
loop_i
=
OUTER_LOOPS
-
1
;
loop_i
>=
0
;
--
loop_i
)
{
// The value for nhw.
int
out_nhw
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Normalize the elements and write to memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_valid_nhw
=
static_cast
<
unsigned
int
>
(
idx
)
<
static_cast
<
unsigned
int
>
(
params
.
nhw
);
const
bool
is_valid
=
is_valid_nhw
&&
is_valid_c
;
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
// Normalize and apply activation function
normalize
(
x_math
,
bias
,
scale
,
m1
);
if
(
USE_ADD_RELU
)
{
float
x1_math
[
ELEMENTS_PER_LDG
];
ldg_stream
(
x1_math
,
&
gmem_src1
[(
is_valid
?
idx
:
0
)
*
params
.
c
]);
add
(
x_math
,
x1_math
);
unsigned
int
relu_mask
;
int
lane_id
=
threadIdx
.
x
&
31
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
bool
rectified
=
x_math
[
i
]
<
0.0
F
;
unsigned
int
local_relu_mask
=
__ballot_sync
(
0xFFFFFFFFU
,
rectified
);
if
(
lane_id
==
i
)
{
// Thread 0 remembers the relu_mask from the first time through this
// loop, Thread 1 the next, Thread 2 the next, and Thread 3 the last.
relu_mask
=
local_relu_mask
;
}
if
(
rectified
)
{
x_math
[
i
]
=
0.0
F
;
}
}
if
(
is_valid_nhw
&&
(
lane_id
<
ELEMENTS_PER_LDG
))
{
gmem_relu_bitmask
[
idx
*
2
+
lane_id
]
=
relu_mask
;
}
}
else
if
(
USE_RELU
)
{
relu_activation
(
x_math
);
}
// Write back.
if
(
is_valid
)
{
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
x_math
);
}
}
// The next value of nhw.
out_nhw
-=
pixels_per_iteration
;
// Read the next elements from memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
}
}
}
// Normalize the elements from SMEM and write them out.
if
(
pixels_in_smem
>
0
)
{
#pragma unroll 2
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_valid_nhw
=
static_cast
<
unsigned
int
>
(
idx
)
<
static_cast
<
unsigned
int
>
(
params
.
nhw
);
const
bool
is_valid
=
is_valid_nhw
&&
is_valid_c
;
// Read from SMEM.
const
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
read_from_smem
(
x_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
float
x_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
// Normalize and apply activation function
normalize
(
x_math
,
bias
,
scale
,
m1
);
if
(
USE_ADD_RELU
)
{
float
x1_math
[
ELEMENTS_PER_LDG
];
ldg_stream
(
x1_math
,
&
gmem_src1
[(
is_valid
?
idx
:
0
)
*
params
.
c
]);
add
(
x_math
,
x1_math
);
unsigned
int
relu_mask
;
int
lane_id
=
threadIdx
.
x
&
31
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
bool
rectified
=
x_math
[
i
]
<
0.0
F
;
unsigned
int
local_relu_mask
=
__ballot_sync
(
0xFFFFFFFFU
,
rectified
);
if
(
lane_id
==
i
)
{
relu_mask
=
local_relu_mask
;
}
if
(
rectified
)
{
x_math
[
i
]
=
0.0
F
;
}
}
if
(
is_valid_nhw
&&
(
lane_id
<
ELEMENTS_PER_LDG
))
{
gmem_relu_bitmask
[
idx
*
2
+
lane_id
]
=
relu_mask
;
}
}
else
if
(
USE_RELU
)
{
relu_activation
(
x_math
);
}
// Write back.
if
(
is_valid
)
{
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
x_math
);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads
();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
NhwcBatchNormBwdParams
{
// The input/output tensors.
uint16_t
*
gmem_src
,
*
gmem_dy
,
*
gmem_dst
,
*
gmem_dst1
;
// dscale/dbias
float
*
gmem_dscale
,
*
gmem_dbias
;
// The scale and bias.
float
*
gmem_scale
,
*
gmem_bias
;
// The mean/inv-var saved from fwd pass
float
*
gmem_saved_mean
,
*
gmem_saved_var
;
// ReLU bitmask
unsigned
int
*
gmem_relu_bitmask
;
// The dimensions.
int
nhw
,
c
;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
float
svar_inv_count
;
// The buffer to do the reduction for dscale and dbias
float
*
gmem_sums
;
// The counters of retired CTAs.
int
*
gmem_retired_ctas
;
// outer loop count
int
outer_loops
;
// number of CTAs along .x dimension
int
c_blks
;
void
*
my_data
;
void
*
pair_datas
[
4
];
int
magic
;
int
sync_iters
;
float
wgrad_coeff
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
relu_bwd
(
float
(
&
dy
)[
N
],
const
float
(
&
x
)[
N
],
const
float
(
&
mean_var_scale_bias
)[
N
],
const
float
(
&
var_scale
)[
N
],
bool
valid_data
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
float
y
=
(
x
[
j
]
*
var_scale
[
j
])
+
mean_var_scale_bias
[
j
];
if
((
y
<=
0.
f
)
&&
valid_data
)
{
dy
[
j
]
=
0.
f
;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
relu_bwd
(
float
(
&
dy
)[
N
],
const
float
(
&
y
)[
N
],
bool
valid_data
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
if
((
y
[
j
]
<=
0.
f
)
&&
valid_data
)
{
dy
[
j
]
=
0.
f
;
}
}
}
template
<
int
N
>
DEVICE_FUNCTION
void
relu_bwd
(
float
(
&
dy
)[
N
],
const
bool
(
&
rectified
)[
N
],
bool
valid_data
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
if
(
rectified
[
j
]
&&
valid_data
)
{
dy
[
j
]
=
0.
f
;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
relu_bwd_for_dx
(
float
(
&
dy
)[
N
],
const
float
(
&
x
)[
N
],
const
float
(
&
mean_var_scale_bias
)[
N
],
const
float
(
&
var_scale
)[
N
])
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
float
y
=
(
x
[
j
]
*
var_scale
[
j
])
+
mean_var_scale_bias
[
j
];
if
(
y
<=
0.
f
)
{
dy
[
j
]
=
0.
f
;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
relu_bwd_for_dx
(
float
(
&
dy
)[
N
],
const
float
(
&
y
)[
N
])
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
if
(
y
[
j
]
<=
0.
f
)
{
dy
[
j
]
=
0.
f
;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
bwd_update
(
float
(
&
dscale
)[
N
],
float
(
&
dbias
)[
N
],
const
float
(
&
dy
)[
N
],
const
float
(
&
x
)[
N
],
const
float
(
&
mean
)[
N
],
float
inv_count
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
float
delta0
=
dy
[
j
]
-
dbias
[
j
];
dbias
[
j
]
+=
delta0
*
inv_count
;
delta0
=
(
dy
[
j
]
*
(
x
[
j
]
-
mean
[
j
]))
-
dscale
[
j
];
dscale
[
j
]
+=
delta0
*
inv_count
;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
bwd_dx
(
float
(
&
dx
)[
N
],
const
float
(
&
dy
)[
N
],
const
float
(
&
var
)[
N
],
const
float
(
&
x
)[
N
],
const
float
(
&
mean
)[
N
],
const
float
(
&
dscale
)[
N
],
const
float
(
&
dbias
)[
N
],
float
inv_count
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
float
tmp1
=
dy
[
j
]
-
(
dbias
[
j
]
*
inv_count
);
float
tmp2
=
dscale
[
j
]
*
inv_count
;
float
tmp3
=
x
[
j
]
-
mean
[
j
];
dx
[
j
]
=
var
[
j
]
*
(
tmp1
-
(
tmp2
*
tmp3
));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Storage
,
int
THREADS_PER_CTA
,
int
THREADS_PER_PIXEL
,
int
PIXELS_PER_THREAD_IN_REGISTERS
,
int
PIXELS_PER_THREAD_IN_SMEM
,
int
ELEMENTS_PER_LDG
,
int
USE_ONLINE_APPROACH
,
int
OUTER_LOOPS_
,
int
DESIRED_OCCUPANCY
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
,
DESIRED_OCCUPANCY
)
void
nhwc_batch_norm_bwd
(
NhwcBatchNormBwdParams
params
)
{
// The number of pixels loaded in a single LDG.
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
// The number of pixels computed per CTA stored in registers.
const
int
PIXELS_PER_CTA_IN_REGISTERS
=
PIXELS_PER_THREAD_IN_REGISTERS
*
PIXELS_PER_LDG
;
// The number of pixels computed per CTA stored in SMEM.
const
int
PIXELS_PER_CTA_IN_SMEM
=
PIXELS_PER_THREAD_IN_SMEM
*
PIXELS_PER_LDG
;
// The number of C elements per CTA.
const
int
C_ELEMENTS_PER_CTA
=
THREADS_PER_PIXEL
*
ELEMENTS_PER_LDG
;
// Shared memory to do CTA-wide parallel sums.
__shared__
float
smem
[
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
32
)
*
ELEMENTS_PER_LDG
];
// The adapter for the storage.
typedef
PackedStorage
<
Storage
,
ELEMENTS_PER_LDG
>
PackedStorage_
;
// The data type for packed storage in SMEM.
typedef
typename
PackedStorage_
::
Type
PackedStorageType
;
// The number of elements in the packed storage.
const
int
PACKED_ELEMENTS_PER_LDG
=
PackedStorage_
::
PACKED_ELEMENTS_PER_LDG
;
// Registers to keep the data live for the persistent approach.
PackedStorageType
x_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
PackedStorageType
dy_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
// Shared memory buffer to store the extra pixels.
extern
__shared__
PackedStorageType
smem_storage_packed
[];
for
(
int
c_blk_index
=
blockIdx
.
y
;
c_blk_index
<
params
.
c_blks
;
c_blk_index
+=
gridDim
.
y
)
{
// The position in the NHW dimension where the CTA starts.
int
cta_nhw_regs
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_REGISTERS
;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int
cta_nhw_smem
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_SMEM
;
// Compute the NHW coordinate of the thread in the CTA.
const
int
thread_in_cta_nhw
=
threadIdx
.
x
/
THREADS_PER_PIXEL
;
// The position in the C dimension where the CTA starts.
const
int
cta_c
=
c_blk_index
*
C_ELEMENTS_PER_CTA
;
// Compute the C coordinate of the thread in the CTA.
const
int
thread_in_cta_c
=
threadIdx
.
x
%
THREADS_PER_PIXEL
;
// Compute the C coordinate of the thread.
const
int
thread_c
=
cta_c
+
thread_in_cta_c
*
ELEMENTS_PER_LDG
;
// Is the thread working on a valid C dimension?
const
int
is_valid_c
=
thread_c
<
params
.
c
;
// Registers to store the mean used for entire duration
float
mean
[
ELEMENTS_PER_LDG
];
zero_array
(
mean
);
if
(
is_valid_c
)
{
read_from_gmem
(
mean
,
params
.
gmem_saved_mean
,
thread_c
/
ELEMENTS_PER_LDG
);
}
// accumulation related registers
float
count
=
0.
f
,
dscale
[
ELEMENTS_PER_LDG
],
dbias
[
ELEMENTS_PER_LDG
];
zero_array
(
dscale
);
zero_array
(
dbias
);
// The number of elements loaded by this CTA.
int
cta_count
=
0
;
// The base pointers to load from.
const
uint16_t
*
gmem_src
=
&
params
.
gmem_src
[
thread_c
];
const
uint16_t
*
gmem_dy
=
&
params
.
gmem_dy
[
thread_c
];
// outer loops
int
OUTER_LOOPS
=
OUTER_LOOPS_
==
1
?
1
:
params
.
outer_loops
;
// Load the batch of elements. Compute sum across them
const
int
pixels_per_iteration
=
PIXELS_PER_CTA_IN_REGISTERS
*
gridDim
.
x
;
if
(
OUTER_LOOPS_
!=
1
)
{
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized
int
offset
=
params
.
nhw
-
pixels_per_iteration
*
OUTER_LOOPS
-
PIXELS_PER_CTA_IN_SMEM
*
gridDim
.
x
;
cta_nhw_regs
+=
offset
;
cta_nhw_smem
+=
offset
;
}
#pragma unroll 1
for
(
int
loop_i
=
0
;
loop_i
<
OUTER_LOOPS
;
++
loop_i
)
{
// The nhw position.
int
nhw_regs
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count
+=
max
(
0
,
min
(
PIXELS_PER_CTA_IN_REGISTERS
,
params
.
nhw
-
nhw_regs
));
// Read the elements from memory.
float
is_valid
[
PIXELS_PER_THREAD_IN_REGISTERS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
nhw_regs
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
zero_array
(
x_storage
[
i
]);
zero_array
(
dy_storage
[
i
]);
is_valid
[
i
]
=
0.
f
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
if
(
loop_i
==
OUTER_LOOPS
-
1
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
else
{
ldg
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
is_valid
[
i
]
=
1.
f
;
}
}
// Do the math.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float and update
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
to_float
(
dy_math
,
dy_storage
[
i
]);
// Update the count.
count
+=
is_valid
[
i
];
// Invert the count.
float
inv_count
=
is_valid
[
i
]
?
1.
f
/
count
:
0.
f
;
bwd_update
(
dscale
,
dbias
,
dy_math
,
x_math
,
mean
,
inv_count
);
}
}
// The elements to load and store in SMEM.
int
smem_nhw
=
OUTER_LOOPS
*
pixels_per_iteration
+
cta_nhw_smem
;
// Load elements from SMEM, update the CTA count.
int
pixels_in_smem
=
min
(
PIXELS_PER_CTA_IN_SMEM
,
params
.
nhw
-
smem_nhw
);
if
(
pixels_in_smem
>
0
)
{
cta_count
+=
pixels_in_smem
;
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
bool
is_pixel_valid
=
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
);
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
],
dy_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
zero_array
(
x_storage_local
);
zero_array
(
dy_storage_local
);
if
(
is_pixel_valid
)
{
ldg_stream
(
x_storage_local
,
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage_local
,
&
gmem_dy
[
idx
*
params
.
c
]);
}
// The offset to store in SMEM.
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
// Store in SMEM.
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
x_storage_local
);
offset
+=
PIXELS_PER_THREAD_IN_SMEM
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
dy_storage_local
);
// Update the count.
count
+=
is_pixel_valid
;
// Invert the count.
float
inv_count
=
is_pixel_valid
?
1.
f
/
count
:
0.
f
;
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
to_float
(
dy_math
,
dy_storage_local
);
bwd_update
(
dscale
,
dbias
,
dy_math
,
x_math
,
mean
,
inv_count
);
}
}
// We scale the mean by the number of elements. It brings more stability.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
dbias
[
i
]
*=
count
;
dscale
[
i
]
*=
count
;
}
// dscale parallel sum
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dscale
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// dbias parallel sum
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dbias
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// The workspace in global memory is distributed across the different CTA.
int
gmem_sums_offset
=
c_blk_index
*
gridDim
.
x
*
C_ELEMENTS_PER_CTA
*
2
;
// Write the data for the CTA to global memory.
float
*
gmem_sums
=
&
params
.
gmem_sums
[
gmem_sums_offset
];
if
(
threadIdx
.
x
<
THREADS_PER_PIXEL
)
{
const
int
idx
=
blockIdx
.
x
*
THREADS_PER_PIXEL
+
threadIdx
.
x
;
write_to_gmem
(
&
gmem_sums
[
0
],
idx
,
dscale
);
write_to_gmem
(
&
gmem_sums
[
C_ELEMENTS_PER_CTA
*
gridDim
.
x
],
idx
,
dbias
);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int
*
gmem_retired_ctas
=
&
params
.
gmem_retired_ctas
[
c_blk_index
%
(
2
*
gridDim
.
y
)];
inter_block_sync
(
gmem_retired_ctas
,
gridDim
.
x
,
blockIdx
.
x
==
0
);
// Reset the accumulators for global summation
zero_array
(
dscale
);
zero_array
(
dbias
);
// Build the global accumulation
#pragma unroll 1
for
(
int
idx
=
threadIdx
.
x
;
idx
<
THREADS_PER_PIXEL
*
gridDim
.
x
;
idx
+=
THREADS_PER_CTA
)
{
float
tmp1
[
ELEMENTS_PER_LDG
],
tmp2
[
ELEMENTS_PER_LDG
];
read_from_gmem
(
tmp1
,
gmem_sums
,
idx
);
read_from_gmem
(
tmp2
,
gmem_sums
+
C_ELEMENTS_PER_CTA
*
gridDim
.
x
,
idx
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
dscale
[
i
]
+=
tmp1
[
i
];
dbias
[
i
]
+=
tmp2
[
i
];
}
}
// dscale parallel sum
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
1
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
);
}
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dscale
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// dbias parallel sum
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
0
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
);
}
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dbias
,
smem
,
thread_in_cta_c
);
// inv-var
float
var
[
ELEMENTS_PER_LDG
];
zero_array
(
var
);
if
(
is_valid_c
)
{
read_from_gmem
(
var
,
params
.
gmem_saved_var
,
thread_c
/
ELEMENTS_PER_LDG
);
}
// Normalize the dscale.
multiply
(
dscale
,
var
);
// store dscale/dbias
bool
is_valid_for_saving
=
is_valid_c
&&
blockIdx
.
x
==
0
&&
thread_in_cta_nhw
==
0
;
if
(
is_valid_for_saving
)
{
if
(
params
.
sync_iters
>
0
)
{
scaled_write_to_gmem
(
params
.
gmem_dscale
,
thread_c
/
ELEMENTS_PER_LDG
,
dscale
,
params
.
wgrad_coeff
);
scaled_write_to_gmem
(
params
.
gmem_dbias
,
thread_c
/
ELEMENTS_PER_LDG
,
dbias
,
params
.
wgrad_coeff
);
}
else
{
write_to_gmem
(
params
.
gmem_dscale
,
thread_c
/
ELEMENTS_PER_LDG
,
dscale
);
write_to_gmem
(
params
.
gmem_dbias
,
thread_c
/
ELEMENTS_PER_LDG
,
dbias
);
}
}
// scale
float
scale
[
ELEMENTS_PER_LDG
];
zero_array
(
scale
);
if
(
is_valid_c
)
{
read_from_gmem
(
scale
,
params
.
gmem_scale
,
thread_c
/
ELEMENTS_PER_LDG
);
}
// Further normalize the dscale to be used in dx calculation
multiply
(
dscale
,
var
);
// scale the inv-var as well, afterwards
multiply
(
var
,
scale
);
// inverse count
float
inv_count
=
params
.
svar_inv_count
;
// The base pointer to write to.
uint16_t
*
const
gmem_dst
=
&
params
.
gmem_dst
[
thread_c
];
// Store the elements in registers.
#pragma unroll 1
for
(
int
loop_i
=
OUTER_LOOPS
-
1
;
loop_i
>=
0
;
--
loop_i
)
{
// The value for nhw.
int
out_nhw
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Normalize the elements and write to memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
to_float
(
dy_math
,
dy_storage
[
i
]);
float
dx
[
ELEMENTS_PER_LDG
];
bwd_dx
(
dx
,
dy_math
,
var
,
x_math
,
mean
,
dscale
,
dbias
,
inv_count
);
// Write back.
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
dx
);
}
}
// The next value of nhw.
out_nhw
-=
pixels_per_iteration
;
// Read the next elements from memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
}
}
// Normalize the elements from SMEM and write them out.
if
(
pixels_in_smem
>
0
)
{
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_valid
=
((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
;
if
(
is_valid
)
{
// Read from SMEM.
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
],
dy_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
read_from_smem
(
x_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
offset
+=
PIXELS_PER_THREAD_IN_SMEM
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
read_from_smem
(
dy_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
to_float
(
dy_math
,
dy_storage_local
);
float
dx
[
ELEMENTS_PER_LDG
];
bwd_dx
(
dx
,
dy_math
,
var
,
x_math
,
mean
,
dscale
,
dbias
,
inv_count
);
// Write back.
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
dx
);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads
();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Storage
,
int
THREADS_PER_CTA
,
int
THREADS_PER_PIXEL
,
int
PIXELS_PER_THREAD_IN_REGISTERS
,
int
PIXELS_PER_THREAD_IN_SMEM
,
int
ELEMENTS_PER_LDG
,
int
USE_ONLINE_APPROACH
,
int
OUTER_LOOPS_
,
int
DESIRED_OCCUPANCY
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
,
DESIRED_OCCUPANCY
)
void
nhwc_batch_norm_bwd_relu
(
NhwcBatchNormBwdParams
params
)
{
// The number of pixels loaded in a single LDG.
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
// The number of pixels computed per CTA stored in registers.
const
int
PIXELS_PER_CTA_IN_REGISTERS
=
PIXELS_PER_THREAD_IN_REGISTERS
*
PIXELS_PER_LDG
;
// The number of pixels computed per CTA stored in SMEM.
const
int
PIXELS_PER_CTA_IN_SMEM
=
PIXELS_PER_THREAD_IN_SMEM
*
PIXELS_PER_LDG
;
// The number of C elements per CTA.
const
int
C_ELEMENTS_PER_CTA
=
THREADS_PER_PIXEL
*
ELEMENTS_PER_LDG
;
// Shared memory to do CTA-wide parallel sums.
__shared__
float
smem
[
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
32
)
*
ELEMENTS_PER_LDG
];
// The adapter for the storage.
typedef
PackedStorage
<
Storage
,
ELEMENTS_PER_LDG
>
PackedStorage_
;
// The data type for packed storage in SMEM.
typedef
typename
PackedStorage_
::
Type
PackedStorageType
;
// The number of elements in the packed storage.
const
int
PACKED_ELEMENTS_PER_LDG
=
PackedStorage_
::
PACKED_ELEMENTS_PER_LDG
;
// Registers to keep the data live for the persistent approach.
PackedStorageType
x_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
PackedStorageType
dy_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
// Shared memory buffer to store the extra pixels.
extern
__shared__
PackedStorageType
smem_storage_packed
[];
for
(
int
c_blk_index
=
blockIdx
.
y
;
c_blk_index
<
params
.
c_blks
;
c_blk_index
+=
gridDim
.
y
)
{
// The position in the NHW dimension where the CTA starts.
int
cta_nhw_regs
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_REGISTERS
;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int
cta_nhw_smem
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_SMEM
;
// Compute the NHW coordinate of the thread in the CTA.
const
int
thread_in_cta_nhw
=
threadIdx
.
x
/
THREADS_PER_PIXEL
;
// The position in the C dimension where the CTA starts.
const
int
cta_c
=
c_blk_index
*
C_ELEMENTS_PER_CTA
;
// Compute the C coordinate of the thread in the CTA.
const
int
thread_in_cta_c
=
threadIdx
.
x
%
THREADS_PER_PIXEL
;
// Compute the C coordinate of the thread.
const
int
thread_c
=
cta_c
+
thread_in_cta_c
*
ELEMENTS_PER_LDG
;
// Is the thread working on a valid C dimension?
const
int
is_valid_c
=
thread_c
<
params
.
c
;
// Registers to store the mean/var/scale/bias used for the entire duration
// Register usage optimizations:
// 1. Can combine bias - (mean * var * scale) into a single register
// 2. Can combine var * scale into a single register
float
varscale
[
ELEMENTS_PER_LDG
];
zero_array
(
varscale
);
if
(
is_valid_c
)
{
read_from_gmem
(
varscale
,
params
.
gmem_saved_var
,
thread_c
/
ELEMENTS_PER_LDG
);
}
float
tmp
[
ELEMENTS_PER_LDG
];
zero_array
(
tmp
);
if
(
is_valid_c
)
{
read_from_gmem
(
tmp
,
params
.
gmem_scale
,
thread_c
/
ELEMENTS_PER_LDG
);
}
multiply
(
varscale
,
tmp
);
float
mean
[
ELEMENTS_PER_LDG
];
zero_array
(
mean
);
if
(
is_valid_c
)
{
read_from_gmem
(
mean
,
params
.
gmem_saved_mean
,
thread_c
/
ELEMENTS_PER_LDG
);
}
zero_array
(
tmp
);
if
(
is_valid_c
)
{
read_from_gmem
(
tmp
,
params
.
gmem_bias
,
thread_c
/
ELEMENTS_PER_LDG
);
}
float
mean_var_scale_bias
[
ELEMENTS_PER_LDG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
mean_var_scale_bias
[
i
]
=
tmp
[
i
]
-
(
mean
[
i
]
*
varscale
[
i
]);
}
// accumulation related registers
float
count
=
0.
f
,
dscale
[
ELEMENTS_PER_LDG
],
dbias
[
ELEMENTS_PER_LDG
];
zero_array
(
dscale
);
zero_array
(
dbias
);
// The number of elements loaded by this CTA.
int
cta_count
=
0
;
// The base pointers to load from.
const
uint16_t
*
gmem_src
=
&
params
.
gmem_src
[
thread_c
];
const
uint16_t
*
gmem_dy
=
&
params
.
gmem_dy
[
thread_c
];
// outer loops
int
OUTER_LOOPS
=
OUTER_LOOPS_
==
1
?
1
:
params
.
outer_loops
;
// Load the batch of elements. Compute sum across them
const
int
pixels_per_iteration
=
PIXELS_PER_CTA_IN_REGISTERS
*
gridDim
.
x
;
if
(
OUTER_LOOPS_
!=
1
)
{
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized
int
offset
=
params
.
nhw
-
pixels_per_iteration
*
OUTER_LOOPS
-
PIXELS_PER_CTA_IN_SMEM
*
gridDim
.
x
;
cta_nhw_regs
+=
offset
;
cta_nhw_smem
+=
offset
;
}
#pragma unroll 1
for
(
int
loop_i
=
0
;
loop_i
<
OUTER_LOOPS
;
++
loop_i
)
{
// The nhw position.
int
nhw_regs
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count
+=
max
(
0
,
min
(
PIXELS_PER_CTA_IN_REGISTERS
,
params
.
nhw
-
nhw_regs
));
// Read the elements from memory.
float
is_valid
[
PIXELS_PER_THREAD_IN_REGISTERS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
nhw_regs
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
zero_array
(
x_storage
[
i
]);
zero_array
(
dy_storage
[
i
]);
is_valid
[
i
]
=
0.
f
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
if
(
loop_i
==
OUTER_LOOPS
-
1
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
else
{
ldg
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
is_valid
[
i
]
=
1.
f
;
}
}
// Do the math.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float and update
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
to_float
(
dy_math
,
dy_storage
[
i
]);
// Update the count.
count
+=
is_valid
[
i
];
// Invert the count.
float
inv_count
=
is_valid
[
i
]
?
1.
f
/
count
:
0.
f
;
relu_bwd
(
dy_math
,
x_math
,
mean_var_scale_bias
,
varscale
,
is_valid
[
i
]);
bwd_update
(
dscale
,
dbias
,
dy_math
,
x_math
,
mean
,
inv_count
);
}
}
// The elements to load and store in SMEM.
int
smem_nhw
=
OUTER_LOOPS
*
pixels_per_iteration
+
cta_nhw_smem
;
// Load elements from SMEM, update the CTA count.
int
pixels_in_smem
=
min
(
PIXELS_PER_CTA_IN_SMEM
,
params
.
nhw
-
smem_nhw
);
if
(
pixels_in_smem
>
0
)
{
cta_count
+=
pixels_in_smem
;
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
bool
is_pixel_valid
=
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
);
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
],
dy_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
zero_array
(
x_storage_local
);
zero_array
(
dy_storage_local
);
if
(
is_pixel_valid
)
{
ldg_stream
(
x_storage_local
,
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage_local
,
&
gmem_dy
[
idx
*
params
.
c
]);
}
// The offset to store in SMEM.
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
// Store in SMEM.
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
x_storage_local
);
offset
+=
PIXELS_PER_THREAD_IN_SMEM
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
dy_storage_local
);
// Update the count.
count
+=
is_pixel_valid
;
// Invert the count.
float
inv_count
=
is_pixel_valid
?
1.
f
/
count
:
0.
f
;
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
to_float
(
dy_math
,
dy_storage_local
);
relu_bwd
(
dy_math
,
x_math
,
mean_var_scale_bias
,
varscale
,
is_pixel_valid
);
bwd_update
(
dscale
,
dbias
,
dy_math
,
x_math
,
mean
,
inv_count
);
}
}
// We scale the mean by the number of elements. It brings more stability.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
dbias
[
i
]
*=
count
;
dscale
[
i
]
*=
count
;
}
// dscale parallel sum
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dscale
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// dbias parallel sum
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dbias
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// The workspace in global memory is distributed across the different CTA.
int
gmem_sums_offset
=
c_blk_index
*
gridDim
.
x
*
C_ELEMENTS_PER_CTA
*
2
;
// Write the data for the CTA to global memory.
float
*
gmem_sums
=
&
params
.
gmem_sums
[
gmem_sums_offset
];
if
(
threadIdx
.
x
<
THREADS_PER_PIXEL
)
{
const
int
idx
=
blockIdx
.
x
*
THREADS_PER_PIXEL
+
threadIdx
.
x
;
write_to_gmem
(
&
gmem_sums
[
0
],
idx
,
dscale
);
write_to_gmem
(
&
gmem_sums
[
C_ELEMENTS_PER_CTA
*
gridDim
.
x
],
idx
,
dbias
);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int
*
gmem_retired_ctas
=
&
params
.
gmem_retired_ctas
[
c_blk_index
%
(
2
*
gridDim
.
y
)];
inter_block_sync
(
gmem_retired_ctas
,
gridDim
.
x
,
blockIdx
.
x
==
0
);
// Reset the accumulators for global summation
zero_array
(
dscale
);
zero_array
(
dbias
);
// Build the global accumulation
#pragma unroll 1
for
(
int
idx
=
threadIdx
.
x
;
idx
<
THREADS_PER_PIXEL
*
gridDim
.
x
;
idx
+=
THREADS_PER_CTA
)
{
float
tmp1
[
ELEMENTS_PER_LDG
],
tmp2
[
ELEMENTS_PER_LDG
];
read_from_gmem
(
tmp1
,
gmem_sums
,
idx
);
read_from_gmem
(
tmp2
,
gmem_sums
+
C_ELEMENTS_PER_CTA
*
gridDim
.
x
,
idx
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
dscale
[
i
]
+=
tmp1
[
i
];
dbias
[
i
]
+=
tmp2
[
i
];
}
}
// dscale parallel sum
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
1
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
);
}
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dscale
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// dbias parallel sum
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
0
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
);
}
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dbias
,
smem
,
thread_in_cta_c
);
// Normalize the dscale.
float
var
[
ELEMENTS_PER_LDG
];
zero_array
(
var
);
if
(
is_valid_c
)
{
read_from_gmem
(
var
,
params
.
gmem_saved_var
,
thread_c
/
ELEMENTS_PER_LDG
);
}
multiply
(
dscale
,
var
);
// store dscale/dbias
bool
is_valid_for_saving
=
is_valid_c
&&
blockIdx
.
x
==
0
&&
thread_in_cta_nhw
==
0
;
if
(
is_valid_for_saving
)
{
if
(
params
.
sync_iters
>
0
)
{
scaled_write_to_gmem
(
params
.
gmem_dscale
,
thread_c
/
ELEMENTS_PER_LDG
,
dscale
,
params
.
wgrad_coeff
);
scaled_write_to_gmem
(
params
.
gmem_dbias
,
thread_c
/
ELEMENTS_PER_LDG
,
dbias
,
params
.
wgrad_coeff
);
}
else
{
write_to_gmem
(
params
.
gmem_dscale
,
thread_c
/
ELEMENTS_PER_LDG
,
dscale
);
write_to_gmem
(
params
.
gmem_dbias
,
thread_c
/
ELEMENTS_PER_LDG
,
dbias
);
}
}
// Further normalize the dscale to be used in dx calculation
float
scale
[
ELEMENTS_PER_LDG
];
zero_array
(
scale
);
if
(
is_valid_c
)
{
read_from_gmem
(
scale
,
params
.
gmem_scale
,
thread_c
/
ELEMENTS_PER_LDG
);
}
multiply
(
dscale
,
var
);
// scale the inv-var as well, afterwards
multiply
(
var
,
scale
);
// inverse count
float
inv_count
=
params
.
svar_inv_count
;
// The base pointer to write to.
uint16_t
*
const
gmem_dst
=
&
params
.
gmem_dst
[
thread_c
];
// Store the elements in registers.
#pragma unroll 1
for
(
int
loop_i
=
OUTER_LOOPS
-
1
;
loop_i
>=
0
;
--
loop_i
)
{
// The value for nhw.
int
out_nhw
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Normalize the elements and write to memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
to_float
(
dy_math
,
dy_storage
[
i
]);
relu_bwd_for_dx
(
dy_math
,
x_math
,
mean_var_scale_bias
,
var
);
float
dx
[
ELEMENTS_PER_LDG
];
bwd_dx
(
dx
,
dy_math
,
var
,
x_math
,
mean
,
dscale
,
dbias
,
inv_count
);
// Write back.
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
dx
);
}
}
// The next value of nhw.
out_nhw
-=
pixels_per_iteration
;
// Read the next elements from memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
}
}
// Normalize the elements from SMEM and write them out.
if
(
pixels_in_smem
>
0
)
{
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_valid
=
((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
;
if
(
is_valid
)
{
// Read from SMEM.
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
],
dy_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
read_from_smem
(
x_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
offset
+=
PIXELS_PER_THREAD_IN_SMEM
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
read_from_smem
(
dy_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
to_float
(
dy_math
,
dy_storage_local
);
relu_bwd_for_dx
(
dy_math
,
x_math
,
mean_var_scale_bias
,
var
);
float
dx
[
ELEMENTS_PER_LDG
];
bwd_dx
(
dx
,
dy_math
,
var
,
x_math
,
mean
,
dscale
,
dbias
,
inv_count
);
// Write back.
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
dx
);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads
();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Storage
,
int
THREADS_PER_CTA
,
int
THREADS_PER_PIXEL
,
int
PIXELS_PER_THREAD_IN_REGISTERS
,
int
PIXELS_PER_THREAD_IN_SMEM
,
int
ELEMENTS_PER_LDG
,
int
USE_ONLINE_APPROACH
,
int
OUTER_LOOPS_
,
int
DESIRED_OCCUPANCY
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
,
DESIRED_OCCUPANCY
)
void
nhwc_batch_norm_bwd_add_relu
(
NhwcBatchNormBwdParams
params
)
{
// The number of pixels loaded in a single LDG.
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
// The number of pixels computed per CTA stored in registers.
const
int
PIXELS_PER_CTA_IN_REGISTERS
=
PIXELS_PER_THREAD_IN_REGISTERS
*
PIXELS_PER_LDG
;
// The number of pixels computed per CTA stored in SMEM.
const
int
PIXELS_PER_CTA_IN_SMEM
=
PIXELS_PER_THREAD_IN_SMEM
*
PIXELS_PER_LDG
;
// The number of C elements per CTA.
const
int
C_ELEMENTS_PER_CTA
=
THREADS_PER_PIXEL
*
ELEMENTS_PER_LDG
;
// Shared memory to do CTA-wide parallel sums.
__shared__
float
smem
[
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
32
)
*
ELEMENTS_PER_LDG
];
// The adapter for the storage.
typedef
PackedStorage
<
Storage
,
ELEMENTS_PER_LDG
>
PackedStorage_
;
// The data type for packed storage in SMEM.
typedef
typename
PackedStorage_
::
Type
PackedStorageType
;
// The number of elements in the packed storage.
const
int
PACKED_ELEMENTS_PER_LDG
=
PackedStorage_
::
PACKED_ELEMENTS_PER_LDG
;
// Registers to keep the data live for the persistent approach.
PackedStorageType
x_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
PackedStorageType
dy_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
// Shared memory buffer to store the extra pixels.
extern
__shared__
PackedStorageType
smem_storage_packed
[];
for
(
int
c_blk_index
=
blockIdx
.
y
;
c_blk_index
<
params
.
c_blks
;
c_blk_index
+=
gridDim
.
y
)
{
// The position in the NHW dimension where the CTA starts.
int
cta_nhw_regs
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_REGISTERS
;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int
cta_nhw_smem
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_SMEM
;
// Compute the NHW coordinate of the thread in the CTA.
const
int
thread_in_cta_nhw
=
threadIdx
.
x
/
THREADS_PER_PIXEL
;
// The position in the C dimension where the CTA starts.
const
int
cta_c
=
c_blk_index
*
C_ELEMENTS_PER_CTA
;
// Compute the C coordinate of the thread in the CTA.
const
int
thread_in_cta_c
=
threadIdx
.
x
%
THREADS_PER_PIXEL
;
// Compute the C coordinate of the thread.
const
int
thread_c
=
cta_c
+
thread_in_cta_c
*
ELEMENTS_PER_LDG
;
// Is the thread working on a valid C dimension?
const
int
is_valid_c
=
thread_c
<
params
.
c
;
float
mean
[
ELEMENTS_PER_LDG
];
zero_array
(
mean
);
if
(
is_valid_c
)
{
read_from_gmem
(
mean
,
params
.
gmem_saved_mean
,
thread_c
/
ELEMENTS_PER_LDG
);
}
// accumulation related registers
float
count
=
0.
f
,
dscale
[
ELEMENTS_PER_LDG
],
dbias
[
ELEMENTS_PER_LDG
];
zero_array
(
dscale
);
zero_array
(
dbias
);
// The number of elements loaded by this CTA.
int
cta_count
=
0
;
// The base pointers to load from.
const
uint16_t
*
gmem_src
=
&
params
.
gmem_src
[
thread_c
];
const
uint16_t
*
gmem_dy
=
&
params
.
gmem_dy
[
thread_c
];
uint16_t
*
gmem_dst1
=
&
params
.
gmem_dst1
[
thread_c
];
// outer loops
int
OUTER_LOOPS
=
OUTER_LOOPS_
==
1
?
1
:
params
.
outer_loops
;
// Load the batch of elements. Compute sum across them
const
int
pixels_per_iteration
=
PIXELS_PER_CTA_IN_REGISTERS
*
gridDim
.
x
;
if
(
OUTER_LOOPS_
!=
1
)
{
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized, offset is evenly divisible by 32
int
offset
=
(
pixels_per_iteration
*
OUTER_LOOPS
+
PIXELS_PER_CTA_IN_SMEM
*
gridDim
.
x
-
params
.
nhw
)
&
~
31
;
cta_nhw_regs
-=
offset
;
cta_nhw_smem
-=
offset
;
}
const
unsigned
int
*
const
gmem_relu_bitmask
=
params
.
gmem_relu_bitmask
+
((
params
.
nhw
+
31
)
&
~
31
)
*
2
*
c_blk_index
;
#pragma unroll 1
for
(
int
loop_i
=
0
;
loop_i
<
OUTER_LOOPS
;
++
loop_i
)
{
// The nhw position.
int
nhw_regs
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count
+=
max
(
0
,
min
(
PIXELS_PER_CTA_IN_REGISTERS
,
params
.
nhw
-
nhw_regs
));
int
lane_id
=
threadIdx
.
x
&
31
;
// Read the elements from memory.
float
is_valid
[
PIXELS_PER_THREAD_IN_REGISTERS
];
unsigned
int
relu_mask
[
PIXELS_PER_THREAD_IN_REGISTERS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
nhw_regs
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
zero_array
(
x_storage
[
i
]);
zero_array
(
dy_storage
[
i
]);
is_valid
[
i
]
=
0.
f
;
const
bool
is_valid_nhw
=
static_cast
<
unsigned
int
>
(
idx
)
<
static_cast
<
unsigned
int
>
(
params
.
nhw
);
if
(
is_valid_nhw
)
{
if
(
is_valid_c
)
{
if
(
loop_i
==
OUTER_LOOPS
-
1
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
else
{
ldg
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
is_valid
[
i
]
=
1.
f
;
}
if
(
lane_id
<
ELEMENTS_PER_LDG
)
{
relu_mask
[
i
]
=
gmem_relu_bitmask
[
idx
*
2
+
lane_id
];
}
}
}
// Do the math.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
nhw_regs
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
// Convert to float and update
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
bool
rectified
[
ELEMENTS_PER_LDG
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
rectified
[
j
]
=
((
__shfl_sync
(
0xFFFFFFFFU
,
relu_mask
[
i
],
j
)
&
(
1U
<<
lane_id
))
!=
0
);
}
to_float
(
x_math
,
x_storage
[
i
]);
to_float
(
dy_math
,
dy_storage
[
i
]);
// Update the count.
count
+=
is_valid
[
i
];
// Invert the count.
float
inv_count
=
is_valid
[
i
]
?
1.
f
/
count
:
0.
f
;
relu_bwd
(
dy_math
,
rectified
,
is_valid
[
i
]);
bwd_update
(
dscale
,
dbias
,
dy_math
,
x_math
,
mean
,
inv_count
);
// Lastly we need 'dy' only for BN, so store the 'relu-dgrad'ed version
from_float
(
dy_storage
[
i
],
dy_math
);
// dZ for elementwise add
if
(
is_valid
[
i
])
{
if
(
loop_i
==
OUTER_LOOPS
-
1
)
{
stg_stream
(
&
gmem_dst1
[
idx
*
params
.
c
],
dy_storage
[
i
]);
}
else
{
stg
(
&
gmem_dst1
[
idx
*
params
.
c
],
dy_storage
[
i
]);
}
}
}
}
// The elements to load and store in SMEM.
int
smem_nhw
=
OUTER_LOOPS
*
pixels_per_iteration
+
cta_nhw_smem
;
// Load elements from SMEM, update the CTA count.
int
pixels_in_smem
=
min
(
PIXELS_PER_CTA_IN_SMEM
,
params
.
nhw
-
smem_nhw
);
if
(
pixels_in_smem
>
0
)
{
cta_count
+=
pixels_in_smem
;
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_pixel_valid_nhw
=
static_cast
<
unsigned
int
>
(
idx
)
<
static_cast
<
unsigned
int
>
(
params
.
nhw
);
const
bool
is_pixel_valid
=
is_pixel_valid_nhw
&&
is_valid_c
;
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
],
dy_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
unsigned
int
relu_mask
;
int
lane_id
=
threadIdx
.
x
&
31
;
zero_array
(
x_storage_local
);
zero_array
(
dy_storage_local
);
if
(
is_pixel_valid_nhw
)
{
if
(
is_valid_c
)
{
ldg_stream
(
x_storage_local
,
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage_local
,
&
gmem_dy
[
idx
*
params
.
c
]);
}
if
(
lane_id
<
ELEMENTS_PER_LDG
)
{
relu_mask
=
gmem_relu_bitmask
[
idx
*
2
+
lane_id
];
}
}
bool
rectified
[
ELEMENTS_PER_LDG
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
rectified
[
j
]
=
((
__shfl_sync
(
0xFFFFFFFFU
,
relu_mask
,
j
)
&
(
1U
<<
lane_id
))
!=
0
);
}
// The offset to store in SMEM.
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
// Store in SMEM.
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
x_storage_local
);
offset
+=
PIXELS_PER_THREAD_IN_SMEM
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
// Update the count.
count
+=
is_pixel_valid
;
// Invert the count.
float
inv_count
=
is_pixel_valid
?
1.
f
/
count
:
0.
f
;
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
to_float
(
dy_math
,
dy_storage_local
);
relu_bwd
(
dy_math
,
rectified
,
is_pixel_valid
);
bwd_update
(
dscale
,
dbias
,
dy_math
,
x_math
,
mean
,
inv_count
);
from_float
(
dy_storage_local
,
dy_math
);
// dZ for elementwise add
if
(
is_pixel_valid
)
{
stg_stream
(
&
gmem_dst1
[
idx
*
params
.
c
],
dy_storage_local
);
}
// only store the 'relu-dgrad'ed version!
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
dy_storage_local
);
}
}
// We scale the mean by the number of elements. It brings more stability.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
dbias
[
i
]
*=
count
;
dscale
[
i
]
*=
count
;
}
// dscale parallel sum
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dscale
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// dbias parallel sum
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dbias
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// The workspace in global memory is distributed across the different CTA.
int
gmem_sums_offset
=
c_blk_index
*
gridDim
.
x
*
C_ELEMENTS_PER_CTA
*
2
;
// Write the data for the CTA to global memory.
float
*
gmem_sums
=
&
params
.
gmem_sums
[
gmem_sums_offset
];
if
(
threadIdx
.
x
<
THREADS_PER_PIXEL
)
{
const
int
idx
=
blockIdx
.
x
*
THREADS_PER_PIXEL
+
threadIdx
.
x
;
write_to_gmem
(
&
gmem_sums
[
0
],
idx
,
dscale
);
write_to_gmem
(
&
gmem_sums
[
C_ELEMENTS_PER_CTA
*
gridDim
.
x
],
idx
,
dbias
);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int
*
gmem_retired_ctas
=
&
params
.
gmem_retired_ctas
[
c_blk_index
%
(
2
*
gridDim
.
y
)];
inter_block_sync
(
gmem_retired_ctas
,
gridDim
.
x
,
blockIdx
.
x
==
0
);
// Reset the accumulators for global summation
zero_array
(
dscale
);
zero_array
(
dbias
);
// Build the global accumulation
#pragma unroll 1
for
(
int
idx
=
threadIdx
.
x
;
idx
<
THREADS_PER_PIXEL
*
gridDim
.
x
;
idx
+=
THREADS_PER_CTA
)
{
float
tmp1
[
ELEMENTS_PER_LDG
],
tmp2
[
ELEMENTS_PER_LDG
];
read_from_gmem
(
tmp1
,
gmem_sums
,
idx
);
read_from_gmem
(
tmp2
,
gmem_sums
+
C_ELEMENTS_PER_CTA
*
gridDim
.
x
,
idx
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
dscale
[
i
]
+=
tmp1
[
i
];
dbias
[
i
]
+=
tmp2
[
i
];
}
}
// dscale parallel sum
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
1
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
);
}
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dscale
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// dbias parallel sum
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
0
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
);
}
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dbias
,
smem
,
thread_in_cta_c
);
// Normalize the dscale.
float
var
[
ELEMENTS_PER_LDG
];
zero_array
(
var
);
if
(
is_valid_c
)
{
read_from_gmem
(
var
,
params
.
gmem_saved_var
,
thread_c
/
ELEMENTS_PER_LDG
);
}
multiply
(
dscale
,
var
);
// store dscale/dbias
bool
is_valid_for_saving
=
is_valid_c
&&
blockIdx
.
x
==
0
&&
thread_in_cta_nhw
==
0
;
if
(
is_valid_for_saving
)
{
if
(
params
.
sync_iters
>
0
)
{
scaled_write_to_gmem
(
params
.
gmem_dscale
,
thread_c
/
ELEMENTS_PER_LDG
,
dscale
,
params
.
wgrad_coeff
);
scaled_write_to_gmem
(
params
.
gmem_dbias
,
thread_c
/
ELEMENTS_PER_LDG
,
dbias
,
params
.
wgrad_coeff
);
}
else
{
write_to_gmem
(
params
.
gmem_dscale
,
thread_c
/
ELEMENTS_PER_LDG
,
dscale
);
write_to_gmem
(
params
.
gmem_dbias
,
thread_c
/
ELEMENTS_PER_LDG
,
dbias
);
}
}
// Further normalize the dscale to be used in dx calculation
float
scale
[
ELEMENTS_PER_LDG
];
zero_array
(
scale
);
if
(
is_valid_c
)
{
read_from_gmem
(
scale
,
params
.
gmem_scale
,
thread_c
/
ELEMENTS_PER_LDG
);
}
multiply
(
dscale
,
var
);
// scale the inv-var as well, afterwards
multiply
(
var
,
scale
);
// inverse count
float
inv_count
=
params
.
svar_inv_count
;
// The base pointer to write to.
uint16_t
*
const
gmem_dst
=
&
params
.
gmem_dst
[
thread_c
];
// Store the elements in registers.
#pragma unroll 1
for
(
int
loop_i
=
OUTER_LOOPS
-
1
;
loop_i
>=
0
;
--
loop_i
)
{
// The value for nhw.
int
out_nhw
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Normalize the elements and write to memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_valid
=
((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
;
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
to_float
(
dy_math
,
dy_storage
[
i
]);
float
dx
[
ELEMENTS_PER_LDG
];
bwd_dx
(
dx
,
dy_math
,
var
,
x_math
,
mean
,
dscale
,
dbias
,
inv_count
);
// Write back.
if
(
is_valid
)
{
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
dx
);
}
}
// The next value of nhw.
out_nhw
-=
pixels_per_iteration
;
// Read the next elements from memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
float
y
[
ELEMENTS_PER_LDG
];
zero_array
(
y
);
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage
[
i
],
&
gmem_dst1
[
idx
*
params
.
c
]);
}
}
}
// Normalize the elements from SMEM and write them out.
if
(
pixels_in_smem
>
0
)
{
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_valid
=
((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
;
if
(
is_valid
)
{
// Read from SMEM.
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
],
dy_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
read_from_smem
(
x_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
offset
+=
PIXELS_PER_THREAD_IN_SMEM
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
read_from_smem
(
dy_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
to_float
(
dy_math
,
dy_storage_local
);
float
dx
[
ELEMENTS_PER_LDG
];
bwd_dx
(
dx
,
dy_math
,
var
,
x_math
,
mean
,
dscale
,
dbias
,
inv_count
);
// Write back.
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
dx
);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads
();
}
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_
apex/contrib/csrc/xentropy/interface.cpp
0 → 100644
View file @
4d6ed501
#include <torch/extension.h>
// CUDA forward declarations
std
::
vector
<
at
::
Tensor
>
softmax_xentropy_cuda
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
,
const
bool
half_to_float
);
at
::
Tensor
softmax_xentropy_backward_cuda
(
const
at
::
Tensor
&
grad_loss
,
const
at
::
Tensor
&
logits
,
const
at
::
Tensor
&
max_log_sum_exp
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
at
::
Tensor
>
softmax_xentropy_forward
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
,
const
bool
half_to_float
)
{
CHECK_CUDA
(
input
);
CHECK_INPUT
(
labels
);
return
softmax_xentropy_cuda
(
input
,
labels
,
smoothing
,
half_to_float
);
}
at
::
Tensor
softmax_xentropy_backward
(
const
at
::
Tensor
&
grad_loss
,
const
at
::
Tensor
&
logits
,
const
at
::
Tensor
&
max_log_sum_exp
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
)
{
CHECK_CUDA
(
grad_loss
);
CHECK_CUDA
(
logits
);
CHECK_INPUT
(
max_log_sum_exp
);
CHECK_INPUT
(
labels
);
return
softmax_xentropy_backward_cuda
(
grad_loss
,
logits
,
max_log_sum_exp
,
labels
,
smoothing
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
softmax_xentropy_forward
,
"Softmax cross entropy loss with label smoothing forward (CUDA)"
);
m
.
def
(
"backward"
,
&
softmax_xentropy_backward
,
"Softmax cross entropy loss with label smoothing backward (CUDA)"
);
}
apex/contrib/csrc/xentropy/xentropy_kernel.cu
0 → 100644
View file @
4d6ed501
/**
* From PyTorch:
*
* Copyright (c) 2016- Facebook, Inc (Adam Paszke)
* Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
* Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
* Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
* Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
* Copyright (c) 2011-2013 NYU (Clement Farabet)
* Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
* Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
* Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
*
* From Caffe2:
*
* Copyright (c) 2016-present, Facebook Inc. All rights reserved.
*
* All contributions by Facebook:
* Copyright (c) 2016 Facebook Inc.
*
* All contributions by Google:
* Copyright (c) 2015 Google Inc.
* All rights reserved.
*
* All contributions by Yangqing Jia:
* Copyright (c) 2015 Yangqing Jia
* All rights reserved.
*
* All contributions from Caffe:
* Copyright(c) 2013, 2014, 2015, the respective contributors
* All rights reserved.
*
* All other contributions:
* Copyright(c) 2015, 2016 the respective contributors
* All rights reserved.
*
* Caffe2 uses a copyright model similar to Caffe: each contributor holds
* copyright over their contributions to Caffe2. The project versioning records
* all such contribution and copyright details. If a contributor wants to further
* mark their specific copyright on a particular contribution, they should
* indicate their copyright solely in the commit message of the change when it is
* committed.
*
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
* and IDIAP Research Institute nor the names of its contributors may be
* used to endorse or promote products derived from this software without
* specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/NumericLimits.cuh>
#include <THC/THC.h>
#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>
#include "type_shim.h"
using
Tensor
=
at
::
Tensor
;
using
TensorList
=
at
::
TensorList
;
using
ScalarType
=
at
::
ScalarType
;
using
at
::
acc_type
;
template
<
typename
T
,
typename
AccumT
,
typename
OutT
>
struct
LogSoftMaxForwardEpilogue
{
__device__
__forceinline__
LogSoftMaxForwardEpilogue
(
AccumT
max_input
,
AccumT
sum
)
:
logsum
(
max_input
+
std
::
log
(
sum
))
{}
__device__
__forceinline__
LogSoftMaxForwardEpilogue
(
AccumT
max_log_sum_exp
)
:
logsum
(
max_log_sum_exp
)
{}
__device__
__forceinline__
OutT
operator
()(
T
input
)
const
{
return
static_cast
<
OutT
>
(
input
-
logsum
);
}
const
AccumT
logsum
;
};
template
<
typename
T
,
typename
AccumT
,
typename
OutT
>
struct
LogSoftMaxBackwardEpilogue
{
__device__
__forceinline__
LogSoftMaxBackwardEpilogue
(
AccumT
sum
)
:
sum
(
sum
)
{}
__device__
__forceinline__
T
operator
()(
OutT
gradOutput
,
OutT
output
)
const
{
return
static_cast
<
T
>
(
gradOutput
-
std
::
exp
(
static_cast
<
AccumT
>
(
output
))
*
sum
);
}
const
AccumT
sum
;
};
const
int
max_threads
=
1024
;
inline
dim3
SoftMax_getBlockSize
(
int
ILP
,
uint64_t
dim_size
)
{
uint64_t
block_size
=
1
;
uint64_t
max_block_size
=
std
::
min
(
dim_size
/
ILP
,
static_cast
<
uint64_t
>
(
max_threads
));
while
(
block_size
<
max_block_size
)
block_size
*=
2
;
// Launch at least a single warp - the kernel assumes that.
block_size
=
std
::
max
(
block_size
,
static_cast
<
uint64_t
>
(
32
));
return
dim3
(
block_size
);
}
template
<
typename
T
>
struct
Add
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
Max
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
b
:
a
;
}
};
////////////////////////////////////////////////////////////////////////////////
// Regular kernel (fast when dim_size is large; requires inner_size == 1)
////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
typename
AccumT
>
struct
MaxFloat
{
__device__
__forceinline__
AccumT
operator
()(
AccumT
max
,
T
v
)
const
{
return
::
max
(
max
,
(
AccumT
)
v
);
}
};
template
<
typename
T
,
typename
AccumT
>
struct
AddFloat
{
__device__
__forceinline__
AccumT
operator
()(
AccumT
sum
,
T
v
)
const
{
return
sum
+
v
;
}
};
template
<
typename
T
,
typename
AccumT
>
struct
SumExpFloat
{
__device__
__forceinline__
SumExpFloat
(
AccumT
v
)
:
max_k
(
v
)
{}
__device__
__forceinline__
AccumT
operator
()(
AccumT
sum
,
T
v
)
const
{
return
sum
+
std
::
exp
(
v
-
max_k
);
}
const
AccumT
max_k
;
};
template
<
template
<
typename
>
class
Reduction
,
typename
AccumT
>
__device__
__forceinline__
AccumT
blockReduce
(
AccumT
*
smem
,
AccumT
val
,
const
Reduction
<
AccumT
>&
r
,
AccumT
defaultVal
)
{
// To avoid RaW races from chaining blockReduce calls together, we need a sync here
__syncthreads
();
smem
[
threadIdx
.
x
]
=
val
;
__syncthreads
();
AccumT
warpVal
=
defaultVal
;
// First warp will perform per-warp reductions for the remaining warps
uint32_t
mask
=
(((
uint64_t
)
1
)
<<
(
blockDim
.
x
/
32
))
-
1
;
if
(
threadIdx
.
x
<
32
)
{
int
lane
=
threadIdx
.
x
%
32
;
if
(
lane
<
blockDim
.
x
/
32
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
32
;
++
i
)
{
warpVal
=
r
(
warpVal
,
smem
[
lane
*
32
+
i
]);
}
__syncwarp
(
mask
);
smem
[
lane
]
=
warpVal
;
}
}
__syncthreads
();
// First thread will perform a reduction of the above per-warp reductions
AccumT
blockVal
=
defaultVal
;
if
(
threadIdx
.
x
==
0
)
{
for
(
int
i
=
0
;
i
<
blockDim
.
x
/
32
;
++
i
)
{
blockVal
=
r
(
blockVal
,
smem
[
i
]);
}
smem
[
0
]
=
blockVal
;
}
// Sync and broadcast
__syncthreads
();
return
smem
[
0
];
}
template
<
template
<
typename
>
class
Reduction1
,
template
<
typename
>
class
Reduction2
,
typename
AccumT
>
__device__
__forceinline__
void
blockReduce
(
AccumT
*
smem
,
AccumT
*
reducVal1
,
AccumT
val1
,
const
Reduction1
<
AccumT
>&
r1
,
AccumT
defaultVal1
,
AccumT
*
reducVal2
,
AccumT
val2
,
const
Reduction2
<
AccumT
>&
r2
,
AccumT
defaultVal2
)
{
// To avoid RaW races from chaining blockReduce calls together, we need a sync here
__syncthreads
();
smem
[
threadIdx
.
x
]
=
val1
;
smem
[
blockDim
.
x
+
threadIdx
.
x
]
=
val2
;
__syncthreads
();
AccumT
warpVal1
=
defaultVal1
;
AccumT
warpVal2
=
defaultVal2
;
// First warp will perform per-warp reductions for the remaining warps
uint32_t
mask
=
(((
uint64_t
)
1
)
<<
(
blockDim
.
x
/
32
))
-
1
;
if
(
threadIdx
.
x
<
32
)
{
int
lane
=
threadIdx
.
x
%
32
;
if
(
lane
<
blockDim
.
x
/
32
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
32
;
++
i
)
{
warpVal1
=
r1
(
warpVal1
,
smem
[
lane
*
32
+
i
]);
warpVal2
=
r2
(
warpVal2
,
smem
[
lane
*
32
+
i
+
blockDim
.
x
]);
}
__syncwarp
(
mask
);
smem
[
lane
]
=
warpVal1
;
smem
[
lane
+
blockDim
.
x
]
=
warpVal2
;
}
}
__syncthreads
();
// First thread will perform a reduction of the above per-warp reductions
AccumT
blockVal1
=
defaultVal1
;
AccumT
blockVal2
=
defaultVal2
;
if
(
threadIdx
.
x
==
0
)
{
for
(
int
i
=
0
;
i
<
blockDim
.
x
/
32
;
++
i
)
{
blockVal1
=
r1
(
blockVal1
,
smem
[
i
]);
blockVal2
=
r2
(
blockVal2
,
smem
[
i
+
blockDim
.
x
]);
}
smem
[
0
]
=
blockVal1
;
smem
[
blockDim
.
x
]
=
blockVal2
;
}
// Sync and broadcast
__syncthreads
();
*
reducVal1
=
smem
[
0
];
*
reducVal2
=
smem
[
blockDim
.
x
];
__syncthreads
();
}
template
<
template
<
typename
,
typename
>
class
Reduction
,
int
ILP
,
typename
T
,
typename
AccumT
>
__device__
__forceinline__
AccumT
ilpReduce
(
T
*
data
,
int
size
,
const
Reduction
<
T
,
AccumT
>&
r
,
AccumT
defaultVal
)
{
AccumT
threadVal
=
defaultVal
;
int
offset
=
threadIdx
.
x
;
int
last
=
size
%
(
ILP
*
blockDim
.
x
);
// Body (unroll by ILP times)
for
(;
offset
<
size
-
last
;
offset
+=
blockDim
.
x
*
ILP
)
{
T
tmp
[
ILP
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ILP
;
++
j
)
tmp
[
j
]
=
data
[
offset
+
j
*
blockDim
.
x
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ILP
;
++
j
)
threadVal
=
r
(
threadVal
,
tmp
[
j
]);
}
// Epilogue
for
(;
offset
<
size
;
offset
+=
blockDim
.
x
)
threadVal
=
r
(
threadVal
,
data
[
offset
]);
return
threadVal
;
}
template
<
template
<
typename
,
typename
>
class
Reduction1
,
template
<
typename
,
typename
>
class
Reduction2
,
int
ILP
,
typename
T
,
typename
AccumT
>
__device__
__forceinline__
void
ilpReduce
(
T
*
data
,
int
size
,
AccumT
*
reducVal1
,
const
Reduction1
<
T
,
AccumT
>&
r1
,
AccumT
defaultVal1
,
AccumT
*
reducVal2
,
const
Reduction2
<
T
,
AccumT
>&
r2
,
AccumT
defaultVal2
)
{
AccumT
threadVal1
=
defaultVal1
;
AccumT
threadVal2
=
defaultVal2
;
int
offset
=
threadIdx
.
x
;
int
last
=
size
%
(
ILP
*
blockDim
.
x
);
// Body (unroll by ILP times)
for
(;
offset
<
size
-
last
;
offset
+=
blockDim
.
x
*
ILP
)
{
T
tmp
[
ILP
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ILP
;
++
j
)
tmp
[
j
]
=
data
[
offset
+
j
*
blockDim
.
x
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ILP
;
++
j
)
{
threadVal1
=
r1
(
threadVal1
,
tmp
[
j
]);
threadVal2
=
r2
(
threadVal2
,
tmp
[
j
]);
}
}
// Epilogue
for
(;
offset
<
size
;
offset
+=
blockDim
.
x
)
{
threadVal1
=
r1
(
threadVal1
,
data
[
offset
]);
threadVal2
=
r2
(
threadVal2
,
data
[
offset
]);
}
*
reducVal1
=
threadVal1
;
*
reducVal2
=
threadVal2
;
}
template
<
int
ILP
,
typename
scalar_t
,
typename
accscalar_t
,
typename
outscalar_t
,
template
<
typename
,
typename
,
typename
>
class
Epilogue
>
__global__
void
cunn_SoftMaxXEntropyForward
(
accscalar_t
*
losses
,
outscalar_t
*
max_log_sum_exp
,
scalar_t
*
input
,
int64_t
*
labels
,
int64_t
classes
,
const
float
smoothing
)
{
extern
__shared__
unsigned
char
smem
[];
auto
sdata
=
reinterpret_cast
<
accscalar_t
*>
(
smem
);
// forward pointers to batch[blockIdx.x]
// each block handles a sample in the mini-batch
input
+=
blockIdx
.
x
*
classes
;
//output += blockIdx.x * classes;
int64_t
label
=
labels
[
blockIdx
.
x
];
// find the max and sum
accscalar_t
threadMax
,
threadSum
,
max_k
,
sum_k
;
ilpReduce
<
MaxFloat
,
AddFloat
,
ILP
,
scalar_t
,
accscalar_t
>
(
input
,
classes
,
&
threadMax
,
MaxFloat
<
scalar_t
,
accscalar_t
>
(),
-
at
::
numeric_limits
<
accscalar_t
>::
max
(),
&
threadSum
,
AddFloat
<
scalar_t
,
accscalar_t
>
(),
static_cast
<
accscalar_t
>
(
0
));
blockReduce
<
Max
,
Add
,
accscalar_t
>
(
sdata
,
&
max_k
,
threadMax
,
Max
<
accscalar_t
>
(),
-
at
::
numeric_limits
<
accscalar_t
>::
max
(),
&
sum_k
,
threadSum
,
Add
<
accscalar_t
>
(),
static_cast
<
accscalar_t
>
(
0
));
// reduce all values
accscalar_t
threadExp
=
ilpReduce
<
SumExpFloat
,
ILP
,
scalar_t
,
accscalar_t
>
(
input
,
classes
,
SumExpFloat
<
scalar_t
,
accscalar_t
>
(
max_k
),
static_cast
<
accscalar_t
>
(
0
));
accscalar_t
sumAll
=
blockReduce
<
Add
,
accscalar_t
>
(
sdata
,
threadExp
,
Add
<
accscalar_t
>
(),
static_cast
<
accscalar_t
>
(
0
));
Epilogue
<
scalar_t
,
accscalar_t
,
outscalar_t
>
epilogue
(
max_k
,
sumAll
);
// calculate per element loss with label smoothing
// reserve max + log_sum_exp for bprop
if
(
threadIdx
.
x
==
0
)
{
accscalar_t
log_prob
=
epilogue
(
static_cast
<
accscalar_t
>
(
input
[
label
]));
losses
[
blockIdx
.
x
]
=
(
max_k
+
std
::
log
(
sumAll
)
-
sum_k
/
classes
)
\
*
smoothing
-
log_prob
*
(
1
-
smoothing
);
max_log_sum_exp
[
blockIdx
.
x
]
=
max_k
+
std
::
log
(
sumAll
);
}
}
template
<
int
ILP
,
typename
scalar_t
,
typename
accscalar_t
,
typename
outscalar_t
,
template
<
typename
,
typename
,
typename
>
class
Epilogue
>
__global__
void
cunn_SoftMaxXEntropyBackward
(
scalar_t
*
gradInput
,
scalar_t
*
logits
,
outscalar_t
*
max_log_sum_exp
,
outscalar_t
*
gradOutput
,
int64_t
*
labels
,
const
float
smoothing
,
int
classes
)
{
gradInput
+=
blockIdx
.
x
*
classes
;
logits
+=
blockIdx
.
x
*
classes
;
accscalar_t
smooth_positives
=
1.0
-
smoothing
;
accscalar_t
smooth_negatives
=
smoothing
/
classes
;
accscalar_t
tmpGradOutput
=
gradOutput
[
blockIdx
.
x
];
int64_t
label
=
labels
[
blockIdx
.
x
];
accscalar_t
coeff
=
max_log_sum_exp
[
blockIdx
.
x
];
int
offset
=
threadIdx
.
x
;
int
last
=
classes
%
(
ILP
*
blockDim
.
x
);
for
(;
offset
<
classes
-
last
;
offset
+=
blockDim
.
x
*
ILP
)
{
accscalar_t
tmpLogits
[
ILP
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ILP
;
++
j
)
{
tmpLogits
[
j
]
=
static_cast
<
accscalar_t
>
(
logits
[
offset
+
j
*
blockDim
.
x
]);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
ILP
;
++
j
)
gradInput
[
offset
+
j
*
blockDim
.
x
]
=
tmpGradOutput
*
(
std
::
exp
(
tmpLogits
[
j
]
-
coeff
)
-
static_cast
<
accscalar_t
>
(
(
offset
+
j
*
blockDim
.
x
==
label
)
?
1
:
0
)
*
smooth_positives
-
smooth_negatives
);
}
for
(;
offset
<
classes
;
offset
+=
blockDim
.
x
)
gradInput
[
offset
]
=
tmpGradOutput
*
(
std
::
exp
(
static_cast
<
accscalar_t
>
(
logits
[
offset
])
-
coeff
)
-
static_cast
<
accscalar_t
>
((
offset
==
label
)
?
1
:
0
)
*
smooth_positives
-
smooth_negatives
);
}
template
<
template
<
typename
,
typename
,
typename
>
class
Epilogue
>
std
::
vector
<
Tensor
>
host_softmax_xentropy
(
const
Tensor
&
input_
,
const
Tensor
&
labels_
,
const
float
smoothing
,
const
bool
half_to_float
){
if
(
half_to_float
)
AT_ASSERTM
(
input_
.
type
().
scalarType
()
==
ScalarType
::
Half
,
"conversion is supported for Half type only"
);
AT_ASSERTM
(
labels_
.
type
().
scalarType
()
==
ScalarType
::
Long
,
"Label type should be CUDA Long"
);
auto
input
=
input_
.
contiguous
();
Tensor
max_log_sum_exp
=
at
::
empty_like
(
labels_
,
half_to_float
?
input
.
options
().
dtype
(
ScalarType
::
Float
)
:
input
.
options
());
Tensor
losses
=
at
::
empty_like
(
labels_
,
input_
.
options
().
dtype
(
ScalarType
::
Float
));
static_assert
(
std
::
is_same
<
acc_type
<
at
::
Half
,
true
>
,
float
>::
value
||
std
::
is_same
<
acc_type
<
at
::
Half
,
true
>
,
double
>::
value
,
"accscalar_t for half should be float or double"
);
AT_ASSERTM
(
input
.
dim
()
==
2
,
"Currently only 2 dim input supported"
);
AT_ASSERTM
(
labels_
.
dim
()
==
1
,
"Labels should be 1 dimensional"
);
AT_ASSERTM
(
input
.
size
(
0
)
==
labels_
.
size
(
0
),
"Input and label should have same number of examples"
);
AT_ASSERTM
(
input
.
numel
()
>
0
,
"Number of classes in input should not be 0"
);
const
int64_t
dim
=
1
;
int64_t
outer_size
=
1
;
int64_t
dim_size
=
input
.
size
(
dim
);
int64_t
inner_size
=
1
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
for
(
int64_t
i
=
0
;
i
<
dim
;
++
i
)
outer_size
*=
input
.
size
(
i
);
for
(
int64_t
i
=
dim
+
1
;
i
<
input
.
dim
();
++
i
)
inner_size
*=
input
.
size
(
i
);
// This kernel spawns a block per each element in the batch.
// XXX: it assumes that inner_size == 1
AT_CHECK
(
inner_size
==
1
,
"Currently only inner size 1 supported"
);
const
int
ILP
=
2
;
dim3
grid
(
outer_size
);
dim3
block
=
SoftMax_getBlockSize
(
ILP
,
dim_size
);
using
namespace
at
;
DISPATCH_FLOAT_AND_HALF
(
input
.
scalar_type
(),
0
,
"host_softmax_xentropy"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
if
(
!
half_to_float
)
{
cunn_SoftMaxXEntropyForward
<
ILP
,
scalar_t_0
,
accscalar_t
,
scalar_t_0
,
Epilogue
>
<<<
grid
,
block
,
2
*
block
.
x
*
sizeof
(
accscalar_t
),
stream
>>>
(
losses
.
data
<
accscalar_t
>
(),
max_log_sum_exp
.
data
<
scalar_t_0
>
(),
input
.
data
<
scalar_t_0
>
(),
labels_
.
data
<
int64_t
>
(),
dim_size
,
smoothing
);
}
else
{
cunn_SoftMaxXEntropyForward
<
ILP
,
scalar_t_0
,
accscalar_t
,
accscalar_t
,
Epilogue
>
<<<
grid
,
block
,
2
*
block
.
x
*
sizeof
(
accscalar_t
),
stream
>>>
(
losses
.
data
<
accscalar_t
>
(),
max_log_sum_exp
.
data
<
accscalar_t
>
(),
input
.
data
<
scalar_t_0
>
(),
labels_
.
data
<
int64_t
>
(),
dim_size
,
smoothing
);
}
);
THCudaCheck
(
cudaGetLastError
());
std
::
vector
<
at
::
Tensor
>
ret
=
{
losses
,
max_log_sum_exp
};
return
ret
;
}
template
<
template
<
typename
,
typename
,
typename
>
class
Epilogue
>
Tensor
host_softmax_xentropy_backward
(
const
at
::
Tensor
&
grad_loss
,
const
at
::
Tensor
&
logits_
,
const
at
::
Tensor
&
max_log_sum_exp
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
,
bool
half_to_float
)
{
const
int64_t
dim
=
1
;
Tensor
gI
=
at
::
empty_like
(
logits_
);
if
(
grad_loss
.
numel
()
==
0
)
{
return
gI
;
}
auto
grad
=
grad_loss
.
contiguous
();
auto
logits
=
logits_
.
contiguous
();
static_assert
(
std
::
is_same
<
acc_type
<
at
::
Half
,
true
>
,
float
>::
value
||
std
::
is_same
<
acc_type
<
at
::
Half
,
true
>
,
double
>::
value
,
"accscalar_t for half should be float or double"
);
if
(
grad
.
dim
()
==
0
)
grad
=
grad
.
view
(
1
);
AT_ASSERTM
(
logits_
.
dim
()
==
2
,
"Currently only 2 dim input supported"
);
AT_ASSERTM
(
labels
.
dim
()
==
1
,
"Labels should be 1 dimensional"
);
AT_ASSERTM
(
logits_
.
numel
()
>
0
,
"Number of classes in input should not be 0"
);
AT_ASSERTM
(
logits_
.
size
(
0
)
==
labels
.
size
(
0
),
"Input and label should have same number of examples"
);
AT_ASSERTM
(
labels
.
size
(
0
)
==
grad
.
size
(
0
),
"Label and loss should have same number of examples"
);
int64_t
outer_size
=
1
;
int64_t
dim_size
=
logits
.
size
(
dim
);
int64_t
inner_size
=
1
;
for
(
int64_t
i
=
0
;
i
<
dim
;
++
i
)
outer_size
*=
logits
.
size
(
i
);
for
(
int64_t
i
=
dim
+
1
;
i
<
logits
.
dim
();
++
i
)
inner_size
*=
logits
.
size
(
i
);
// See descriptions of kernels above.
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_CHECK
(
inner_size
==
1
,
"Currently only inner size 1 supported"
);
const
int
ILP
=
2
;
dim3
grid
(
outer_size
);
dim3
block
=
SoftMax_getBlockSize
(
ILP
,
dim_size
);
DISPATCH_FLOAT_AND_HALF
(
gI
.
scalar_type
(),
0
,
"host_softmax_xentropy_backward"
,
using
accscalar_t
=
acc_type
<
scalar_t_0
,
true
>
;
if
(
!
half_to_float
)
{
cunn_SoftMaxXEntropyBackward
<
ILP
,
scalar_t_0
,
accscalar_t
,
scalar_t_0
,
Epilogue
>
<<<
grid
,
block
,
block
.
x
*
sizeof
(
accscalar_t
),
stream
>>>
(
gI
.
data
<
scalar_t_0
>
(),
logits
.
data
<
scalar_t_0
>
(),
max_log_sum_exp
.
data
<
scalar_t_0
>
(),
grad
.
data
<
scalar_t_0
>
(),
labels
.
data
<
int64_t
>
(),
smoothing
,
dim_size
);
}
else
{
cunn_SoftMaxXEntropyBackward
<
ILP
,
scalar_t_0
,
accscalar_t
,
accscalar_t
,
Epilogue
>
<<<
grid
,
block
,
block
.
x
*
sizeof
(
accscalar_t
),
stream
>>>
(
gI
.
data
<
scalar_t_0
>
(),
logits
.
data
<
scalar_t_0
>
(),
max_log_sum_exp
.
data
<
accscalar_t
>
(),
grad
.
data
<
accscalar_t
>
(),
labels
.
data
<
int64_t
>
(),
smoothing
,
dim_size
);
}
);
THCudaCheck
(
cudaGetLastError
());
return
gI
;
}
std
::
vector
<
Tensor
>
softmax_xentropy_cuda
(
const
Tensor
&
input
,
const
Tensor
&
labels
,
const
float
smoothing
,
const
bool
half_to_float
){
return
host_softmax_xentropy
<
LogSoftMaxForwardEpilogue
>
(
input
,
labels
,
smoothing
,
half_to_float
);
}
at
::
Tensor
softmax_xentropy_backward_cuda
(
const
at
::
Tensor
&
grad_loss
,
const
at
::
Tensor
&
logits
,
const
at
::
Tensor
&
max_log_sum_exp
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
)
{
bool
half_to_float
=
grad_loss
.
type
().
scalarType
()
!=
logits
.
type
().
scalarType
();
if
(
half_to_float
)
{
AT_ASSERTM
((
grad_loss
.
type
().
scalarType
()
==
ScalarType
::
Float
&&
logits
.
type
().
scalarType
()
==
ScalarType
::
Half
),
"expected input and grad types to match, or input to be at::Half and grad to be at::Float"
);
}
return
host_softmax_xentropy_backward
<
LogSoftMaxBackwardEpilogue
>
(
grad_loss
,
logits
,
max_log_sum_exp
,
labels
,
smoothing
,
half_to_float
);
}
apex/contrib/groupbn/__init__.py
0 → 100644
View file @
4d6ed501
try
:
import
torch
import
bnp
from
.batch_norm
import
BatchNorm2d_NHWC
del
torch
del
bnp
del
batch_norm
except
ImportError
as
err
:
print
(
"apex was installed without --bnp flag, contrib.groupbn is not available"
)
apex/contrib/groupbn/batch_norm.py
0 → 100644
View file @
4d6ed501
import
torch
import
numpy
as
np
from
torch.nn.modules.batchnorm
import
_BatchNorm
import
bnp
class
bn_NHWC_impl
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
ret_cta
,
mom
,
epsilon
,
fuse_relu
,
is_train
,
bn_group
,
my_data
,
pair_data
,
magic
,
pair_data2
,
pair_data3
,
fwd_occup
,
fwd_grid_x
,
bwd_occup
,
bwd_grid_x
,
multi_stream
):
if
is_train
:
ctx
.
save_for_backward
(
x
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
)
ctx
.
epsilon
=
epsilon
ctx
.
momentum
=
mom
ctx
.
ret_cta
=
ret_cta
ctx
.
fuse_relu
=
fuse_relu
ctx
.
my_data
=
my_data
ctx
.
pair_data
=
pair_data
ctx
.
magic
=
magic
ctx
.
pair_data2
=
pair_data2
ctx
.
pair_data3
=
pair_data3
ctx
.
bn_group
=
bn_group
ctx
.
bwd_occup
=
bwd_occup
ctx
.
bwd_grid_x
=
bwd_grid_x
ctx
.
multi_stream
=
multi_stream
res
=
bnp
.
bn_fwd_nhwc
(
x
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
ret_cta
,
mom
,
epsilon
,
fuse_relu
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
magic
,
fwd_occup
,
fwd_grid_x
,
multi_stream
)
return
res
else
:
return
bnp
.
bn_fwd_eval_nhwc
(
x
,
s
,
b
,
rm
,
riv
,
ret_cta
,
bn_group
,
mom
,
epsilon
,
fuse_relu
)
@
staticmethod
def
backward
(
ctx
,
grad_y
):
x
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
=
ctx
.
saved_variables
epsilon
=
ctx
.
epsilon
mom
=
ctx
.
momentum
ret_cta
=
ctx
.
ret_cta
fuse_relu
=
ctx
.
fuse_relu
my_data
=
ctx
.
my_data
pair_data
=
ctx
.
pair_data
magic
=
ctx
.
magic
pair_data2
=
ctx
.
pair_data2
pair_data3
=
ctx
.
pair_data3
bn_group
=
ctx
.
bn_group
bwd_occup
=
ctx
.
bwd_occup
bwd_grid_x
=
ctx
.
bwd_grid_x
multi_stream
=
ctx
.
multi_stream
dx
,
dscale
,
dbias
=
bnp
.
bn_bwd_nhwc
(
x
,
grad_y
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
ret_cta
,
mom
,
epsilon
,
fuse_relu
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
magic
,
bwd_occup
,
bwd_grid_x
,
multi_stream
)
return
dx
,
dscale
,
dbias
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
bn_addrelu_NHWC_impl
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
z
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
grid_dim_y
,
ret_cta
,
mom
,
epsilon
,
is_train
,
bn_group
,
my_data
,
pair_data
,
magic
,
pair_data2
,
pair_data3
,
fwd_occup
,
fwd_grid_x
,
bwd_occup
,
bwd_grid_x
,
multi_stream
):
if
is_train
:
bitmask
=
torch
.
cuda
.
IntTensor
(((
x
.
numel
()
+
31
)
//
32
)
*
2
*
grid_dim_y
)
ctx
.
save_for_backward
(
x
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
bitmask
)
ctx
.
epsilon
=
epsilon
ctx
.
momentum
=
mom
ctx
.
ret_cta
=
ret_cta
ctx
.
my_data
=
my_data
ctx
.
pair_data
=
pair_data
ctx
.
magic
=
magic
ctx
.
pair_data2
=
pair_data2
ctx
.
pair_data3
=
pair_data3
ctx
.
bn_group
=
bn_group
ctx
.
bwd_occup
=
bwd_occup
ctx
.
bwd_grid_x
=
bwd_grid_x
ctx
.
multi_stream
=
multi_stream
res
=
bnp
.
bn_addrelu_fwd_nhwc
(
x
,
z
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
bitmask
,
ret_cta
,
mom
,
epsilon
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
magic
,
fwd_occup
,
fwd_grid_x
,
multi_stream
)
return
res
else
:
return
bnp
.
bn_addrelu_fwd_eval_nhwc
(
x
,
z
,
s
,
b
,
rm
,
riv
,
ret_cta
,
bn_group
,
mom
,
epsilon
)
@
staticmethod
def
backward
(
ctx
,
grad_y
):
x
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
bitmask
=
ctx
.
saved_variables
epsilon
=
ctx
.
epsilon
mom
=
ctx
.
momentum
ret_cta
=
ctx
.
ret_cta
my_data
=
ctx
.
my_data
pair_data
=
ctx
.
pair_data
magic
=
ctx
.
magic
pair_data2
=
ctx
.
pair_data2
pair_data3
=
ctx
.
pair_data3
bn_group
=
ctx
.
bn_group
bwd_occup
=
ctx
.
bwd_occup
bwd_grid_x
=
ctx
.
bwd_grid_x
multi_stream
=
ctx
.
multi_stream
dx
,
dz
,
dscale
,
dbias
=
bnp
.
bn_addrelu_bwd_nhwc
(
x
,
grad_y
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
bitmask
,
ret_cta
,
mom
,
epsilon
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
magic
,
bwd_occup
,
bwd_grid_x
,
multi_stream
)
return
dx
,
dz
,
dscale
,
dbias
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
BatchNorm2d_NHWC
(
_BatchNorm
):
# if using BatchNorm2d_NHWC simultaneously with multiple streams set multi_stream to True
def
__init__
(
self
,
num_features
,
fuse_relu
=
False
,
bn_group
=
1
,
max_cta_per_sm
=
2
,
cta_launch_margin
=
12
,
multi_stream
=
False
):
super
(
BatchNorm2d_NHWC
,
self
).
__init__
(
num_features
)
self
.
fuse_relu
=
fuse_relu
self
.
multi_stream
=
multi_stream
self
.
minibatch_mean
=
torch
.
cuda
.
FloatTensor
(
num_features
)
self
.
minibatch_riv
=
torch
.
cuda
.
FloatTensor
(
num_features
)
#defaut to distributed bn disabled
self
.
bn_group
=
bn_group
self
.
max_cta_per_sm
=
max_cta_per_sm
#used only in training fwd and bwd
self
.
cta_launch_margin
=
cta_launch_margin
#used only in training fwd and bwd
self
.
my_data
=
None
self
.
pair_data
=
None
self
.
pair_data2
=
None
self
.
pair_data3
=
None
self
.
local_rank
=
0
self
.
magic
=
torch
.
IntTensor
([
0
])
#calculate cta per sm occupancies
assert
(
max_cta_per_sm
>
0
)
# won't be able to do much with 0 CTAs :)
self
.
fwd_occupancy
=
min
(
bnp
.
bn_fwd_nhwc_occupancy
(),
max_cta_per_sm
)
self
.
bwd_occupancy
=
min
(
bnp
.
bn_bwd_nhwc_occupancy
(),
max_cta_per_sm
)
self
.
addrelu_fwd_occupancy
=
min
(
bnp
.
bn_addrelu_fwd_nhwc_occupancy
(),
max_cta_per_sm
)
self
.
addrelu_bwd_occupancy
=
min
(
bnp
.
bn_addrelu_bwd_nhwc_occupancy
(),
max_cta_per_sm
)
#calculate grid dimentions based on occupancy numbers
mp_count
=
torch
.
cuda
.
get_device_properties
(
None
).
multi_processor_count
self
.
fwd_grid_dim_x
=
max
(
mp_count
*
self
.
fwd_occupancy
-
cta_launch_margin
,
1
)
self
.
bwd_grid_dim_x
=
max
(
mp_count
*
self
.
bwd_occupancy
-
cta_launch_margin
,
1
)
self
.
addrelu_fwd_grid_dim_x
=
max
(
mp_count
*
self
.
addrelu_fwd_occupancy
-
cta_launch_margin
,
1
)
self
.
addrelu_bwd_grid_dim_x
=
max
(
mp_count
*
self
.
addrelu_bwd_occupancy
-
cta_launch_margin
,
1
)
self
.
grid_dim_y
=
(
num_features
+
63
)
//
64
# allocate scratch space used by implementation
# TODO: scratch space that is not supposed to be exposed at user code. We only need one time initialization, the
# same buffer could be reused in future iterations. Currently we exposed it here instead of requesting new
# buffer from cache allocator to avoid unnecessary initialization at future iterations.
self
.
ret_cta
=
torch
.
cuda
.
ByteTensor
(
8192
).
fill_
(
0
)
#FIXME: turn pair handles into an array
if
bn_group
>
1
:
local_rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
assert
(
world_size
>=
bn_group
)
assert
(
world_size
%
bn_group
==
0
)
bn_sync_steps
=
1
if
(
bn_group
==
4
):
bn_sync_steps
=
2
if
(
bn_group
==
8
):
bn_sync_steps
=
3
self
.
ipc_buffer
=
torch
.
cuda
.
ByteTensor
(
bnp
.
get_buffer_size
(
bn_sync_steps
))
self
.
my_data
=
bnp
.
get_data_ptr
(
self
.
ipc_buffer
)
# we are walking on very thin ice here by utilizing internal `_share_cuda_()`
self
.
storage
=
self
.
ipc_buffer
.
storage
()
self
.
share_cuda
=
self
.
storage
.
_share_cuda_
()
internal_cuda_mem
=
self
.
share_cuda
# internal_cuda_mem[1]: ipc_mem_handle
my_handle
=
torch
.
cuda
.
ByteTensor
(
np
.
frombuffer
(
internal_cuda_mem
[
1
],
dtype
=
np
.
uint8
))
# internal_cuda_mem[3]: offset
my_offset
=
torch
.
cuda
.
IntTensor
([
internal_cuda_mem
[
3
]])
handles_all
=
torch
.
empty
(
world_size
,
my_handle
.
size
(
0
),
dtype
=
my_handle
.
dtype
,
device
=
my_handle
.
device
)
handles_l
=
list
(
handles_all
.
unbind
(
0
))
torch
.
distributed
.
all_gather
(
handles_l
,
my_handle
)
offsets_all
=
torch
.
empty
(
world_size
,
my_offset
.
size
(
0
),
dtype
=
my_offset
.
dtype
,
device
=
my_offset
.
device
)
offsets_l
=
list
(
offsets_all
.
unbind
(
0
))
torch
.
distributed
.
all_gather
(
offsets_l
,
my_offset
)
#whom do I actually care about? that would be local_rank XOR 1
self
.
pair_handle
=
handles_l
[
local_rank
^
1
].
cpu
().
contiguous
()
pair_offset
=
offsets_l
[
local_rank
^
1
].
cpu
()
self
.
pair_data
=
bnp
.
get_remote_data_ptr
(
self
.
pair_handle
,
pair_offset
)
if
bn_group
>
2
:
self
.
pair_handle2
=
handles_l
[
local_rank
^
2
].
cpu
().
contiguous
()
pair_offset2
=
offsets_l
[
local_rank
^
2
].
cpu
()
self
.
pair_data2
=
bnp
.
get_remote_data_ptr
(
self
.
pair_handle2
,
pair_offset2
)
if
bn_group
>
4
:
self
.
pair_handle3
=
handles_l
[
local_rank
^
4
].
cpu
().
contiguous
()
pair_offset3
=
offsets_l
[
local_rank
^
4
].
cpu
()
self
.
pair_data3
=
bnp
.
get_remote_data_ptr
(
self
.
pair_handle3
,
pair_offset3
)
#FIXME: get magic value into C code and eliminate from here
self
.
magic
=
torch
.
IntTensor
([
2
])
self
.
local_rank
=
local_rank
def
forward
(
self
,
x
,
z
=
None
):
if
z
is
not
None
:
assert
(
self
.
fuse_relu
==
True
)
return
bn_addrelu_NHWC_impl
.
apply
(
x
,
z
,
self
.
weight
,
self
.
bias
,
self
.
running_mean
,
self
.
running_var
,
self
.
minibatch_mean
,
self
.
minibatch_riv
,
self
.
grid_dim_y
,
self
.
ret_cta
,
self
.
momentum
,
self
.
eps
,
self
.
training
,
self
.
bn_group
,
self
.
my_data
,
self
.
pair_data
,
(
self
.
magic
),
self
.
pair_data2
,
self
.
pair_data3
,
self
.
addrelu_fwd_occupancy
,
self
.
addrelu_fwd_grid_dim_x
,
self
.
addrelu_bwd_occupancy
,
self
.
addrelu_bwd_grid_dim_x
,
self
.
multi_stream
)
else
:
return
bn_NHWC_impl
.
apply
(
x
,
self
.
weight
,
self
.
bias
,
self
.
running_mean
,
self
.
running_var
,
self
.
minibatch_mean
,
self
.
minibatch_riv
,
self
.
ret_cta
,
self
.
momentum
,
self
.
eps
,
self
.
fuse_relu
,
self
.
training
,
self
.
bn_group
,
self
.
my_data
,
self
.
pair_data
,
(
self
.
magic
),
self
.
pair_data2
,
self
.
pair_data3
,
self
.
fwd_occupancy
,
self
.
fwd_grid_dim_x
,
self
.
bwd_occupancy
,
self
.
bwd_grid_dim_x
,
self
.
multi_stream
)
def
__del__
(
self
):
if
self
.
bn_group
>
1
:
bnp
.
close_remote_data
(
self
.
pair_handle
)
if
self
.
bn_group
>
2
:
bnp
.
close_remote_data
(
self
.
pair_handle2
)
if
self
.
bn_group
>
4
:
bnp
.
close_remote_data
(
self
.
pair_handle3
)
apex/contrib/test/test_label_smoothing.py
0 → 100644
View file @
4d6ed501
import
torch
from
apex.contrib
import
xentropy
as
label_smoothing
import
unittest
import
warnings
import
random
import
numpy
as
np
import
time
def
label_smoothing_raw
(
x
,
target
,
padding_idx
,
smoothing
):
logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
x
,
dim
=-
1
,
dtype
=
torch
.
float32
)
non_pad_mask
=
(
target
!=
padding_idx
)
nll_loss
=
-
logprobs
.
gather
(
dim
=-
1
,
index
=
target
.
unsqueeze
(
1
))
nll_loss
=
nll_loss
.
squeeze
(
1
)[
non_pad_mask
]
smooth_loss
=
-
logprobs
.
mean
(
dim
=-
1
)[
non_pad_mask
]
loss
=
(
1.0
-
smoothing
)
*
nll_loss
+
smoothing
*
smooth_loss
return
loss
def
label_smoothing_opt_1
(
x
,
target
,
padding_idx
,
smoothing
):
logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
x
,
dim
=-
1
,
dtype
=
torch
.
float32
)
pad_mask
=
(
target
==
padding_idx
)
ll_loss
=
logprobs
.
gather
(
dim
=-
1
,
index
=
target
.
unsqueeze
(
1
)).
squeeze
(
1
)
smooth_loss
=
logprobs
.
mean
(
dim
=-
1
)
loss
=
(
smoothing
-
1.0
)
*
ll_loss
-
smoothing
*
smooth_loss
loss
.
masked_fill_
(
pad_mask
,
0
)
return
loss
class
LabelSmoothingTest
(
unittest
.
TestCase
):
def
setUp
(
self
,
seed
=
1234
):
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
# Set pytorch print precision
torch
.
set_printoptions
(
precision
=
10
)
def
gen_test_inputs
(
self
,
N
,
T
,
H
,
smoothing
,
padding_idx
):
logits
=
torch
.
randn
((
N
*
T
,
H
),
dtype
=
torch
.
half
,
device
=
'cuda'
,
requires_grad
=
True
)
labels
=
torch
.
randint
(
0
,
H
,
[
N
*
T
],
device
=
'cuda'
)
for
i
in
random
.
sample
(
range
(
N
*
T
),
N
*
T
//
6
):
labels
[
i
]
=
padding_idx
half_to_float
=
(
logits
.
dtype
==
torch
.
half
)
return
logits
,
labels
,
half_to_float
def
print_max_diff_elem
(
self
,
ref
,
tst
):
ref
,
tst
=
ref
.
flatten
(),
tst
.
flatten
()
diff
=
(
ref
-
tst
).
abs
().
max
()
idx
=
(
ref
-
tst
).
abs
().
argmax
()
print
(
"Max atol idx: {}, diff: {:.6f}, ref: {:.6f}, tst: {:.6f}"
.
format
(
idx
,
diff
,
ref
[
idx
],
tst
[
idx
]))
def
test_label_smoothing_function
(
self
):
# Set label smoothing configuration
smoothing
,
padding_idx
=
0.1
,
0
N
,
T
,
H
=
128
,
74
,
32320
iters
=
10
loss_func
=
label_smoothing
.
SoftmaxCrossEntropyLoss
.
apply
for
i
in
range
(
iters
):
logits
,
labels
,
half_to_float
=
self
.
gen_test_inputs
(
N
,
T
,
H
,
smoothing
,
padding_idx
)
# Run original softmax cross entropy with label smoothing
logits
.
grad
=
None
losses
=
label_smoothing_raw
(
logits
,
labels
,
padding_idx
,
smoothing
)
loss
=
losses
.
sum
()
loss
.
backward
()
ref_loss
=
loss
.
clone
().
detach
()
ref_grad
=
logits
.
grad
.
clone
().
detach
()
# Run optimized softmax cross entropy with label smoothing
logits
.
grad
=
None
losses
=
loss_func
(
logits
,
labels
,
smoothing
,
padding_idx
,
half_to_float
)
loss
=
losses
.
sum
()
loss
.
backward
()
val_loss
=
loss
.
clone
().
detach
()
val_grad
=
logits
.
grad
.
clone
().
detach
()
# Validate
self
.
print_max_diff_elem
(
ref_grad
,
val_grad
)
self
.
assertTrue
(
torch
.
allclose
(
ref_loss
,
val_loss
,
atol
=
1e-5
,
rtol
=
1e-5
))
self
.
assertTrue
(
torch
.
allclose
(
ref_grad
,
val_grad
,
atol
=
1e-5
,
rtol
=
1e-5
))
def
test_label_smoothing_perf
(
self
):
# Set label smoothing configuration
smoothing
,
padding_idx
=
0.1
,
0
N
,
T
,
H
=
128
,
74
,
32320
iters
=
1000
loss_func
=
label_smoothing
.
SoftmaxCrossEntropyLoss
.
apply
print
()
logits
,
labels
,
half_to_float
=
self
.
gen_test_inputs
(
N
,
T
,
H
,
smoothing
,
padding_idx
)
# Run original softmax cross entropy with label smoothing
torch
.
cuda
.
synchronize
()
ts
=
time
.
time
()
for
i
in
range
(
iters
):
logits
.
grad
=
None
losses
=
label_smoothing_raw
(
logits
,
labels
,
padding_idx
,
smoothing
)
loss
=
losses
.
sum
()
/
N
loss
.
backward
()
torch
.
cuda
.
synchronize
()
print
(
"Raw time {:.2f} s elapsed for {} iterations, norm {:.4f}"
.
format
(
time
.
time
()
-
ts
,
iters
,
logits
.
grad
.
norm
()))
# Run optimized softmax cross entropy with label smoothing
torch
.
cuda
.
synchronize
()
ts
=
time
.
time
()
for
i
in
range
(
iters
):
logits
.
grad
=
None
losses
=
loss_func
(
logits
,
labels
,
smoothing
,
padding_idx
,
half_to_float
)
loss
=
losses
.
sum
()
/
N
loss
.
backward
()
torch
.
cuda
.
synchronize
()
print
(
"Opt time {:.2f} s elapsed for {} iterations, norm {:.4f}"
.
format
(
time
.
time
()
-
ts
,
iters
,
logits
.
grad
.
norm
()))
if
__name__
==
'__main__'
:
unittest
.
main
()
apex/contrib/xentropy/__init__.py
0 → 100644
View file @
4d6ed501
try
:
import
torch
import
xentropy_cuda
from
.softmax_xentropy
import
SoftmaxCrossEntropyLoss
del
torch
del
xentropy_cuda
del
softmax_xentropy
except
ImportError
as
err
:
print
(
"apex was installed without --xentropy flag, contrib.xentropy is not available"
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment