Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
438f6f9f
Unverified
Commit
438f6f9f
authored
Jan 17, 2019
by
mcarilli
Committed by
GitHub
Jan 17, 2019
Browse files
Merge pull request #125 from NVIDIA/nhwc_sbn_pr
[sync BN nhwc]
parents
3c7a0e44
a62b87ea
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
934 additions
and
137 deletions
+934
-137
apex/parallel/__init__.py
apex/parallel/__init__.py
+5
-3
apex/parallel/optimized_sync_batchnorm.py
apex/parallel/optimized_sync_batchnorm.py
+21
-4
apex/parallel/optimized_sync_batchnorm_kernel.py
apex/parallel/optimized_sync_batchnorm_kernel.py
+36
-13
csrc/syncbn.cpp
csrc/syncbn.cpp
+57
-16
csrc/welford.cu
csrc/welford.cu
+774
-87
tests/synced_batchnorm/single_gpu_unit_test.py
tests/synced_batchnorm/single_gpu_unit_test.py
+34
-7
tests/synced_batchnorm/two_gpu_unit_test.py
tests/synced_batchnorm/two_gpu_unit_test.py
+7
-7
No files found.
apex/parallel/__init__.py
View file @
438f6f9f
...
...
@@ -19,7 +19,7 @@ except ImportError:
warned_syncbn
=
True
from
.sync_batchnorm
import
SyncBatchNorm
def
convert_syncbn_model
(
module
,
process_group
=
None
):
def
convert_syncbn_model
(
module
,
process_group
=
None
,
channel_last
=
False
):
'''
Recursively traverse module and its children to replace all
`torch.nn.modules.batchnorm._BatchNorm` with `apex.parallel.SyncBatchNorm`
...
...
@@ -38,14 +38,16 @@ def convert_syncbn_model(module, process_group=None):
'''
mod
=
module
if
isinstance
(
module
,
torch
.
nn
.
modules
.
batchnorm
.
_BatchNorm
):
mod
=
SyncBatchNorm
(
module
.
num_features
,
module
.
eps
,
module
.
momentum
,
module
.
affine
,
module
.
track_running_stats
,
process_group
)
mod
=
SyncBatchNorm
(
module
.
num_features
,
module
.
eps
,
module
.
momentum
,
module
.
affine
,
module
.
track_running_stats
,
process_group
,
channel_last
=
channel_last
)
mod
.
running_mean
=
module
.
running_mean
mod
.
running_var
=
module
.
running_var
if
module
.
affine
:
mod
.
weight
.
data
=
module
.
weight
.
data
.
clone
().
detach
()
mod
.
bias
.
data
=
module
.
bias
.
data
.
clone
().
detach
()
for
name
,
child
in
module
.
named_children
():
mod
.
add_module
(
name
,
convert_syncbn_model
(
child
))
mod
.
add_module
(
name
,
convert_syncbn_model
(
child
,
process_group
=
process_group
,
channel_last
=
channel_last
))
# TODO(jie) should I delete model explicitly?
del
module
return
mod
apex/parallel/optimized_sync_batchnorm.py
View file @
438f6f9f
...
...
@@ -38,26 +38,43 @@ class SyncBatchNorm(_BatchNorm):
process_group: pass in a process group within which the stats of the
mini-batch is being synchronized. ``None`` for using default process
group
channel_last: a boolean value that when set to ``True``, this module
take the last dimension of the input tensor to be the channel
dimension. Default: False
Examples::
>>> # channel first tensor
>>> sbn = apex.parallel.SyncBatchNorm(100).cuda()
>>> inp = torch.randn(10, 100, 14, 14).cuda()
>>> out = sbn(inp)
>>> inp = torch.randn(3, 100, 20).cuda()
>>> out = sbn(inp)
>>> # channel last tensor
>>> sbn = apex.parallel.SyncBatchNorm(100, channel_last=True).cuda()
>>> inp = torch.randn(10, 14, 14, 100).cuda()
"""
def
__init__
(
self
,
num_features
,
eps
=
1e-5
,
momentum
=
0.1
,
affine
=
True
,
track_running_stats
=
True
,
process_group
=
None
):
def
__init__
(
self
,
num_features
,
eps
=
1e-5
,
momentum
=
0.1
,
affine
=
True
,
track_running_stats
=
True
,
process_group
=
None
,
channel_last
=
False
):
super
(
SyncBatchNorm
,
self
).
__init__
(
num_features
,
eps
=
eps
,
momentum
=
momentum
,
affine
=
affine
,
track_running_stats
=
track_running_stats
)
self
.
process_group
=
process_group
self
.
channel_last
=
channel_last
def
_specify_process_group
(
self
,
process_group
):
self
.
process_group
=
process_group
def
_specify_channel_last
(
self
,
channel_last
):
self
.
channel_last
=
channel_last
def
forward
(
self
,
input
):
if
not
self
.
training
and
self
.
track_running_stats
:
if
not
self
.
training
and
self
.
track_running_stats
and
not
self
.
channel_last
:
# fall back to pytorch implementation for inference
return
F
.
batch_norm
(
input
,
self
.
running_mean
,
self
.
running_var
,
self
.
weight
,
self
.
bias
,
False
,
0.0
,
self
.
eps
)
else
:
self
.
num_batches_tracked
+=
1
return
SyncBatchnormFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
running_mean
,
self
.
running_var
,
self
.
eps
,
self
.
track_running_stats
,
self
.
momentum
,
self
.
process_group
)
exponential_average_factor
=
0.0
if
self
.
training
and
self
.
track_running_stats
:
self
.
num_batches_tracked
+=
1
if
self
.
momentum
is
None
:
exponential_average_factor
=
1.0
/
float
(
self
.
num_batches_tracked
)
else
:
exponential_average_factor
=
self
.
momentum
return
SyncBatchnormFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
running_mean
,
self
.
running_var
,
self
.
eps
,
self
.
training
or
not
self
.
track_running_stats
,
exponential_average_factor
,
self
.
process_group
,
self
.
channel_last
)
apex/parallel/optimized_sync_batchnorm_kernel.py
View file @
438f6f9f
...
...
@@ -7,26 +7,40 @@ from apex.parallel import ReduceOp
class
SyncBatchnormFunction
(
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
,
running_mean
,
running_variance
,
eps
,
track_running_stats
=
True
,
momentum
=
1.0
,
process_group
=
None
):
def
forward
(
ctx
,
input
,
weight
,
bias
,
running_mean
,
running_variance
,
eps
,
track_running_stats
=
True
,
momentum
=
1.0
,
process_group
=
None
,
channel_last
=
False
):
torch
.
cuda
.
nvtx
.
range_push
(
"sync_BN_fw"
)
input
=
input
.
contiguous
()
world_size
=
0
mean
=
None
var_biased
=
None
inv_std
=
None
var
=
None
out
=
None
count
=
None
if
track_running_stats
:
mean
,
var
,
var_biased
=
syncbn
.
welford_mean_var
(
input
)
if
channel_last
:
count
=
int
(
input
.
numel
()
/
input
.
size
(
-
1
))
mean
,
var_biased
=
syncbn
.
welford_mean_var_c_last
(
input
)
else
:
count
=
int
(
input
.
numel
()
/
input
.
size
(
1
))
mean
,
var_biased
=
syncbn
.
welford_mean_var
(
input
)
if
torch
.
distributed
.
is_initialized
():
if
not
process_group
:
process_group
=
torch
.
distributed
.
group
.
WORLD
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
mean_all
=
torch
.
empty
(
world_size
,
mean
.
size
(
0
),
dtype
=
mean
.
dtype
,
device
=
mean
.
device
)
var_all
=
torch
.
empty
(
world_size
,
var
.
size
(
0
),
dtype
=
var
.
dtype
,
device
=
var
.
device
)
var_all
=
torch
.
empty
(
world_size
,
var
_biased
.
size
(
0
),
dtype
=
var
_biased
.
dtype
,
device
=
var
_biased
.
device
)
mean_l
=
[
mean_all
.
narrow
(
0
,
i
,
1
)
for
i
in
range
(
world_size
)]
var_l
=
[
var_all
.
narrow
(
0
,
i
,
1
)
for
i
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
mean_l
,
mean
,
process_group
)
torch
.
distributed
.
all_gather
(
var_l
,
var_biased
,
process_group
)
mean
,
var
,
var_biase
d
=
syncbn
.
welford_parallel
(
mean_all
.
transpose
(
1
,
0
).
contiguous
(),
var_all
.
transpose
(
1
,
0
).
cont
iguous
(),
int
(
input
.
numel
()
/
input
.
size
(
1
))
)
mean
,
var
,
inv_st
d
=
syncbn
.
welford_parallel
(
mean_all
,
var_all
,
co
u
nt
,
eps
)
# TODO(Jie): should do fp32 math instead!
else
:
inv_std
=
1.0
/
torch
.
sqrt
(
var_biased
+
eps
)
var
=
var_biased
*
(
count
)
/
(
count
-
1
)
r_m_inc
=
mean
if
running_mean
.
dtype
!=
torch
.
float16
else
mean
.
half
()
r_v_inc
=
var
if
running_variance
.
dtype
!=
torch
.
float16
else
var
.
half
()
...
...
@@ -34,14 +48,17 @@ class SyncBatchnormFunction(Function):
running_variance
.
data
=
running_variance
.
data
*
(
1
-
momentum
)
+
momentum
*
r_v_inc
else
:
mean
=
running_mean
.
data
var_biased
=
running_var
.
data
inv_std
=
1.0
/
torch
.
sqrt
(
running_var
.
data
+
eps
)
ctx
.
save_for_backward
(
input
,
weight
,
mean
,
var_biased
)
ctx
.
eps
=
eps
ctx
.
save_for_backward
(
input
,
weight
,
mean
,
inv_std
)
ctx
.
process_group
=
process_group
ctx
.
channel_last
=
channel_last
ctx
.
world_size
=
world_size
out
=
syncbn
.
batchnorm_forward
(
input
,
mean
,
var_biased
,
weight
,
bias
,
eps
)
if
channel_last
:
out
=
syncbn
.
batchnorm_forward_c_last
(
input
,
mean
,
inv_std
,
weight
,
bias
)
else
:
out
=
syncbn
.
batchnorm_forward
(
input
,
mean
,
inv_std
,
weight
,
bias
)
torch
.
cuda
.
nvtx
.
range_pop
()
return
out
...
...
@@ -53,14 +70,17 @@ class SyncBatchnormFunction(Function):
# mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0)
saved_input
,
weight
,
running_mean
,
running_variance
=
ctx
.
saved_tensors
eps
=
ctx
.
eps
saved_input
,
weight
,
mean
,
inv_std
=
ctx
.
saved_tensors
process_group
=
ctx
.
process_group
channel_last
=
ctx
.
channel_last
world_size
=
ctx
.
world_size
grad_input
=
grad_weight
=
grad_bias
=
None
# TODO(jie): why do I have to clone here? life time of grad_output?
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn
(
grad_output
,
saved_input
,
running_mean
,
running_variance
,
weight
,
eps
)
if
channel_last
:
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn_c_last
(
grad_output
,
saved_input
,
mean
,
inv_std
,
weight
)
else
:
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn
(
grad_output
,
saved_input
,
mean
,
inv_std
,
weight
)
# calculate grad_input
if
ctx
.
needs_input_grad
[
0
]:
...
...
@@ -72,7 +92,10 @@ class SyncBatchnormFunction(Function):
torch
.
distributed
.
all_reduce
(
mean_dy_xmu
,
ReduceOp
.
SUM
,
process_group
)
mean_dy_xmu
=
mean_dy_xmu
/
world_size
grad_input
=
syncbn
.
batchnorm_backward
(
grad_output
,
saved_input
,
running_mean
,
running_variance
,
weight
,
mean_dy
,
mean_dy_xmu
,
eps
)
if
channel_last
:
grad_input
=
syncbn
.
batchnorm_backward_c_last
(
grad_output
,
saved_input
,
mean
,
inv_std
,
weight
,
mean_dy
,
mean_dy_xmu
)
else
:
grad_input
=
syncbn
.
batchnorm_backward
(
grad_output
,
saved_input
,
mean
,
inv_std
,
weight
,
mean_dy
,
mean_dy_xmu
)
if
weight
is
None
or
not
ctx
.
needs_input_grad
[
1
]:
grad_weight
=
None
...
...
@@ -81,4 +104,4 @@ class SyncBatchnormFunction(Function):
grad_bias
=
None
torch
.
cuda
.
nvtx
.
range_pop
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
,
None
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
,
None
,
None
csrc/syncbn.cpp
View file @
438f6f9f
...
...
@@ -3,52 +3,93 @@
#include <vector>
// returns {mean,
unbiased_var,
biased_var}
// returns {mean,biased_var}
// implemented using welford
std
::
vector
<
at
::
Tensor
>
welford_mean_var_CUDA
(
const
at
::
Tensor
input
);
// reduces array of mean/var across processes
// returns global {mean,
unbiased_var
,biased_var}
// returns global {mean,
inv_std
,biased_var}
// implemented using welford
std
::
vector
<
at
::
Tensor
>
welford_parallel_CUDA
(
const
at
::
Tensor
mean_feature_nodes
,
const
at
::
Tensor
var_biased_feature_nodes
,
int
numel
);
std
::
vector
<
at
::
Tensor
>
welford_parallel_CUDA
(
const
at
::
Tensor
mean_feature_nodes
,
const
at
::
Tensor
var_biased_feature_nodes
,
int
numel
,
const
float
eps
);
// elementwise BN operation, returns output
// input/weight/shift should have identical data type;
// mean/
var
have promoted data type (dtype==fp16?fp32:dtype)
// mean/
inv_std
have promoted data type (dtype==fp16?fp32:dtype)
at
::
Tensor
batchnorm_forward_CUDA
(
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
var
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
Tensor
shift
,
const
float
eps
);
const
at
::
Tensor
shift
);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// mean/
var
have promoted data type (dtype==fp16?fp32:dtype)
// mean/
inv_std
have promoted data type (dtype==fp16?fp32:dtype)
// implemented using kahan summation
std
::
vector
<
at
::
Tensor
>
reduce_bn_CUDA
(
const
at
::
Tensor
grad_output
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
var
,
const
at
::
Tensor
weight
,
const
float
eps
);
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
);
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// mean/
var
/mean_dy/mean_dy_xmu precision is fp32
// mean/
inv_std
/mean_dy/mean_dy_xmu precision is fp32
at
::
Tensor
batchnorm_backward_CUDA
(
const
at
::
Tensor
grad_output
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
var
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
Tensor
mean_dy
,
const
at
::
Tensor
mean_dy_xmu
,
const
float
eps
);
const
at
::
Tensor
mean_dy_xmu
);
// returns {mean, biased_var}
// implemented using welford
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
std
::
vector
<
at
::
Tensor
>
welford_mean_var_c_last_CUDA
(
const
at
::
Tensor
input
);
// elementwise BN operation, returns output
// input/weight/shift should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at
::
Tensor
batchnorm_forward_c_last_CUDA
(
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
Tensor
shift
);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
std
::
vector
<
at
::
Tensor
>
reduce_bn_c_last_CUDA
(
const
at
::
Tensor
grad_output
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
);
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// mean/inv_std/mean_dy/mean_dy_xmu precision is fp32
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at
::
Tensor
batchnorm_backward_c_last_CUDA
(
const
at
::
Tensor
grad_output
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
Tensor
mean_dy
,
const
at
::
Tensor
mean_dy_xmu
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"welford_mean_var"
,
&
welford_mean_var_CUDA
,
"welford mean variance"
);
m
.
def
(
"welford_parallel"
,
&
welford_parallel_CUDA
,
"welford parallel reduce mean variance"
);
m
.
def
(
"batchnorm_forward"
,
&
batchnorm_forward_CUDA
,
"batchnorm forward"
);
m
.
def
(
"reduce_bn"
,
&
reduce_bn_CUDA
,
"batchnorm backward reduce grad sum and bias/weight grad
ient
"
);
m
.
def
(
"reduce_bn"
,
&
reduce_bn_CUDA
,
"batchnorm backward reduce grad sum and bias/weight grad"
);
m
.
def
(
"batchnorm_backward"
,
&
batchnorm_backward_CUDA
,
"batchnorm backward dgrad"
);
m
.
def
(
"welford_mean_var_c_last"
,
&
welford_mean_var_c_last_CUDA
,
"welford mean variance nhwc"
);
m
.
def
(
"batchnorm_forward_c_last"
,
&
batchnorm_forward_c_last_CUDA
,
"batchnorm forward nhwc"
);
m
.
def
(
"reduce_bn_c_last"
,
&
reduce_bn_c_last_CUDA
,
"batchnorm backwards reduce grad sum and bias/weight grad nhwc"
);
m
.
def
(
"batchnorm_backward_c_last"
,
&
batchnorm_backward_c_last_CUDA
,
"batchnorm backward dgrad nhwc"
);
}
csrc/welford.cu
View file @
438f6f9f
...
...
@@ -71,8 +71,57 @@ __device__ __forceinline__ T reduce_block(T *x, T val)
return
val
;
}
#define TILE_W 32
#define MAX_BLOCK_SIZE 1024
#define ELEMENTS_PER_ITER 4 // enables concurrency within each thread to hide latency
#define ELEMENTS_PER_THREAD 16
#define OPTIMAL_TILE_W 32
#define MAX_H_BLOCK 128
#define MAX_BLOCK_SIZE 512
__host__
int
div_ru
(
int
x
,
int
y
)
{
return
h_last_pow2
(
1
+
(
x
-
1
)
/
y
);
}
__host__
void
flexible_launch_configs
(
const
int
reduction
,
const
int
stride
,
dim3
&
block
,
dim3
&
grid
,
const
bool
coop_flag
=
false
)
{
int
block_x
=
std
::
min
(
h_last_pow2
(
stride
),
OPTIMAL_TILE_W
);
int
block_y
=
std
::
min
(
h_last_pow2
(
div_ru
(
reduction
,
ELEMENTS_PER_THREAD
)),
MAX_BLOCK_SIZE
/
block_x
);
if
(
block_x
*
block_y
!=
MAX_BLOCK_SIZE
)
{
block_x
=
std
::
min
(
h_last_pow2
(
stride
),
MAX_BLOCK_SIZE
/
block_y
);
}
int
grid_x
=
div_ru
(
stride
,
block_x
);
int
grid_y
=
std
::
min
(
div_ru
(
reduction
,
block_y
*
ELEMENTS_PER_THREAD
),
MAX_H_BLOCK
);
if
(
coop_flag
)
{
// it's not worth having a grid reduction if the reduction dimension is not big enough
grid_y
=
grid_y
<
8
?
1
:
grid_y
;
}
block
.
x
=
block_x
;
block
.
y
=
block_y
;
block
.
z
=
1
;
grid
.
x
=
grid_x
;
grid
.
y
=
grid_y
;
grid
.
z
=
1
;
}
template
<
typename
T
,
typename
C
>
__device__
__forceinline__
void
welford_merge_element
(
C
&
count
,
T
&
mean
,
T
&
m2n
,
const
C
&
num_new
,
const
T
&
mean_new
,
const
T
&
m2n_new
)
{
T
factor
=
T
(
1.0
)
/
max
(
1
,
(
count
+
num_new
));
T
delta0
=
mean
-
mean_new
;
mean
=
(
mean_new
*
num_new
+
mean
*
count
)
*
factor
;
m2n
+=
m2n_new
+
delta0
*
delta0
*
num_new
*
count
*
factor
;
count
+=
num_new
;
}
template
<
typename
T
>
__device__
__forceinline__
void
warp_reduce_mean_m2n
(
T
&
mean
,
T
&
m2n
,
int
&
num
)
...
...
@@ -82,11 +131,7 @@ __device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num)
auto
num_new
=
__shfl_down_sync
(
0xffffffff
,
num
,
i
);
auto
mean_new
=
__shfl_down_sync
(
0xffffffff
,
mean
,
i
);
auto
m2n_new
=
__shfl_down_sync
(
0xffffffff
,
m2n
,
i
);
T
factor
=
1.0
/
max
(
1
,
(
num
+
num_new
));
auto
dif_mean
=
mean
-
mean_new
;
mean
=
(
mean_new
*
num_new
+
mean
*
num
)
*
factor
;
m2n
+=
m2n_new
+
dif_mean
*
dif_mean
*
num
*
num_new
*
factor
;
num
+=
num_new
;
welford_merge_element
(
num
,
mean
,
m2n
,
num_new
,
mean_new
,
m2n_new
);
}
}
...
...
@@ -148,13 +193,71 @@ __host__ size_t get_element_data_size(const at::Tensor& input, bool accumulation
return
at
::
elementSize
(
scalar_type
);
}
template
<
typename
T
,
typename
C
>
__device__
__forceinline__
void
welford_merge_block_vertical
(
C
&
count
,
T
&
mean
,
T
&
m2n
,
C
*
shmem_count
,
T
*
shmem_mean
,
T
*
shmem_m2n
)
{
// write to shared memory
auto
address_base
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
shmem_mean
[
address_base
]
=
mean
;
shmem_m2n
[
address_base
]
=
m2n
;
shmem_count
[
address_base
]
=
count
;
#pragma unroll
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
>>=
1
)
{
__syncthreads
();
if
(
threadIdx
.
y
<
offset
&&
threadIdx
.
y
+
offset
<
blockDim
.
y
)
{
auto
address
=
address_base
+
offset
*
blockDim
.
x
;
// read shared memory back to register for reduction
auto
num_new
=
shmem_count
[
address
];
auto
mean_new
=
shmem_mean
[
address
];
auto
m2n_new
=
shmem_m2n
[
address
];
welford_merge_element
(
count
,
mean
,
m2n
,
num_new
,
mean_new
,
m2n_new
);
// last write is not necessary
shmem_mean
[
address_base
]
=
mean
;
shmem_m2n
[
address_base
]
=
m2n
;
shmem_count
[
address_base
]
=
count
;
}
}
}
template
<
typename
T
>
__device__
__forceinline__
void
merge_block_vertical
(
T
&
sum_dy
,
T
&
sum_dy_xmu
,
T
*
shmem_sum_dy
,
T
*
shmem_sum_dy_xmu
)
{
// write to shared memory
auto
address_base
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
shmem_sum_dy
[
address_base
]
=
sum_dy
;
shmem_sum_dy_xmu
[
address_base
]
=
sum_dy_xmu
;
#pragma unroll
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
>>=
1
)
{
__syncthreads
();
if
(
threadIdx
.
y
<
offset
&&
threadIdx
.
y
+
offset
<
blockDim
.
y
)
{
auto
address
=
address_base
+
offset
*
blockDim
.
x
;
sum_dy
+=
shmem_sum_dy
[
address
];
sum_dy_xmu
+=
shmem_sum_dy_xmu
[
address
];
// last write is not necessary
shmem_sum_dy
[
address_base
]
=
sum_dy
;
shmem_sum_dy_xmu
[
address_base
]
=
sum_dy_xmu
;
}
}
}
// welford kernel calculating mean/biased_variance/unbiased_variance
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
outscalar_t
>
__global__
void
welford_kernel
(
const
scalar_t
*
__restrict__
input
,
outscalar_t
*
__restrict__
out_mean
,
outscalar_t
*
__restrict__
out_var
,
outscalar_t
*
__restrict__
out_var_biased
,
const
int
bs
,
const
int
fs
,
...
...
@@ -185,7 +288,6 @@ __global__ void welford_kernel(
if
(
thread_id
==
0
)
{
out_mean
[
blockIdx
.
x
]
=
static_cast
<
outscalar_t
>
(
x_mean
);
out_var
[
blockIdx
.
x
]
=
static_cast
<
outscalar_t
>
(
m_2_n
/
(
count
-
1
));
out_var_biased
[
blockIdx
.
x
]
=
static_cast
<
outscalar_t
>
(
m_2_n
/
count
);
}
}
...
...
@@ -195,15 +297,14 @@ template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
__global__
void
batchnorm_forward_kernel
(
const
scalar_t
*
__restrict__
input
,
const
accscalar_t
*
__restrict__
mean
,
const
accscalar_t
*
__restrict__
var
,
const
accscalar_t
*
__restrict__
inv_std
,
const
layerscalar_t
*
__restrict__
weight
,
const
layerscalar_t
*
__restrict__
shift
,
scalar_t
*
__restrict__
out
,
const
int
ss
,
const
int
bs
,
const
float
eps
)
{
const
int
bs
)
{
auto
m_c
=
mean
[
blockIdx
.
x
];
auto
inv_std_c
=
static_cast
<
accscalar_t
>
(
rsqrt
(
var
[
blockIdx
.
x
]
+
eps
))
;
auto
inv_std_c
=
inv_std
[
blockIdx
.
x
];
auto
w_c
=
static_cast
<
accscalar_t
>
(
weight
[
blockIdx
.
x
]);
auto
s_c
=
static_cast
<
accscalar_t
>
(
shift
[
blockIdx
.
x
]);
...
...
@@ -224,22 +325,21 @@ __global__ void reduce_bn_kernel(
const
scalar_t
*
__restrict__
input
,
const
scalar_t
*
__restrict__
grad_output
,
const
accscalar_t
*
__restrict__
mean
,
const
accscalar_t
*
__restrict__
var
,
const
accscalar_t
*
__restrict__
inv_std
,
accscalar_t
*
__restrict__
mean_dy
,
accscalar_t
*
__restrict__
mean_dy_xmu
,
layerscalar_t
*
__restrict__
grad_weight
,
layerscalar_t
*
__restrict__
grad_bias
,
const
int
bs
,
const
int
fs
,
const
int
ss
,
const
float
eps
)
{
const
int
ss
)
{
static
__shared__
int
s_mem
[
64
];
int
total_item_num
=
bs
*
ss
;
int
thread_id
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
auto
r_mean
=
mean
[
blockIdx
.
x
];
auto
factor
=
accscalar_t
(
1.0
)
/
(
accscalar_t
)
sqrt
(
var
[
blockIdx
.
x
]
+
eps
)
;
auto
factor
=
inv_std
[
blockIdx
.
x
];
// Kahan sum
accscalar_t
sum_dy
=
0.0
;
...
...
@@ -283,64 +383,435 @@ __global__ void batchnorm_backward_kernel(
const
scalar_t
*
__restrict__
grad_output
,
const
scalar_t
*
__restrict__
input
,
const
accscalar_t
*
__restrict__
mean
,
const
accscalar_t
*
__restrict__
var
,
const
accscalar_t
*
__restrict__
inv_std
,
const
layerscalar_t
*
__restrict__
weight
,
const
accscalar_t
*
__restrict__
mean_dy
,
const
accscalar_t
*
__restrict__
mean_dy_xmu
,
scalar_t
*
__restrict__
grad_input
,
const
int
ss
,
const
int
bs
,
const
float
eps
)
{
const
int
bs
)
{
auto
m_c
=
static_cast
<
accscalar_t
>
(
mean
[
blockIdx
.
x
]);
auto
m_dy_c
=
static_cast
<
accscalar_t
>
(
mean_dy
[
blockIdx
.
x
]);
auto
factor_1_c
=
static_cast
<
accscalar_t
>
(
var
[
blockIdx
.
x
]
)
+
eps
;
auto
factor_2_c
=
static_cast
<
accscalar_t
>
(
weight
[
blockIdx
.
x
])
/
sqrt
(
factor_1_c
)
;
factor_1_c
/
=
static_cast
<
accscalar_t
>
(
mean_dy_xmu
[
blockIdx
.
x
]
)
;
auto
factor_1_c
=
inv_std
[
blockIdx
.
x
];
auto
factor_2_c
=
static_cast
<
accscalar_t
>
(
weight
[
blockIdx
.
x
])
*
factor_1_c
;
factor_1_c
=
factor_1_c
*
factor_1_c
*
mean_dy_xmu
[
blockIdx
.
x
];
for
(
int
batch_offset
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
batch_offset
<
bs
;
batch_offset
+=
gridDim
.
y
*
blockDim
.
y
)
{
int
address_base
=
blockIdx
.
x
*
ss
+
batch_offset
*
gridDim
.
x
*
ss
;
for
(
int
offset
=
threadIdx
.
x
+
blockIdx
.
z
*
blockDim
.
x
;
offset
<
ss
;
offset
+=
gridDim
.
z
*
blockDim
.
x
)
{
grad_input
[
address_base
+
offset
]
=
(
static_cast
<
accscalar_t
>
(
grad_output
[
address_base
+
offset
])
-
m_dy_c
-
(
static_cast
<
accscalar_t
>
(
input
[
address_base
+
offset
])
-
m_c
)
/
factor_1_c
)
*
factor_2_c
;
grad_input
[
address_base
+
offset
]
=
(
static_cast
<
accscalar_t
>
(
grad_output
[
address_base
+
offset
])
-
m_dy_c
-
(
static_cast
<
accscalar_t
>
(
input
[
address_base
+
offset
])
-
m_c
)
*
factor_1_c
)
*
factor_2_c
;
}
}
}
// parallel welford kernel to further reduce mean / biased_var / unbiased_var
// across multiple processes.
template
<
typename
scalar_t
,
typename
accscalar_t
>
// welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
outscalar_t
,
int
PARALLEL_LOADS
>
__global__
void
welford_kernel_c_last
(
const
scalar_t
*
__restrict__
input
,
outscalar_t
*
__restrict__
out_mean
,
outscalar_t
*
__restrict__
out_var_biased
,
volatile
accscalar_t
*
staging_data
,
int
*
semaphores
,
const
int
reduction_size
,
const
int
stride
)
{
// hide latency with concurrency
accscalar_t
x_mean
[
PARALLEL_LOADS
];
accscalar_t
m_2_n
[
PARALLEL_LOADS
];
int
count
[
PARALLEL_LOADS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
PARALLEL_LOADS
;
i
++
)
{
x_mean
[
i
]
=
accscalar_t
(
0
);
m_2_n
[
i
]
=
accscalar_t
(
0
);
count
[
i
]
=
accscalar_t
(
0
);
}
// tensor dimension (m,c)
// loop along m dimension
int
inner_loop_stride
=
blockDim
.
y
*
gridDim
.
y
;
// offset along m dimension
int
m_offset
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
c_offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
loop_count
=
1
+
(
reduction_size
-
1
)
/
(
inner_loop_stride
*
PARALLEL_LOADS
);
int
address_base
=
m_offset
*
stride
+
c_offset
;
int
address_increment
=
inner_loop_stride
*
stride
;
for
(
int
i
=
0
;
i
<
loop_count
;
i
++
)
{
accscalar_t
x_math
[
PARALLEL_LOADS
];
accscalar_t
x_count_inv
[
PARALLEL_LOADS
];
accscalar_t
is_valid
[
PARALLEL_LOADS
];
// load multiple data in
#pragma unroll
for
(
int
j
=
0
;
j
<
PARALLEL_LOADS
;
j
++
)
{
if
(
c_offset
<
stride
&&
m_offset
<
reduction_size
)
{
x_math
[
j
]
=
input
[
address_base
];
count
[
j
]
++
;
x_count_inv
[
j
]
=
accscalar_t
(
1
)
/
count
[
j
];
is_valid
[
j
]
=
accscalar_t
(
1
);
}
else
{
x_math
[
j
]
=
accscalar_t
(
0
);
x_count_inv
[
j
]
=
accscalar_t
(
0
);
is_valid
[
j
]
=
accscalar_t
(
0
);
}
m_offset
+=
inner_loop_stride
;
address_base
+=
address_increment
;
}
// calculate mean/m2n with welford
#pragma unroll
for
(
int
j
=
0
;
j
<
PARALLEL_LOADS
;
j
++
)
{
accscalar_t
delta0
=
x_math
[
j
]
-
x_mean
[
j
];
x_mean
[
j
]
+=
delta0
*
x_count_inv
[
j
];
accscalar_t
delta1
=
x_math
[
j
]
-
x_mean
[
j
];
m_2_n
[
j
]
+=
delta0
*
delta1
*
is_valid
[
j
];
}
}
// thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS
#pragma unroll
for
(
int
j
=
1
;
j
<
PARALLEL_LOADS
;
j
++
)
{
welford_merge_element
(
count
[
0
],
x_mean
[
0
],
m_2_n
[
0
],
count
[
j
],
x_mean
[
j
],
m_2_n
[
j
]);
}
// release x_mean / m_2_n
auto
mean_th
=
x_mean
[
0
];
auto
m2_th
=
m_2_n
[
0
];
auto
count_th
=
count
[
0
];
// block-wise reduction with shared memory (since reduction cannot be done within a warp)
static
__shared__
accscalar_t
shmem_mean
[
MAX_BLOCK_SIZE
];
static
__shared__
accscalar_t
shmem_m2n
[
MAX_BLOCK_SIZE
];
static
__shared__
int
shmem_count
[
MAX_BLOCK_SIZE
];
welford_merge_block_vertical
(
count_th
,
mean_th
,
m2_th
,
shmem_count
,
shmem_mean
,
shmem_m2n
);
// grid reduction if needed (coop launch used at the first place)
if
(
gridDim
.
y
>
1
)
{
volatile
accscalar_t
*
staging_mean
=
staging_data
;
volatile
accscalar_t
*
staging_m2n
=
&
staging_data
[
stride
*
gridDim
.
y
];
volatile
int
*
staging_count
=
reinterpret_cast
<
volatile
int
*>
(
&
staging_m2n
[
stride
*
gridDim
.
y
]);
address_base
=
c_offset
+
blockIdx
.
y
*
stride
;
// write data to staging_data;
if
(
threadIdx
.
y
==
0
&&
c_offset
<
stride
)
{
staging_mean
[
address_base
]
=
mean_th
;
staging_m2n
[
address_base
]
=
m2_th
;
staging_count
[
address_base
]
=
count_th
;
}
__threadfence
();
__shared__
bool
is_last_block_done
;
// mark block done
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
int
old
=
atomicAdd
(
&
semaphores
[
blockIdx
.
x
],
1
);
is_last_block_done
=
(
old
==
(
gridDim
.
y
-
1
));
}
__syncthreads
();
// check that all data is now available in global memory
if
(
is_last_block_done
)
{
count_th
=
0
;
mean_th
=
accscalar_t
(
0.0
);
m2_th
=
accscalar_t
(
0.0
);
for
(
int
y
=
threadIdx
.
y
;
y
<
gridDim
.
y
;
y
+=
blockDim
.
y
)
{
address_base
=
c_offset
+
y
*
stride
;
int
num_new
=
c_offset
<
stride
?
staging_count
[
address_base
]
:
0
;
accscalar_t
mean_new
=
c_offset
<
stride
?
staging_mean
[
address_base
]
:
accscalar_t
(
0.0
);
accscalar_t
m2n_new
=
c_offset
<
stride
?
staging_m2n
[
address_base
]
:
accscalar_t
(
0.0
);
welford_merge_element
(
count_th
,
mean_th
,
m2_th
,
num_new
,
mean_new
,
m2n_new
);
}
welford_merge_block_vertical
(
count_th
,
mean_th
,
m2_th
,
shmem_count
,
shmem_mean
,
shmem_m2n
);
if
(
threadIdx
.
y
==
0
&&
c_offset
<
stride
)
{
out_mean
[
c_offset
]
=
static_cast
<
outscalar_t
>
(
mean_th
);
out_var_biased
[
c_offset
]
=
static_cast
<
outscalar_t
>
(
m2_th
/
count_th
);
}
}
}
else
{
if
(
blockIdx
.
y
==
0
&&
threadIdx
.
y
==
0
&&
c_offset
<
stride
)
{
out_mean
[
c_offset
]
=
static_cast
<
outscalar_t
>
(
mean_th
);
out_var_biased
[
c_offset
]
=
static_cast
<
outscalar_t
>
(
m2_th
/
count_th
);
}
}
}
// parallel welford kernel to further reduce mean / biased_var
// into mean / unbiased_var / inv_std across multiple processes.
template
<
typename
scalar_t
>
__global__
void
welford_kernel_parallel
(
const
scalar_t
*
__restrict__
mean
,
const
scalar_t
*
__restrict__
var_biased
,
scalar_t
*
__restrict__
out_mean
,
scalar_t
*
__restrict__
out_var
,
scalar_t
*
__restrict__
out_var_biased
,
const
int
ns
,
const
int
fs
,
scalar_t
*
__restrict__
inv_std
,
const
int
world_size
,
const
int
feature_size
,
const
float
eps
,
const
int
numel
)
{
static
__shared__
int
s_mem
[
160
];
int
block_size
=
blockDim
.
x
;
accscalar_t
*
s_mem_ac
=
(
accscalar_t
*
)
&
s_mem
[
32
];
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
feature_size
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
// load data;
int
address
=
i
;
scalar_t
x_mean
=
0
;
scalar_t
m_2_n
=
0
;
int
count
=
0
;
for
(
int
j
=
0
;
j
<
world_size
;
j
++
)
{
welford_merge_element
(
count
,
x_mean
,
m_2_n
,
numel
,
mean
[
address
],
var_biased
[
address
]
*
numel
);
address
+=
feature_size
;
}
out_mean
[
i
]
=
x_mean
;
out_var
[
i
]
=
m_2_n
/
(
count
-
1
);
inv_std
[
i
]
=
scalar_t
(
1
)
/
sqrt
(
m_2_n
/
count
+
eps
);
}
}
int
input_base
=
blockIdx
.
x
*
ns
+
threadIdx
.
x
;
int
thread_id
=
threadIdx
.
x
;
// elementwise BN kernel
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
layerscalar_t
,
int
PARALLEL_LOADS
>
__global__
void
batchnorm_forward_c_last_kernel
(
const
scalar_t
*
__restrict__
input
,
const
accscalar_t
*
__restrict__
mean
,
const
accscalar_t
*
__restrict__
inv_std
,
const
layerscalar_t
*
__restrict__
weight
,
const
layerscalar_t
*
__restrict__
shift
,
scalar_t
*
__restrict__
out
,
const
int
reduction_size
,
const
int
stride
)
{
// tensor dimension (m,c)
// loop along m dimension
int
inner_loop_stride
=
blockDim
.
y
*
gridDim
.
y
;
// offset along m dimension
int
m_offset
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
c_offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
auto
m_c
=
mean
[
c_offset
];
auto
inv_std_c
=
static_cast
<
accscalar_t
>
(
inv_std
[
c_offset
]);
auto
w_c
=
static_cast
<
accscalar_t
>
(
weight
[
c_offset
]);
auto
s_c
=
static_cast
<
accscalar_t
>
(
shift
[
c_offset
]);
int
loop_count
=
1
+
(
reduction_size
-
1
)
/
(
inner_loop_stride
*
PARALLEL_LOADS
);
int
address_base
=
m_offset
*
stride
+
c_offset
;
int
address_increment
=
inner_loop_stride
*
stride
;
for
(
int
i
=
0
;
i
<
loop_count
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
PARALLEL_LOADS
;
j
++
)
{
if
(
c_offset
<
stride
&&
m_offset
<
reduction_size
)
{
out
[
address_base
]
=
static_cast
<
scalar_t
>
(
w_c
*
(
static_cast
<
accscalar_t
>
(
input
[
address_base
])
-
m_c
)
*
inv_std_c
+
s_c
);
}
m_offset
+=
inner_loop_stride
;
address_base
+=
address_increment
;
}
}
}
// load data;
auto
x_mean
=
static_cast
<
accscalar_t
>
(
mean
[
input_base
]);
auto
m_2_n
=
static_cast
<
accscalar_t
>
(
var_biased
[
input_base
])
*
numel
;
auto
count
=
numel
;
// batchnorm backward kernel for c last tensor
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
layerscalar_t
,
int
PARALLEL_LOADS
>
__global__
void
reduce_bn_c_last_kernel
(
const
scalar_t
*
__restrict__
input
,
const
scalar_t
*
__restrict__
grad_output
,
const
accscalar_t
*
__restrict__
mean
,
const
accscalar_t
*
__restrict__
inv_std
,
accscalar_t
*
__restrict__
mean_dy
,
accscalar_t
*
__restrict__
mean_dy_xmu
,
layerscalar_t
*
__restrict__
grad_weight
,
layerscalar_t
*
__restrict__
grad_bias
,
volatile
accscalar_t
*
staging_data
,
int
*
semaphores
,
const
int
reduction_size
,
const
int
stride
)
{
// hide latency with concurrency
accscalar_t
sum_dy
[
PARALLEL_LOADS
];
accscalar_t
sum_dy_xmu
[
PARALLEL_LOADS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
PARALLEL_LOADS
;
i
++
)
{
sum_dy
[
i
]
=
accscalar_t
(
0
);
sum_dy_xmu
[
i
]
=
accscalar_t
(
0
);
}
// tensor dimension (m,c)
// loop along m dimension
int
inner_loop_stride
=
blockDim
.
y
*
gridDim
.
y
;
// offset along m dimension
int
m_offset
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
c_offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
loop_count
=
1
+
(
reduction_size
-
1
)
/
(
inner_loop_stride
*
PARALLEL_LOADS
);
int
address_base
=
m_offset
*
stride
+
c_offset
;
int
address_increment
=
inner_loop_stride
*
stride
;
auto
r_mean
=
mean
[
c_offset
];
auto
factor
=
inv_std
[
c_offset
];
for
(
int
i
=
0
;
i
<
loop_count
;
i
++
)
{
accscalar_t
x_input
[
PARALLEL_LOADS
];
accscalar_t
x_grad_output
[
PARALLEL_LOADS
];
// load multiple data in
#pragma unroll
for
(
int
j
=
0
;
j
<
PARALLEL_LOADS
;
j
++
)
{
if
(
c_offset
<
stride
&&
m_offset
<
reduction_size
)
{
x_input
[
j
]
=
input
[
address_base
];
x_grad_output
[
j
]
=
grad_output
[
address_base
];
}
else
{
x_input
[
j
]
=
accscalar_t
(
0
);
x_grad_output
[
j
]
=
accscalar_t
(
0
);
}
m_offset
+=
inner_loop_stride
;
address_base
+=
address_increment
;
}
__syncthreads
();
// calculate sum_dy / sum_dy_xmu
#pragma unroll
for
(
int
j
=
0
;
j
<
PARALLEL_LOADS
;
j
++
)
{
sum_dy
[
j
]
+=
x_grad_output
[
j
];
sum_dy_xmu
[
j
]
+=
x_grad_output
[
j
]
*
(
x_input
[
j
]
-
r_mean
);
}
}
welford_reduce_mean_m2n
<
accscalar_t
>
(
s_mem_ac
,
s_mem
,
x_mean
,
m_2_n
,
count
,
block_size
,
thread_id
);
// thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS
#pragma unroll
for
(
int
j
=
1
;
j
<
PARALLEL_LOADS
;
j
++
)
{
sum_dy
[
0
]
+=
sum_dy
[
j
];
sum_dy_xmu
[
0
]
+=
sum_dy_xmu
[
j
];
}
if
(
thread_id
==
0
)
{
out_mean
[
blockIdx
.
x
]
=
static_cast
<
scalar_t
>
(
x_mean
);
out_var
[
blockIdx
.
x
]
=
static_cast
<
scalar_t
>
(
m_2_n
/
(
count
-
1
));
out_var_biased
[
blockIdx
.
x
]
=
static_cast
<
scalar_t
>
(
m_2_n
/
count
);
// release array of registers
auto
sum_dy_th
=
sum_dy
[
0
];
auto
sum_dy_xmu_th
=
sum_dy_xmu
[
0
];
// block-wise reduction with shared memory (since reduction cannot be done within a warp)
static
__shared__
accscalar_t
shmem_sum_dy
[
MAX_BLOCK_SIZE
];
static
__shared__
accscalar_t
shmem_sum_dy_xmu
[
MAX_BLOCK_SIZE
];
merge_block_vertical
(
sum_dy_th
,
sum_dy_xmu_th
,
shmem_sum_dy
,
shmem_sum_dy_xmu
);
// grid reduction if needed (coop launch used at the first place)
if
(
gridDim
.
y
>
1
)
{
volatile
accscalar_t
*
staging_sum_dy
=
staging_data
;
volatile
accscalar_t
*
staging_sum_dy_xmu
=
&
staging_data
[
stride
*
gridDim
.
y
];
address_base
=
c_offset
+
blockIdx
.
y
*
stride
;
// write data to staging_data;
if
(
threadIdx
.
y
==
0
&&
c_offset
<
stride
)
{
staging_sum_dy
[
address_base
]
=
sum_dy_th
;
staging_sum_dy_xmu
[
address_base
]
=
sum_dy_xmu_th
;
}
__threadfence
();
__shared__
bool
is_last_block_done
;
// mark block done
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
int
old
=
atomicAdd
(
&
semaphores
[
blockIdx
.
x
],
1
);
is_last_block_done
=
(
old
==
(
gridDim
.
y
-
1
));
}
__syncthreads
();
// check that all data is now available in global memory
if
(
is_last_block_done
)
{
sum_dy_th
=
accscalar_t
(
0.0
);
sum_dy_xmu_th
=
accscalar_t
(
0.0
);
for
(
int
y
=
threadIdx
.
y
;
y
<
gridDim
.
y
;
y
+=
blockDim
.
y
)
{
address_base
=
c_offset
+
y
*
stride
;
sum_dy_th
+=
(
c_offset
<
stride
?
staging_sum_dy
[
address_base
]
:
accscalar_t
(
0.0
));
sum_dy_xmu_th
+=
(
c_offset
<
stride
?
staging_sum_dy_xmu
[
address_base
]
:
accscalar_t
(
0.0
));
}
merge_block_vertical
(
sum_dy_th
,
sum_dy_xmu_th
,
shmem_sum_dy
,
shmem_sum_dy_xmu
);
if
(
threadIdx
.
y
==
0
&&
c_offset
<
stride
)
{
grad_bias
[
c_offset
]
=
static_cast
<
layerscalar_t
>
(
sum_dy_th
);
grad_weight
[
c_offset
]
=
static_cast
<
layerscalar_t
>
(
sum_dy_xmu_th
*
factor
);
mean_dy
[
c_offset
]
=
sum_dy_th
/
reduction_size
;
mean_dy_xmu
[
c_offset
]
=
sum_dy_xmu_th
/
reduction_size
;
}
}
}
else
{
if
(
blockIdx
.
y
==
0
&&
threadIdx
.
y
==
0
&&
c_offset
<
stride
)
{
grad_bias
[
c_offset
]
=
static_cast
<
layerscalar_t
>
(
sum_dy_th
);
grad_weight
[
c_offset
]
=
static_cast
<
layerscalar_t
>
(
sum_dy_xmu_th
*
factor
);
mean_dy
[
c_offset
]
=
sum_dy_th
/
reduction_size
;
mean_dy_xmu
[
c_offset
]
=
sum_dy_xmu_th
/
reduction_size
;
}
}
}
// elementwise BN kernel
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
layerscalar_t
,
int
PARALLEL_LOADS
>
__global__
void
batchnorm_backward_c_last_kernel
(
const
scalar_t
*
__restrict__
grad_output
,
const
scalar_t
*
__restrict__
input
,
const
accscalar_t
*
__restrict__
mean
,
const
accscalar_t
*
__restrict__
inv_std
,
const
layerscalar_t
*
__restrict__
weight
,
const
accscalar_t
*
__restrict__
mean_dy
,
const
accscalar_t
*
__restrict__
mean_dy_xmu
,
scalar_t
*
__restrict__
grad_input
,
const
int
reduction_size
,
const
int
stride
)
{
// tensor dimension (m,c)
// loop along m dimension
int
inner_loop_stride
=
blockDim
.
y
*
gridDim
.
y
;
// offset along m dimension
int
m_offset
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
c_offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
auto
m_c
=
mean
[
c_offset
];
auto
m_dy_c
=
mean_dy
[
c_offset
];
auto
factor_1_c
=
inv_std
[
c_offset
];
auto
factor_2_c
=
static_cast
<
accscalar_t
>
(
weight
[
c_offset
])
*
factor_1_c
;
factor_1_c
=
factor_1_c
*
factor_1_c
*
mean_dy_xmu
[
c_offset
];
int
loop_count
=
1
+
(
reduction_size
-
1
)
/
(
inner_loop_stride
*
PARALLEL_LOADS
);
int
address_base
=
m_offset
*
stride
+
c_offset
;
int
address_increment
=
inner_loop_stride
*
stride
;
for
(
int
i
=
0
;
i
<
loop_count
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
PARALLEL_LOADS
;
j
++
)
{
if
(
c_offset
<
stride
&&
m_offset
<
reduction_size
)
{
grad_input
[
address_base
]
=
static_cast
<
scalar_t
>
(
(
static_cast
<
accscalar_t
>
(
grad_output
[
address_base
])
-
m_dy_c
-
(
static_cast
<
accscalar_t
>
(
input
[
address_base
])
-
m_c
)
*
factor_1_c
)
*
factor_2_c
);
}
m_offset
+=
inner_loop_stride
;
address_base
+=
address_increment
;
}
}
}
std
::
vector
<
at
::
Tensor
>
welford_mean_var_CUDA
(
const
at
::
Tensor
input
)
{
const
auto
batch_size
=
input
.
size
(
0
);
...
...
@@ -349,7 +820,6 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
auto
space_size
=
get_tensor_spatial_size
(
input
);
auto
scalar_type
=
promote_scalartype
(
input
);
at
::
Tensor
out_var
=
at
::
empty
({
feature_size
},
input
.
options
().
dtype
(
scalar_type
));
at
::
Tensor
out_var_biased
=
at
::
empty
({
feature_size
},
input
.
options
().
dtype
(
scalar_type
));
at
::
Tensor
out_mean
=
at
::
empty
({
feature_size
},
input
.
options
().
dtype
(
scalar_type
));
...
...
@@ -358,7 +828,6 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
const
dim3
block
(
block_x
,
block_y
);
const
dim3
grid
(
feature_size
);
// shared memory used for reduce on mean, var, num_elements;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"welford_mean_var_kernel"
,
([
&
]
{
...
...
@@ -366,23 +835,21 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
welford_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
out_mean
.
data
<
accscalar_t
>
(),
out_var
.
data
<
accscalar_t
>
(),
out_var_biased
.
data
<
accscalar_t
>
(),
batch_size
,
feature_size
,
space_size
);
}));
return
{
out_mean
,
out_var
,
out_var_biased
};
return
{
out_mean
,
out_var_biased
};
}
at
::
Tensor
batchnorm_forward_CUDA
(
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
var
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
Tensor
shift
,
const
float
eps
)
{
const
at
::
Tensor
shift
)
{
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
feature_size
=
input
.
size
(
1
);
at
::
Tensor
out
=
at
::
empty_like
(
input
);
...
...
@@ -403,13 +870,12 @@ at::Tensor batchnorm_forward_CUDA(
batchnorm_forward_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
var
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
data
<
accscalar_t
>
(),
shift
.
data
<
accscalar_t
>
(),
out
.
data
<
scalar_t
>
(),
space_size
,
batch_size
,
eps
);
batch_size
);
}));
}
else
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
...
...
@@ -418,13 +884,12 @@ at::Tensor batchnorm_forward_CUDA(
batchnorm_forward_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
var
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
data
<
scalar_t
>
(),
shift
.
data
<
scalar_t
>
(),
out
.
data
<
scalar_t
>
(),
space_size
,
batch_size
,
eps
);
batch_size
);
}));
}
return
out
;
...
...
@@ -434,9 +899,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
const
at
::
Tensor
grad_output
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
var
,
const
at
::
Tensor
weight
,
const
float
eps
)
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
)
{
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
feature_size
=
input
.
size
(
1
);
...
...
@@ -463,15 +927,14 @@ std::vector<at::Tensor> reduce_bn_CUDA(
input
.
data
<
scalar_t
>
(),
grad_output
.
data
<
scalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
var
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_weight
.
data
<
accscalar_t
>
(),
grad_bias
.
data
<
accscalar_t
>
(),
batch_size
,
feature_size
,
space_size
,
eps
);
space_size
);
}));
}
else
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
...
...
@@ -481,15 +944,14 @@ std::vector<at::Tensor> reduce_bn_CUDA(
input
.
data
<
scalar_t
>
(),
grad_output
.
data
<
scalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
var
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_weight
.
data
<
scalar_t
>
(),
grad_bias
.
data
<
scalar_t
>
(),
batch_size
,
feature_size
,
space_size
,
eps
);
space_size
);
}));
}
...
...
@@ -500,11 +962,10 @@ at::Tensor batchnorm_backward_CUDA(
const
at
::
Tensor
grad_output
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
var
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
Tensor
mean_dy
,
const
at
::
Tensor
mean_dy_xmu
,
const
float
eps
)
{
const
at
::
Tensor
mean_dy_xmu
)
{
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
feature_size
=
input
.
size
(
1
);
...
...
@@ -528,14 +989,13 @@ at::Tensor batchnorm_backward_CUDA(
grad_output
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
var
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_input
.
data
<
scalar_t
>
(),
space_size
,
batch_size
,
eps
);
batch_size
);
}));
}
else
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
...
...
@@ -545,46 +1005,273 @@ at::Tensor batchnorm_backward_CUDA(
grad_output
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
var
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
data
<
scalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_input
.
data
<
scalar_t
>
(),
space_size
,
batch_size
,
eps
);
batch_size
);
}));
}
return
grad_input
;
}
std
::
vector
<
at
::
Tensor
>
welford_parallel_CUDA
(
const
at
::
Tensor
mean_feature_nodes
,
const
at
::
Tensor
var_biased
,
int
numel
)
{
const
auto
feature_size
=
mean_feature_nodes
.
size
(
0
);
const
auto
world_size
=
mean_feature_nodes
.
size
(
1
);
std
::
vector
<
at
::
Tensor
>
welford_parallel_CUDA
(
const
at
::
Tensor
mean_feature_nodes
,
const
at
::
Tensor
var_biased
,
int
numel
,
const
float
eps
)
{
const
auto
world_size
=
mean_feature_nodes
.
size
(
0
);
const
auto
feature_size
=
mean_feature_nodes
.
size
(
1
);
at
::
Tensor
out_var
=
at
::
empty
({
feature_size
},
var_biased
.
options
());
at
::
Tensor
out_var_biase
d
=
at
::
empty_like
(
out_var
);
at
::
Tensor
inv_st
d
=
at
::
empty_like
(
out_var
);
at
::
Tensor
out_mean
=
at
::
empty_like
(
out_var
);
// TODO(jie): tile this for memory coalescing!
const
dim3
block
(
world_size
);
const
dim3
grid
(
feature_size
);
// shared memory used for reduce on mean, var, num_elements;
const
int
block
=
std
::
min
(
h_last_pow2
(
feature_size
),
MAX_BLOCK_SIZE
);
const
int
grid
=
std
::
max
<
int
>
(
1
,
feature_size
/
block
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
mean_feature_nodes
.
type
(),
"welford_parallel_kernel"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
welford_kernel_parallel
<
scalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
welford_kernel_parallel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
mean_feature_nodes
.
data
<
scalar_t
>
(),
var_biased
.
data
<
scalar_t
>
(),
out_mean
.
data
<
scalar_t
>
(),
out_var
.
data
<
scalar_t
>
(),
out_var_biase
d
.
data
<
scalar_t
>
(),
inv_st
d
.
data
<
scalar_t
>
(),
world_size
,
feature_size
,
eps
,
numel
);
}));
return
{
out_mean
,
out_var
,
out_var_biased
};
return
{
out_mean
,
out_var
,
inv_std
};
}
std
::
vector
<
at
::
Tensor
>
welford_mean_var_c_last_CUDA
(
const
at
::
Tensor
input
)
{
const
auto
stride
=
input
.
size
(
input
.
ndimension
()
-
1
);
const
auto
reduction_size
=
input
.
numel
()
/
stride
;
auto
scalar_type
=
promote_scalartype
(
input
);
auto
option
=
input
.
options
().
dtype
(
scalar_type
);
at
::
Tensor
out_var_biased
=
at
::
empty
({
stride
},
option
);
at
::
Tensor
out_mean
=
at
::
empty
({
stride
},
option
);
dim3
block
;
dim3
grid
;
flexible_launch_configs
(
reduction_size
,
stride
,
block
,
grid
,
true
);
at
::
Tensor
staging_data
;
at
::
Tensor
semaphores
;
if
(
grid
.
y
>
1
)
{
staging_data
=
at
::
empty
({
4
*
stride
*
grid
.
y
},
option
);
semaphores
=
at
::
zeros
({
grid
.
x
},
input
.
options
().
dtype
(
at
::
ScalarType
::
Int
));
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"welford_mean_var_c_last"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
accscalar_t
*
staging_data_ptr
=
grid
.
y
>
1
?
staging_data
.
data
<
accscalar_t
>
()
:
nullptr
;
int
*
semaphores_ptr
=
grid
.
y
>
1
?
semaphores
.
data
<
int
>
()
:
nullptr
;
welford_kernel_c_last
<
scalar_t
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
out_mean
.
data
<
accscalar_t
>
(),
out_var_biased
.
data
<
accscalar_t
>
(),
staging_data_ptr
,
semaphores_ptr
,
reduction_size
,
stride
);
}));
return
{
out_mean
,
out_var_biased
};
}
at
::
Tensor
batchnorm_forward_c_last_CUDA
(
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
Tensor
shift
)
{
const
auto
stride
=
input
.
size
(
input
.
ndimension
()
-
1
);
const
auto
reduction_size
=
input
.
numel
()
/
stride
;
at
::
Tensor
out
=
at
::
empty_like
(
input
);
dim3
block
;
dim3
grid
;
flexible_launch_configs
(
reduction_size
,
stride
,
block
,
grid
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_forward_c_last_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
data
<
accscalar_t
>
(),
shift
.
data
<
accscalar_t
>
(),
out
.
data
<
scalar_t
>
(),
reduction_size
,
stride
);
}));
}
else
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_forward_c_last_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
,
ELEMENTS_PER_ITER
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
data
<
scalar_t
>
(),
shift
.
data
<
scalar_t
>
(),
out
.
data
<
scalar_t
>
(),
reduction_size
,
stride
);
}));
}
return
out
;
}
std
::
vector
<
at
::
Tensor
>
reduce_bn_c_last_CUDA
(
const
at
::
Tensor
grad_output
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
)
{
const
auto
stride
=
input
.
size
(
input
.
ndimension
()
-
1
);
const
auto
reduction_size
=
input
.
numel
()
/
stride
;
at
::
Tensor
mean_dy
=
at
::
empty
({
stride
},
mean
.
options
());
at
::
Tensor
mean_dy_xmu
=
at
::
empty
({
stride
},
mean
.
options
());
at
::
Tensor
grad_weight
=
at
::
empty
({
stride
},
weight
.
options
());
at
::
Tensor
grad_bias
=
at
::
empty
({
stride
},
weight
.
options
());
dim3
block
;
dim3
grid
;
flexible_launch_configs
(
reduction_size
,
stride
,
block
,
grid
,
true
);
at
::
Tensor
staging_data
;
at
::
Tensor
semaphores
;
if
(
grid
.
y
>
1
)
{
staging_data
=
at
::
empty
({
2
*
stride
*
grid
.
y
},
mean
.
options
());
semaphores
=
at
::
zeros
({
grid
.
x
},
input
.
options
().
dtype
(
at
::
ScalarType
::
Int
));
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward_reduce"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
accscalar_t
*
staging_data_ptr
=
grid
.
y
>
1
?
staging_data
.
data
<
accscalar_t
>
()
:
nullptr
;
int
*
semaphores_ptr
=
grid
.
y
>
1
?
semaphores
.
data
<
int
>
()
:
nullptr
;
reduce_bn_c_last_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
grad_output
.
data
<
scalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_weight
.
data
<
accscalar_t
>
(),
grad_bias
.
data
<
accscalar_t
>
(),
staging_data_ptr
,
semaphores_ptr
,
reduction_size
,
stride
);
}));
}
else
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward_reduce"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
accscalar_t
*
staging_data_ptr
=
grid
.
y
>
1
?
staging_data
.
data
<
accscalar_t
>
()
:
nullptr
;
int
*
semaphores_ptr
=
grid
.
y
>
1
?
semaphores
.
data
<
int
>
()
:
nullptr
;
reduce_bn_c_last_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
,
ELEMENTS_PER_ITER
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
grad_output
.
data
<
scalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_weight
.
data
<
scalar_t
>
(),
grad_bias
.
data
<
scalar_t
>
(),
staging_data_ptr
,
semaphores_ptr
,
reduction_size
,
stride
);
}));
}
return
{
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
};
}
at
::
Tensor
batchnorm_backward_c_last_CUDA
(
const
at
::
Tensor
grad_output
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
Tensor
mean_dy
,
const
at
::
Tensor
mean_dy_xmu
)
{
const
auto
stride
=
input
.
size
(
input
.
ndimension
()
-
1
);
const
auto
reduction_size
=
input
.
numel
()
/
stride
;
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
dim3
block
;
dim3
grid
;
flexible_launch_configs
(
reduction_size
,
stride
,
block
,
grid
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_backward_c_last_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
<<<
grid
,
block
,
0
,
stream
>>>
(
grad_output
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_input
.
data
<
scalar_t
>
(),
reduction_size
,
stride
);
}));
}
else
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_backward_c_last_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
,
ELEMENTS_PER_ITER
>
<<<
grid
,
block
,
0
,
stream
>>>
(
grad_output
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
data
<
scalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_input
.
data
<
scalar_t
>
(),
reduction_size
,
stride
);
}));
}
return
grad_input
;
}
tests/synced_batchnorm/single_gpu_unit_test.py
View file @
438f6f9f
...
...
@@ -54,7 +54,11 @@ m = inp_r.mean(1)
b_v
=
inp_r
.
var
(
1
,
unbiased
=
False
)
unb_v
=
inp_r
.
var
(
1
,
unbiased
=
True
)
mean
,
var
,
var_biased
=
syncbn
.
welford_mean_var
(
inp_t
)
eps
=
1e-5
#mean, var, var_biased = syncbn.welford_mean_var(inp_t)
mean
,
var_biased
=
syncbn
.
welford_mean_var
(
inp_t
)
inv_std
=
1.0
/
torch
.
sqrt
(
var_biased
+
eps
)
bn
=
torch
.
nn
.
BatchNorm2d
(
feature_size
).
cuda
()
bn
.
momentum
=
1.0
...
...
@@ -74,16 +78,25 @@ grad_sbn = grad_output_t.clone().detach()
out_sbn
=
sbn
(
inp_sbn
)
out_sbn
.
backward
(
grad_sbn
)
sbn_c_last
=
apex
.
parallel
.
SyncBatchNorm
(
feature_size
,
channel_last
=
True
).
cuda
()
sbn_c_last
.
momentum
=
1.0
sbn_c_last
.
weight
.
data
=
weight_t
.
clone
()
sbn_c_last
.
bias
.
data
=
bias_t
.
clone
()
inp_sbn_c_last
=
inp_t
.
clone
().
transpose
(
-
1
,
1
).
contiguous
().
requires_grad_
()
grad_sbn_c_last
=
grad_output_t
.
clone
().
transpose
(
-
1
,
1
).
contiguous
().
detach
()
out_sbn_c_last
=
sbn_c_last
(
inp_sbn_c_last
)
out_sbn_c_last
.
backward
(
grad_sbn_c_last
)
sbn_result
=
True
sbn_result_c_last
=
True
bn_result
=
True
sbn_result
=
compare
(
"comparing mean: "
,
mean
,
m
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing variance: "
,
var
,
unb_v
,
error
)
and
sbn_result
#
sbn_result = compare("comparing variance: ", var, unb_v, error) and sbn_result
sbn_result
=
compare
(
"comparing biased variance: "
,
var_biased
,
b_v
,
error
)
and
sbn_result
eps
=
1e-5
out
=
syncbn
.
batchnorm_forward
(
inp_t
,
mean
,
var_biase
d
,
weight_t
,
bias_t
,
eps
)
out
=
syncbn
.
batchnorm_forward
(
inp_t
,
mean
,
inv_st
d
,
weight_t
,
bias_t
)
out_r
=
weight_r
*
(
inp2_r
-
m
.
view
(
-
1
,
1
,
1
))
*
torch
.
rsqrt
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
+
bias_r
sbn_result
=
compare
(
"comparing output: "
,
out
,
out_r
,
error
)
and
sbn_result
...
...
@@ -102,8 +115,8 @@ mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).co
grad_input_r
=
(
grad_output2_r
-
mean_dy_r
.
view
(
-
1
,
1
,
1
)
-
(
inp2_r
-
m
.
view
(
-
1
,
1
,
1
))
/
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
*
mean_dy_xmu_r
.
view
(
-
1
,
1
,
1
)
)
*
torch
.
rsqrt
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
*
weight_r
.
view
(
-
1
,
1
,
1
)
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn
(
grad_output_t
,
inp_t
,
mean
,
var_biase
d
,
weight_t
,
eps
)
grad_input
=
syncbn
.
batchnorm_backward
(
grad_output_t
,
inp_t
,
mean
,
var_biase
d
,
weight_t
,
mean_dy
,
mean_dy_xmu
,
eps
)
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn
(
grad_output_t
,
inp_t
,
mean
,
inv_st
d
,
weight_t
)
grad_input
=
syncbn
.
batchnorm_backward
(
grad_output_t
,
inp_t
,
mean
,
inv_st
d
,
weight_t
,
mean_dy
,
mean_dy_xmu
)
sbn_result
=
compare
(
"comparing bias grad: "
,
grad_bias
,
grad_bias_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing weight grad: "
,
grad_weight
,
grad_weight_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing mean_dy grad: "
,
mean_dy
,
mean_dy_r
,
error
)
and
sbn_result
...
...
@@ -112,7 +125,7 @@ sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error)
compare
(
"comparing bn input grad: "
,
inp_bn
.
grad
,
grad_input_r
,
error
)
sbn_result
=
compare
(
"comparing sbn input grad: "
,
inp_sbn
.
grad
,
grad_input_r
,
error
)
and
sbn_result
compare
(
"comparing output: "
,
out_bn
,
out_sbn
,
error
)
compare
(
"comparing
bn/sbn
output: "
,
out_bn
,
out_sbn
,
error
)
sbn_result
=
compare
(
"comparing running_mean: "
,
bn
.
running_mean
.
data
,
sbn
.
running_mean
.
data
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing running_variance: "
,
bn
.
running_var
.
data
,
sbn
.
running_var
.
data
,
error
)
and
sbn_result
compare
(
"comparing grad_input: "
,
inp_bn
.
grad
,
inp_sbn
.
grad
,
error
)
...
...
@@ -123,7 +136,21 @@ compare("comparing grad_weight: ", bn.weight.grad, sbn.weight.grad, error)
compare
(
"comparing grad_weight bn to ref: "
,
bn
.
weight
.
grad
,
grad_weight_r
,
error
)
sbn_result
=
compare
(
"comparing grad_weight sbn to ref: "
,
sbn
.
weight
.
grad
,
grad_weight_r
,
error
)
and
sbn_result
compare
(
"comparing channel last bn/sbn output: "
,
out_bn
,
out_sbn_c_last
.
transpose
(
-
1
,
1
).
contiguous
(),
error
)
sbn_result_c_last
=
compare
(
"comparing channel last running_mean: "
,
bn
.
running_mean
.
data
,
sbn_c_last
.
running_mean
.
data
,
error
)
and
sbn_result_c_last
sbn_result_c_last
=
compare
(
"comparing channel last running_variance: "
,
bn
.
running_var
.
data
,
sbn_c_last
.
running_var
.
data
,
error
)
and
sbn_result_c_last
compare
(
"comparing channel last grad_input: "
,
inp_bn
.
grad
,
inp_sbn_c_last
.
grad
.
transpose
(
-
1
,
1
).
contiguous
(),
error
)
compare
(
"comparing channel last grad_bias: "
,
bn
.
bias
.
grad
,
sbn_c_last
.
bias
.
grad
,
error
)
sbn_result_c_last
=
compare
(
"comparing channel last grad_bias sbn to ref: "
,
sbn_c_last
.
bias
.
grad
,
grad_bias_r
,
error
)
and
sbn_result_c_last
compare
(
"comparing channel last grad_weight: "
,
bn
.
weight
.
grad
,
sbn_c_last
.
weight
.
grad
,
error
)
sbn_result_c_last
=
compare
(
"comparing channel last grad_weight sbn to ref: "
,
sbn_c_last
.
weight
.
grad
,
grad_weight_r
,
error
)
and
sbn_result_c_last
if
sbn_result
:
print
(
"====SBN single gpu passed tests"
)
else
:
print
(
"*SBN single gpu failed*"
)
if
sbn_result_c_last
:
print
(
"====SBN channel last single gpu passed tests"
)
else
:
print
(
"*SBN channel last single gpu failed*"
)
tests/synced_batchnorm/two_gpu_unit_test.py
View file @
438f6f9f
...
...
@@ -75,7 +75,10 @@ m = inp_r.mean(1)
b_v
=
inp_r
.
var
(
1
,
unbiased
=
False
)
unb_v
=
inp_r
.
var
(
1
,
unbiased
=
True
)
mean
,
var
,
var_biased
=
syncbn
.
welford_mean_var
(
inp_t
)
eps
=
1e-5
mean
,
var_biased
=
syncbn
.
welford_mean_var
(
inp_t
)
inv_std
=
1.0
/
torch
.
sqrt
(
var_biased
+
eps
)
bn
=
torch
.
nn
.
BatchNorm2d
(
feature_size
).
cuda
()
bn
.
momentum
=
1.0
...
...
@@ -111,12 +114,9 @@ bn_result = True
if
args
.
local_rank
==
0
:
sbn_result
=
compare
(
"comparing mean: "
,
mean
,
m
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing variance: "
,
var
,
unb_v
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing biased variance: "
,
var_biased
,
b_v
,
error
)
and
sbn_result
eps
=
1e-5
out
=
syncbn
.
batchnorm_forward
(
inp_t
,
mean
,
var_biased
,
weight_t
,
bias_t
,
eps
)
out
=
syncbn
.
batchnorm_forward
(
inp_t
,
mean
,
inv_std
,
weight_t
,
bias_t
)
out_r
=
weight_r
*
(
inp2_r
-
m
.
view
(
-
1
,
1
,
1
))
*
torch
.
rsqrt
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
+
bias_r
if
args
.
local_rank
==
0
:
...
...
@@ -136,8 +136,8 @@ mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).co
grad_input_r
=
(
grad_output2_r
-
mean_dy_r
.
view
(
-
1
,
1
,
1
)
-
(
inp2_r
-
m
.
view
(
-
1
,
1
,
1
))
/
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
*
mean_dy_xmu_r
.
view
(
-
1
,
1
,
1
)
)
*
torch
.
rsqrt
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
*
weight_r
.
view
(
-
1
,
1
,
1
)
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn
(
grad_output_t
,
inp_t
,
mean
,
var_biase
d
,
weight_t
,
eps
)
grad_input
=
syncbn
.
batchnorm_backward
(
grad_output_t
,
inp_t
,
mean
,
var_biase
d
,
weight_t
,
mean_dy
,
mean_dy_xmu
,
eps
)
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn
(
grad_output_t
,
inp_t
,
mean
,
inv_st
d
,
weight_t
)
grad_input
=
syncbn
.
batchnorm_backward
(
grad_output_t
,
inp_t
,
mean
,
inv_st
d
,
weight_t
,
mean_dy
,
mean_dy_xmu
)
if
args
.
local_rank
==
0
:
sbn_result
=
compare
(
"comparing bias grad: "
,
grad_bias
,
grad_bias_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing weight grad: "
,
grad_weight
,
grad_weight_r
,
error
)
and
sbn_result
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment