Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
0ef439b6
Commit
0ef439b6
authored
Jun 13, 2019
by
Evgeni Krimer
Browse files
update gbn
parent
d68ec712
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
466 additions
and
249 deletions
+466
-249
apex/contrib/csrc/groupbn/batch_norm.cu
apex/contrib/csrc/groupbn/batch_norm.cu
+37
-19
apex/contrib/csrc/groupbn/batch_norm.h
apex/contrib/csrc/groupbn/batch_norm.h
+92
-82
apex/contrib/csrc/groupbn/batch_norm_add_relu.cu
apex/contrib/csrc/groupbn/batch_norm_add_relu.cu
+37
-19
apex/contrib/csrc/groupbn/batch_norm_add_relu.h
apex/contrib/csrc/groupbn/batch_norm_add_relu.h
+67
-67
apex/contrib/csrc/groupbn/cuda_utils.h
apex/contrib/csrc/groupbn/cuda_utils.h
+0
-11
apex/contrib/csrc/groupbn/interface.cpp
apex/contrib/csrc/groupbn/interface.cpp
+33
-8
apex/contrib/csrc/groupbn/ipc.cu
apex/contrib/csrc/groupbn/ipc.cu
+2
-1
apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h
apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h
+128
-15
apex/contrib/groupbn/batch_norm.py
apex/contrib/groupbn/batch_norm.py
+70
-27
No files found.
apex/contrib/csrc/groupbn/batch_norm.cu
View file @
0ef439b6
...
...
@@ -52,16 +52,19 @@ at::Tensor nhwc_bn_fwd_train(
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
)
{
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
const
int
N
=
x
.
size
(
0
);
const
int
H
=
x
.
size
(
1
);
...
...
@@ -116,8 +119,8 @@ at::Tensor nhwc_bn_fwd_train(
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
2
];
void
*
retired_ctas
=
THCudaMalloc
(
at
::
globalContext
().
lazyInitCUDA
(),
retired_cta_bytes
);
cudaMemsetAsync
(
retired_ctas
,
0
,
retired_cta_bytes
,
stream
);
//FIXME: is this legit?
void
*
retired_ctas
=
ret_cta
.
data
<
uint8_t
>
(
);
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
...
...
@@ -127,12 +130,9 @@ at::Tensor nhwc_bn_fwd_train(
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
int
device_id
;
cudaGetDevice
(
&
device_id
);
// Don't fuse in ReLU for now at least
bn
->
fwd
(
stream
,
fuse_relu
,
device_id
,
my_data
,
pair_data
,
pair_data2
,
bn_grou
p
,
*
magic
,
max_cta_per_sm
,
cta_la
unch_margin
);
bn
->
fwd
(
stream
,
fuse_relu
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
*
magic
,
occupancy
,
grid_dim_x
,
coop
);
THCudaFree
(
at
::
globalContext
().
lazyInitCUDA
(),
retired_ctas
);
return
y
;
}
...
...
@@ -142,6 +142,7 @@ at::Tensor nhwc_bn_fwd_eval(
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
int
bn_group
,
const
float
momentum
,
const
float
epsilon
,
...
...
@@ -196,8 +197,8 @@ at::Tensor nhwc_bn_fwd_eval(
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
2
];
void
*
retired_ctas
=
THCudaMalloc
(
at
::
globalContext
().
lazyInitCUDA
(),
retired_cta_bytes
);
cudaMemsetAsync
(
retired_ctas
,
0
,
retired_cta_bytes
,
stream
);
//FIXME: is this legit?
void
*
retired_ctas
=
ret_cta
.
data
<
uint8_t
>
(
);
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
...
...
@@ -210,7 +211,6 @@ at::Tensor nhwc_bn_fwd_eval(
// Don't fuse in ReLU for now at least
bn
->
fwdInference
(
stream
,
fuse_relu
);
THCudaFree
(
at
::
globalContext
().
lazyInitCUDA
(),
retired_ctas
);
return
y
;
}
...
...
@@ -224,16 +224,19 @@ std::vector<at::Tensor> nhwc_bn_bwd(
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
)
{
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
// shape
const
int
N
=
x
.
size
(
0
);
const
int
H
=
x
.
size
(
1
);
...
...
@@ -293,8 +296,8 @@ std::vector<at::Tensor> nhwc_bn_bwd(
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
2
];
void
*
retired_ctas
=
THCudaMalloc
(
at
::
globalContext
().
lazyInitCUDA
(),
retired_cta_bytes
);
cudaMemsetAsync
(
retired_ctas
,
0
,
retired_cta_bytes
,
stream
);
//FIXME: is this legit?
void
*
retired_ctas
=
ret_cta
.
data
<
uint8_t
>
(
);
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
...
...
@@ -304,10 +307,25 @@ std::vector<at::Tensor> nhwc_bn_bwd(
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
int
device_id
;
cudaGetDevice
(
&
device_id
);
bn
->
dgrad
(
stream
,
fuse_relu
,
device_id
,
my_data
,
pair_data
,
pair_data2
,
bn_group
,
*
magic
,
max_cta_per_sm
,
cta_launch_margin
);
bn
->
dgrad
(
stream
,
fuse_relu
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
*
magic
,
occupancy
,
grid_dim_x
,
coop
);
THCudaFree
(
at
::
globalContext
().
lazyInitCUDA
(),
retired_ctas
);
return
std
::
vector
<
at
::
Tensor
>
{
x_grad
,
scale_grad
,
bias_grad
};
}
int
nhwc_bn_fwd_occupancy
()
{
int
device_id
=-
1
;
cudaGetDevice
(
&
device_id
);
//max occupancy supported by the code is 2
return
NhwcBatchNorm
::
smem_driven_fwd_occupancy
(
device_id
,
2
);
}
int
nhwc_bn_bwd_occupancy
()
{
int
device_id
=-
1
;
cudaGetDevice
(
&
device_id
);
//max occupancy supported by the code is 2
return
NhwcBatchNorm
::
smem_driven_bwd_occupancy
(
device_id
,
2
);
}
apex/contrib/csrc/groupbn/batch_norm.h
View file @
0ef439b6
...
...
@@ -56,11 +56,11 @@ class NhwcBatchNorm {
exit
(
-
1
);
}
void
fwd
(
cudaStream_t
stream
,
bool
use_relu
,
int
device_id
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
const
int
bn_group
,
const
int
ma
gic
,
const
int
max_cta_per_sm
,
const
int
ct
a_launch_margin
);
void
dgrad
(
cudaStream_t
stream
,
bool
use_relu
,
int
device_id
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
const
int
bn_group
,
const
int
ma
gic
,
const
int
max_cta_per_sm
,
const
int
ct
a_launch_margin
);
void
fwd
(
cudaStream_t
stream
,
bool
use_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
void
dgrad
(
cudaStream_t
stream
,
bool
use_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
void
fwdInference
(
cudaStream_t
stream
,
bool
use_relu
);
dim3
calc_fwd_grid
(
int
device_id
,
int
*
loop
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
);
dim3
calc_bwd_grid
(
int
device_id
,
int
*
loop
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
);
dim3
calc_fwd_grid
(
int
*
loop
,
const
int
grid_dim_x
);
dim3
calc_bwd_grid
(
int
*
loop
,
const
int
grid_dim_x
);
void
setInputDescriptor
(
const
cudnnTensorFormat_t
format
,
const
cudnnDataType_t
data_type
,
...
...
@@ -256,8 +256,9 @@ class NhwcBatchNorm {
// version that was compiled with that occupancy in its launch bounds. This way, we avoid
// needless register spills.
void
_fwdKernelLauncher
(
cudaStream_t
stream
,
NhwcBatchNormFwdParams
params
,
dim3
grid_dim
,
int
outer_loops
,
bool
use_relu
,
int
device_id
,
const
int
max_cta_per_sm
)
{
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY) \
dim3
grid_dim
,
int
outer_loops
,
bool
use_relu
,
const
int
occupancy
,
const
bool
coop
)
{
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
...
...
@@ -289,37 +290,45 @@ class NhwcBatchNorm {
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
if (COOP) { \
cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
cudaLaunchKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
int
occupancy
=
smem_driven_fwd_occupancy
(
device_id
,
max_cta_per_sm
);
if
(
outer_loops
==
1
&&
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
1
,
true
,
false
,
2
);
LAUNCH_FWD_KERNEL
(
1
,
true
,
false
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
1
,
true
,
false
,
1
);
LAUNCH_FWD_KERNEL
(
1
,
true
,
false
,
1
,
coop
);
}
else
if
(
outer_loops
==
1
&&
!
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
1
,
false
,
false
,
2
);
LAUNCH_FWD_KERNEL
(
1
,
false
,
false
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
1
,
false
,
false
,
1
);
LAUNCH_FWD_KERNEL
(
1
,
false
,
false
,
1
,
coop
);
}
else
if
(
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
0
,
true
,
false
,
2
);
LAUNCH_FWD_KERNEL
(
0
,
true
,
false
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
0
,
true
,
false
,
1
);
LAUNCH_FWD_KERNEL
(
0
,
true
,
false
,
1
,
coop
);
}
else
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
0
,
false
,
false
,
2
);
LAUNCH_FWD_KERNEL
(
0
,
false
,
false
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
0
,
false
,
false
,
1
);
LAUNCH_FWD_KERNEL
(
0
,
false
,
false
,
1
,
coop
);
}
#undef LAUNCH_FWD_KERNEL
}
...
...
@@ -327,8 +336,8 @@ class NhwcBatchNorm {
// Helper function to launch the backward kernel.
void
_bwdKernelLauncher
(
cudaStream_t
stream
,
NhwcBatchNormBwdParams
params
,
dim3
grid_dim
,
int
outer_loops
,
bool
use_relu
,
int
device_id
,
const
int
max_cta_per_sm
)
{
#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY) \
dim3
grid_dim
,
int
outer_loops
,
bool
use_relu
,
const
int
occupancy
,
const
bool
coop
)
{
#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY
, COOP
) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_func = nhwc_batch_norm_bwd< \
...
...
@@ -356,16 +365,25 @@ class NhwcBatchNorm {
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
cudaLaunchCooperativeKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
if (COOP) { \
cudaLaunchCooperativeKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
cudaLaunchKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd coop serial kernel"); \
} while (0)
#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY) \
#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY
, COOP
) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_relu_func = nhwc_batch_norm_bwd_relu< \
...
...
@@ -393,62 +411,50 @@ class NhwcBatchNorm {
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
cudaLaunchCooperativeKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
if (COOP) { \
cudaLaunchCooperativeKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
cudaLaunchKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
int
occupancy
=
smem_driven_bwd_occupancy
(
device_id
,
max_cta_per_sm
);
if
(
outer_loops
==
1
&&
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_RELU_KERNEL
(
1
,
2
);
LAUNCH_BWD_RELU_KERNEL
(
1
,
2
,
coop
);
else
LAUNCH_BWD_RELU_KERNEL
(
1
,
1
);
LAUNCH_BWD_RELU_KERNEL
(
1
,
1
,
coop
);
}
else
if
(
outer_loops
==
1
&&
!
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_KERNEL
(
1
,
2
);
LAUNCH_BWD_KERNEL
(
1
,
2
,
coop
);
else
LAUNCH_BWD_KERNEL
(
1
,
1
);
LAUNCH_BWD_KERNEL
(
1
,
1
,
coop
);
}
else
if
(
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_RELU_KERNEL
(
0
,
2
);
LAUNCH_BWD_RELU_KERNEL
(
0
,
2
,
coop
);
else
LAUNCH_BWD_RELU_KERNEL
(
0
,
1
);
LAUNCH_BWD_RELU_KERNEL
(
0
,
1
,
coop
);
}
else
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_KERNEL
(
0
,
2
);
LAUNCH_BWD_KERNEL
(
0
,
2
,
coop
);
else
LAUNCH_BWD_KERNEL
(
0
,
1
);
LAUNCH_BWD_KERNEL
(
0
,
1
,
coop
);
}
#undef LAUNCH_BWD_KERNEL
}
private:
// Calculate the max number of CTAs allowed in the grid for the fwd kernel.
static
size_t
max_fwd_grid_x
(
int
device_id
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
)
{
using
namespace
at
::
cuda
::
utils
;
int
answer
=
MultiprocessorCount
(
device_id
)
*
smem_driven_fwd_occupancy
(
device_id
,
max_cta_per_sm
);
if
(
SMArch
(
device_id
)
>=
70
)
answer
-=
cta_launch_margin
;
answer
=
std
::
max
(
1
,
answer
);
// we need at least one CTA to operate
return
static_cast
<
size_t
>
(
answer
);
}
// Calculate the max number of CTAs allowed in the grid for the bwd kernel.
static
size_t
max_bwd_grid_x
(
int
device_id
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
)
{
using
namespace
at
::
cuda
::
utils
;
int
answer
=
MultiprocessorCount
(
device_id
)
*
smem_driven_bwd_occupancy
(
device_id
,
max_cta_per_sm
);
if
(
SMArch
(
device_id
)
>=
70
)
answer
-=
cta_launch_margin
;
answer
=
std
::
max
(
1
,
answer
);
// we need at least one CTA to operate
return
static_cast
<
size_t
>
(
answer
);
}
public:
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static
int
smem_driven_fwd_occupancy
(
int
device_id
,
const
int
max_cta_per_sm
)
{
...
...
@@ -603,11 +609,11 @@ void NhwcBatchNorm::fwdInference(cudaStream_t stream, bool use_relu) {
}
}
dim3
NhwcBatchNorm
::
calc_fwd_grid
(
int
device_id
,
int
*
loop
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
)
{
dim3
NhwcBatchNorm
::
calc_fwd_grid
(
int
*
loop
,
const
int
grid_dim_x
)
{
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_FWD
);
int
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
unsigned
int
max_grid_x
=
max_fwd_grid_x
(
device_id
,
max_cta_per_sm
,
cta_launch_margin
)
;
unsigned
int
max_grid_x
=
grid_dim_x
;
if
(
grid_dim
.
x
<=
max_grid_x
)
{
*
loop
=
1
;
if
(
max_grid_x
/
grid_dim
.
x
>
1
)
{
...
...
@@ -626,11 +632,11 @@ dim3 NhwcBatchNorm::calc_fwd_grid(int device_id, int *loop, const int max_cta_pe
return
grid_dim
;
}
dim3
NhwcBatchNorm
::
calc_bwd_grid
(
int
device_id
,
int
*
loop
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
)
{
dim3
NhwcBatchNorm
::
calc_bwd_grid
(
int
*
loop
,
const
int
grid_dim_x
)
{
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_BWD
);
int
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
unsigned
int
max_grid_x
=
max_bwd_grid_x
(
device_id
,
max_cta_per_sm
,
cta_launch_margin
)
;
unsigned
int
max_grid_x
=
grid_dim_x
;
if
(
grid_dim
.
x
<=
max_grid_x
)
{
*
loop
=
1
;
if
(
max_grid_x
/
grid_dim
.
x
>
1
)
{
...
...
@@ -649,7 +655,8 @@ dim3 NhwcBatchNorm::calc_bwd_grid(int device_id, int *loop, const int max_cta_pe
return
grid_dim
;
}
void
NhwcBatchNorm
::
fwd
(
cudaStream_t
stream
,
bool
use_relu
,
int
device_id
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
const
int
bn_group
,
const
int
magic
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
)
{
void
NhwcBatchNorm
::
fwd
(
cudaStream_t
stream
,
bool
use_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
...
...
@@ -677,16 +684,18 @@ void NhwcBatchNorm::fwd(cudaStream_t stream, bool use_relu, int device_id, void*
NhwcBatchNormFwdParams
params
;
_setFwdParams
(
&
params
);
params
.
my_data
=
my_data
;
params
.
pair_data
=
pair_data
;
params
.
pair_data2
=
pair_data2
;
params
.
pair_datas
[
0
]
=
pair_data
;
params
.
pair_datas
[
1
]
=
pair_data2
;
params
.
pair_datas
[
2
]
=
pair_data3
;
params
.
magic
=
magic
;
params
.
sync_iters
=
bn_group
>>
1
;
params
.
sync_iters
=
(
bn_group
==
8
)
?
3
:
(
bn_group
>>
1
)
;
dim3
grid_dim
=
calc_fwd_grid
(
device_id
,
&
params
.
outer_loops
,
max_cta_per_sm
,
cta_launch_margin
);
_fwdKernelLauncher
(
stream
,
params
,
grid_dim
,
params
.
outer_loops
,
use_relu
,
device_id
,
max_cta_per_sm
);
dim3
grid_dim
=
calc_fwd_grid
(
&
params
.
outer_loops
,
grid_dim_x
);
_fwdKernelLauncher
(
stream
,
params
,
grid_dim
,
params
.
outer_loops
,
use_relu
,
occupancy
,
coop
);
}
void
NhwcBatchNorm
::
dgrad
(
cudaStream_t
stream
,
bool
use_relu
,
int
device_id
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
const
int
bn_group
,
const
int
magic
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
)
{
void
NhwcBatchNorm
::
dgrad
(
cudaStream_t
stream
,
bool
use_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
...
...
@@ -711,14 +720,15 @@ void NhwcBatchNorm::dgrad(cudaStream_t stream, bool use_relu, int device_id, voi
NhwcBatchNormBwdParams
params
;
_setBwdParams
(
&
params
);
params
.
my_data
=
my_data
;
params
.
pair_data
=
pair_data
;
params
.
pair_data2
=
pair_data2
;
params
.
pair_datas
[
0
]
=
pair_data
;
params
.
pair_datas
[
1
]
=
pair_data2
;
params
.
pair_datas
[
2
]
=
pair_data3
;
params
.
magic
=
magic
;
params
.
sync_iters
=
bn_group
>>
1
;
params
.
sync_iters
=
(
bn_group
==
8
)
?
3
:
(
bn_group
>>
1
)
;
params
.
wgrad_coeff
=
1.0
/
bn_group
;
dim3
grid_dim
=
calc_bwd_grid
(
device_id
,
&
params
.
outer_loops
,
max_cta_per_sm
,
cta_launch_margin
);
_bwdKernelLauncher
(
stream
,
params
,
grid_dim
,
params
.
outer_loops
,
use_relu
,
device_id
,
max_cta_per_sm
);
dim3
grid_dim
=
calc_bwd_grid
(
&
params
.
outer_loops
,
grid_dim_x
);
_bwdKernelLauncher
(
stream
,
params
,
grid_dim
,
params
.
outer_loops
,
use_relu
,
occupancy
,
coop
);
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
apex/contrib/csrc/groupbn/batch_norm_add_relu.cu
View file @
0ef439b6
...
...
@@ -55,15 +55,18 @@ at::Tensor nhwc_bn_addrelu_fwd_train(
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
bitmask
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
)
{
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
const
int
N
=
x
.
size
(
0
);
const
int
H
=
x
.
size
(
1
);
...
...
@@ -121,8 +124,9 @@ at::Tensor nhwc_bn_addrelu_fwd_train(
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
3
];
void
*
retired_ctas
=
THCudaMalloc
(
at
::
globalContext
().
lazyInitCUDA
(),
retired_cta_bytes
);
cudaMemsetAsync
(
retired_ctas
,
0
,
retired_cta_bytes
,
stream
);
//FIXME: is this legit?
void
*
retired_ctas
=
ret_cta
.
data
<
uint8_t
>
();
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
...
...
@@ -132,12 +136,9 @@ at::Tensor nhwc_bn_addrelu_fwd_train(
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
int
device_id
;
cudaGetDevice
(
&
device_id
);
// Don't fuse in ReLU for now at least
bn
->
fwd
(
stream
,
device_id
,
my_data
,
pair_data
,
pair_data2
,
bn_grou
p
,
*
magic
,
max_cta_per_sm
,
cta_la
unch_margin
);
bn
->
fwd
(
stream
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
*
magic
,
occupancy
,
grid_dim_x
,
coop
);
THCudaFree
(
at
::
globalContext
().
lazyInitCUDA
(),
retired_ctas
);
return
y
;
}
...
...
@@ -148,6 +149,7 @@ at::Tensor nhwc_bn_addrelu_fwd_eval(
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
int
bn_group
,
const
float
momentum
,
const
float
epsilon
)
{
...
...
@@ -204,8 +206,8 @@ at::Tensor nhwc_bn_addrelu_fwd_eval(
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
3
];
void
*
retired_ctas
=
THCudaMalloc
(
at
::
globalContext
().
lazyInitCUDA
(),
retired_cta_bytes
);
cudaMemsetAsync
(
retired_ctas
,
0
,
retired_cta_bytes
,
stream
);
//FIXME: is this legit?
void
*
retired_ctas
=
ret_cta
.
data
<
uint8_t
>
(
);
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
...
...
@@ -218,7 +220,6 @@ at::Tensor nhwc_bn_addrelu_fwd_eval(
// Don't fuse in ReLU for now at least
bn
->
fwdInference
(
stream
);
THCudaFree
(
at
::
globalContext
().
lazyInitCUDA
(),
retired_ctas
);
return
y
;
}
...
...
@@ -233,15 +234,18 @@ std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
bitmask
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
)
{
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
// shape
const
int
N
=
x
.
size
(
0
);
const
int
H
=
x
.
size
(
1
);
...
...
@@ -305,8 +309,8 @@ std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
3
];
void
*
retired_ctas
=
THCudaMalloc
(
at
::
globalContext
().
lazyInitCUDA
(),
retired_cta_bytes
);
cudaMemsetAsync
(
retired_ctas
,
0
,
retired_cta_bytes
,
stream
);
//FIXME: is this legit?
void
*
retired_ctas
=
ret_cta
.
data
<
uint8_t
>
(
);
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
...
...
@@ -316,10 +320,24 @@ std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
int
device_id
;
cudaGetDevice
(
&
device_id
);
bn
->
dgrad
(
stream
,
device_id
,
my_data
,
pair_data
,
pair_data2
,
bn_group
,
*
magic
,
max_cta_per_sm
,
cta_launch_margin
);
bn
->
dgrad
(
stream
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
*
magic
,
occupancy
,
grid_dim_x
,
coop
);
THCudaFree
(
at
::
globalContext
().
lazyInitCUDA
(),
retired_ctas
);
return
std
::
vector
<
at
::
Tensor
>
{
x_grad
,
z_grad
,
scale_grad
,
bias_grad
};
}
int
nhwc_bn_addrelu_fwd_occupancy
()
{
int
device_id
=-
1
;
cudaGetDevice
(
&
device_id
);
//max occupancy supported by the code is 2
return
NhwcBatchNormAddRelu
::
smem_driven_fwd_occupancy
(
device_id
,
2
);
}
int
nhwc_bn_addrelu_bwd_occupancy
()
{
int
device_id
=-
1
;
cudaGetDevice
(
&
device_id
);
//max occupancy supported by the code is 2
return
NhwcBatchNormAddRelu
::
smem_driven_bwd_occupancy
(
device_id
,
2
);
}
apex/contrib/csrc/groupbn/batch_norm_add_relu.h
View file @
0ef439b6
...
...
@@ -56,11 +56,11 @@ class NhwcBatchNormAddRelu {
exit
(
-
1
);
}
void
fwd
(
cudaStream_t
stream
,
int
device_id
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
const
int
bn_group
,
const
int
ma
gic
,
const
int
max_cta_per_sm
,
const
int
ct
a_launch_margin
);
void
dgrad
(
cudaStream_t
stream
,
int
device_id
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
const
int
bn_group
,
const
int
ma
gic
,
const
int
max_cta_per_sm
,
const
int
ct
a_launch_margin
);
void
fwd
(
cudaStream_t
stream
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
void
dgrad
(
cudaStream_t
stream
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
void
fwdInference
(
cudaStream_t
stream
);
dim3
calc_fwd_grid
(
int
device_id
,
int
*
loop
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
);
dim3
calc_bwd_grid
(
int
device_id
,
int
*
loop
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
);
dim3
calc_fwd_grid
(
int
*
loop
,
const
int
grid_dim_x
);
dim3
calc_bwd_grid
(
int
*
loop
,
const
int
grid_dim_x
);
void
setInputDescriptor
(
const
cudnnTensorFormat_t
format
,
const
cudnnDataType_t
data_type
,
...
...
@@ -260,8 +260,8 @@ class NhwcBatchNormAddRelu {
// version that was compiled with that occupancy in its launch bounds. This way, we avoid
// needless register spills.
void
_fwdKernelLauncher
(
cudaStream_t
stream
,
NhwcBatchNormFwdParams
params
,
dim3
grid_dim
,
int
outer_loops
,
int
device_id
,
const
int
max_cta_per_sm
)
{
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY) \
dim3
grid_dim
,
int
outer_loops
,
const
int
occupancy
,
const
bool
coop
)
{
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY
, COOP
) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
...
...
@@ -294,27 +294,35 @@ class NhwcBatchNormAddRelu {
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
if (COOP) { \
cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
cudaLaunchKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
int
occupancy
=
smem_driven_fwd_occupancy
(
device_id
,
max_cta_per_sm
);
if
(
outer_loops
==
1
)
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
1
,
false
,
true
,
2
);
LAUNCH_FWD_KERNEL
(
1
,
false
,
true
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
1
,
false
,
true
,
1
);
LAUNCH_FWD_KERNEL
(
1
,
false
,
true
,
1
,
coop
);
}
else
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
0
,
false
,
true
,
2
);
LAUNCH_FWD_KERNEL
(
0
,
false
,
true
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
0
,
false
,
true
,
1
);
LAUNCH_FWD_KERNEL
(
0
,
false
,
true
,
1
,
coop
);
}
#undef LAUNCH_FWD_KERNEL
}
...
...
@@ -322,8 +330,8 @@ class NhwcBatchNormAddRelu {
// Helper function to launch the backward kernel.
void
_bwdKernelLauncher
(
cudaStream_t
stream
,
NhwcBatchNormBwdParams
params
,
dim3
grid_dim
,
int
outer_loops
,
int
device_id
,
const
int
max_cta_per_sm
)
{
#define LAUNCH_BWD_ADD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY) \
dim3
grid_dim
,
int
outer_loops
,
const
int
occupancy
,
const
bool
coop
)
{
#define LAUNCH_BWD_ADD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY
, COOP
) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
...
...
@@ -354,52 +362,40 @@ class NhwcBatchNormAddRelu {
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
cudaLaunchCooperativeKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
if (COOP) { \
cudaLaunchCooperativeKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
cudaLaunchKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-add-relu coop serial kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
int
occupancy
=
smem_driven_bwd_occupancy
(
device_id
,
max_cta_per_sm
);
if
(
outer_loops
==
1
)
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_ADD_RELU_KERNEL
(
1
,
2
);
LAUNCH_BWD_ADD_RELU_KERNEL
(
1
,
2
,
coop
);
else
LAUNCH_BWD_ADD_RELU_KERNEL
(
1
,
1
);
LAUNCH_BWD_ADD_RELU_KERNEL
(
1
,
1
,
coop
);
}
else
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_ADD_RELU_KERNEL
(
0
,
2
);
LAUNCH_BWD_ADD_RELU_KERNEL
(
0
,
2
,
coop
);
else
LAUNCH_BWD_ADD_RELU_KERNEL
(
0
,
1
);
LAUNCH_BWD_ADD_RELU_KERNEL
(
0
,
1
,
coop
);
}
#undef LAUNCH_BWD_KERNEL
}
private:
// Calculate the max number of CTAs allowed in the grid for the fwd kernel.
static
size_t
max_fwd_grid_x
(
int
device_id
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
)
{
using
namespace
at
::
cuda
::
utils
;
int
answer
=
MultiprocessorCount
(
device_id
)
*
smem_driven_fwd_occupancy
(
device_id
,
max_cta_per_sm
);
if
(
SMArch
(
device_id
)
>=
70
)
answer
-=
cta_launch_margin
;
answer
=
std
::
max
(
1
,
answer
);
// we need at least one CTA to operate
return
static_cast
<
size_t
>
(
answer
);
}
// Calculate the max number of CTAs allowed in the grid for the bwd kernel.
static
size_t
max_bwd_grid_x
(
int
device_id
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
)
{
using
namespace
at
::
cuda
::
utils
;
int
answer
=
MultiprocessorCount
(
device_id
)
*
smem_driven_bwd_occupancy
(
device_id
,
max_cta_per_sm
);
if
(
SMArch
(
device_id
)
>=
70
)
answer
-=
cta_launch_margin
;
answer
=
std
::
max
(
1
,
answer
);
// we need at least one CTA to operate
return
static_cast
<
size_t
>
(
answer
);
}
public:
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static
int
smem_driven_fwd_occupancy
(
int
device_id
,
const
int
max_cta_per_sm
)
{
using
namespace
at
::
cuda
::
utils
;
...
...
@@ -553,11 +549,11 @@ void NhwcBatchNormAddRelu::fwdInference(cudaStream_t stream) {
checkCudaStatus
(
name_
+
" fwd_inference-relu kernel"
);
}
dim3
NhwcBatchNormAddRelu
::
calc_fwd_grid
(
int
device_id
,
int
*
loop
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
)
{
dim3
NhwcBatchNormAddRelu
::
calc_fwd_grid
(
int
*
loop
,
const
int
grid_dim_x
)
{
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_FWD
);
int
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
unsigned
int
max_grid_x
=
max_fwd_grid_x
(
device_id
,
max_cta_per_sm
,
cta_launch_margin
)
;
unsigned
int
max_grid_x
=
grid_dim_x
;
if
(
grid_dim
.
x
<=
max_grid_x
)
{
*
loop
=
1
;
if
(
max_grid_x
/
grid_dim
.
x
>
1
)
{
...
...
@@ -576,11 +572,11 @@ dim3 NhwcBatchNormAddRelu::calc_fwd_grid(int device_id, int *loop, const int max
return
grid_dim
;
}
dim3
NhwcBatchNormAddRelu
::
calc_bwd_grid
(
int
device_id
,
int
*
loop
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
)
{
dim3
NhwcBatchNormAddRelu
::
calc_bwd_grid
(
int
*
loop
,
const
int
grid_dim_x
)
{
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_BWD
);
int
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
unsigned
int
max_grid_x
=
max_bwd_grid_x
(
device_id
,
max_cta_per_sm
,
cta_launch_margin
)
;
unsigned
int
max_grid_x
=
grid_dim_x
;
if
(
grid_dim
.
x
<=
max_grid_x
)
{
*
loop
=
1
;
if
(
max_grid_x
/
grid_dim
.
x
>
1
)
{
...
...
@@ -599,7 +595,8 @@ dim3 NhwcBatchNormAddRelu::calc_bwd_grid(int device_id, int *loop, const int max
return
grid_dim
;
}
void
NhwcBatchNormAddRelu
::
fwd
(
cudaStream_t
stream
,
int
device_id
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
const
int
bn_group
,
const
int
magic
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
)
{
void
NhwcBatchNormAddRelu
::
fwd
(
cudaStream_t
stream
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
...
...
@@ -630,16 +627,18 @@ void NhwcBatchNormAddRelu::fwd(cudaStream_t stream, int device_id, void* my_data
_setFwdParams
(
&
params
);
params
.
my_data
=
my_data
;
params
.
pair_data
=
pair_data
;
params
.
pair_data2
=
pair_data2
;
params
.
pair_datas
[
0
]
=
pair_data
;
params
.
pair_datas
[
1
]
=
pair_data2
;
params
.
pair_datas
[
2
]
=
pair_data3
;
params
.
magic
=
magic
;
params
.
sync_iters
=
bn_group
>>
1
;
params
.
sync_iters
=
(
bn_group
==
8
)
?
3
:
(
bn_group
>>
1
)
;
dim3
grid_dim
=
calc_fwd_grid
(
device_id
,
&
params
.
outer_loops
,
max_cta_per_sm
,
cta_launch_margin
);
_fwdKernelLauncher
(
stream
,
params
,
grid_dim
,
params
.
outer_loops
,
device_id
,
max_cta_per_sm
);
dim3
grid_dim
=
calc_fwd_grid
(
&
params
.
outer_loops
,
grid_dim_x
);
_fwdKernelLauncher
(
stream
,
params
,
grid_dim
,
params
.
outer_loops
,
occupancy
,
coop
);
}
void
NhwcBatchNormAddRelu
::
dgrad
(
cudaStream_t
stream
,
int
device_id
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
const
int
bn_group
,
const
int
magic
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
)
{
void
NhwcBatchNormAddRelu
::
dgrad
(
cudaStream_t
stream
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
...
...
@@ -668,14 +667,15 @@ void NhwcBatchNormAddRelu::dgrad(cudaStream_t stream, int device_id, void* my_da
_setBwdParams
(
&
params
);
params
.
my_data
=
my_data
;
params
.
pair_data
=
pair_data
;
params
.
pair_data2
=
pair_data2
;
params
.
pair_datas
[
0
]
=
pair_data
;
params
.
pair_datas
[
1
]
=
pair_data2
;
params
.
pair_datas
[
2
]
=
pair_data3
;
params
.
magic
=
magic
;
params
.
sync_iters
=
bn_group
>>
1
;
params
.
sync_iters
=
(
bn_group
==
8
)
?
3
:
(
bn_group
>>
1
)
;
params
.
wgrad_coeff
=
1.0
/
bn_group
;
dim3
grid_dim
=
calc_bwd_grid
(
device_id
,
&
params
.
outer_loops
,
max_cta_per_sm
,
cta_launch_margin
);
_bwdKernelLauncher
(
stream
,
params
,
grid_dim
,
params
.
outer_loops
,
device_id
,
max_cta_per_sm
);
dim3
grid_dim
=
calc_bwd_grid
(
&
params
.
outer_loops
,
grid_dim_x
);
_bwdKernelLauncher
(
stream
,
params
,
grid_dim
,
params
.
outer_loops
,
occupancy
,
coop
);
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
apex/contrib/csrc/groupbn/cuda_utils.h
View file @
0ef439b6
...
...
@@ -7,17 +7,6 @@ namespace cuda {
namespace
utils
{
//eventually should be replaced by real query functions
static
inline
int
MultiprocessorCount
(
int
device_id
)
{
return
getDeviceProperties
(
device_id
)
->
multiProcessorCount
;
}
static
inline
int
SMArch
(
int
device_id
)
{
auto
device_property
=
getDeviceProperties
(
device_id
);
int
cc
=
device_property
->
major
*
10
+
device_property
->
minor
;
return
cc
;
}
static
inline
int
MaxSharedMemoryPerMultiprocessor
(
int
device_id
)
{
return
getDeviceProperties
(
device_id
)
->
sharedMemPerMultiprocessor
;
}
...
...
apex/contrib/csrc/groupbn/interface.cpp
View file @
0ef439b6
...
...
@@ -36,16 +36,19 @@ at::Tensor nhwc_bn_fwd_train(
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
);
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
at
::
Tensor
nhwc_bn_fwd_eval
(
const
at
::
Tensor
&
x
,
...
...
@@ -53,6 +56,7 @@ at::Tensor nhwc_bn_fwd_eval(
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
int
bn_group
,
const
float
momentum
,
const
float
epsilon
,
...
...
@@ -67,16 +71,19 @@ std::vector<at::Tensor> nhwc_bn_bwd(
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
);
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
at
::
Tensor
nhwc_bn_addrelu_fwd_train
(
const
at
::
Tensor
&
x
,
...
...
@@ -88,15 +95,18 @@ at::Tensor nhwc_bn_addrelu_fwd_train(
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
bitmask
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
);
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
at
::
Tensor
nhwc_bn_addrelu_fwd_eval
(
const
at
::
Tensor
&
x
,
...
...
@@ -105,6 +115,7 @@ at::Tensor nhwc_bn_addrelu_fwd_eval(
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
int
bn_group
,
const
float
momentum
,
const
float
epsilon
);
...
...
@@ -119,16 +130,24 @@ std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
bitmask
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
max_cta_per_sm
,
const
int
cta_launch_margin
);
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
int
nhwc_bn_fwd_occupancy
();
int
nhwc_bn_bwd_occupancy
();
int
nhwc_bn_addrelu_fwd_occupancy
();
int
nhwc_bn_addrelu_bwd_occupancy
();
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
...
@@ -141,8 +160,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"bn_fwd_eval_nhwc"
,
&
nhwc_bn_fwd_eval
,
"bn_fwd_eval_nhwc"
);
m
.
def
(
"bn_bwd_nhwc"
,
&
nhwc_bn_bwd
,
"bn_bwd_nhwc"
);
m
.
def
(
"bn_fwd_nhwc_occupancy"
,
&
nhwc_bn_fwd_occupancy
,
"bn_fwd_nhwc_occupancy"
);
m
.
def
(
"bn_bwd_nhwc_occupancy"
,
&
nhwc_bn_bwd_occupancy
,
"bn_bwd_nhwc_occupancy"
);
m
.
def
(
"bn_addrelu_fwd_nhwc"
,
&
nhwc_bn_addrelu_fwd_train
,
"bn_addrelu_fwd_nhwc"
);
m
.
def
(
"bn_addrelu_fwd_eval_nhwc"
,
&
nhwc_bn_addrelu_fwd_eval
,
"bn_addrelu_fwd_eval_nhwc"
);
m
.
def
(
"bn_addrelu_bwd_nhwc"
,
&
nhwc_bn_addrelu_bwd
,
"bn_addrelu_bwd_nhwc"
);
m
.
def
(
"bn_addrelu_fwd_nhwc_occupancy"
,
&
nhwc_bn_addrelu_fwd_occupancy
,
"bn_addrelu_fwd_nhwc_occupancy"
);
m
.
def
(
"bn_addrelu_bwd_nhwc_occupancy"
,
&
nhwc_bn_addrelu_bwd_occupancy
,
"bn_addrelu_bwd_nhwc_occupancy"
);
}
apex/contrib/csrc/groupbn/ipc.cu
View file @
0ef439b6
...
...
@@ -58,7 +58,8 @@ const int MAX_BLOCK_Y = 256;
const
int
MAX_OFFSET
=
REDUCE_OPS
*
MAX_BLOCK_Y
;
const
int
BYTES_PER_ELEM
=
4
;
// Buffer size per sync step
const
int
SINGLE_SYNC_BUFFER_BYTES
=
MAX_OFFSET
*
THREADS_PER_PIXEL
*
(
1
+
ELEMENTS_PER_LDG
)
*
BYTES_PER_ELEM
;
//const int SINGLE_SYNC_BUFFER_BYTES = MAX_OFFSET*THREADS_PER_PIXEL*(1+ELEMENTS_PER_LDG)*BYTES_PER_ELEM;
const
int
SINGLE_SYNC_BUFFER_BYTES
=
MAX_OFFSET
*
THREADS_PER_PIXEL
*
2
*
ELEMENTS_PER_LDG
*
BYTES_PER_ELEM
;
};
class
IpcMemHandleRegistry
{
...
...
apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h
View file @
0ef439b6
...
...
@@ -364,7 +364,121 @@ DEVICE_FUNCTION void relu_activation(float (&x)[N]) {
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_CTA
>
DEVICE_FUNCTION
void
parallel_sums_16x2
(
float
*
smem
,
float
(
&
x
)[
4
],
int
nhw
,
void
*
params_my_data
,
void
**
params_pair_datas
,
int
off
,
const
int
magic
,
const
int
sync_iters
)
{
// The size of a warp.
const
int
THREADS_PER_WARP
=
32
;
// The number of warps in a CTA.
const
int
WARPS_PER_CTA
=
THREADS_PER_CTA
/
THREADS_PER_WARP
;
// The number of threads per pixel.
const
int
THREADS_PER_PIXEL
=
16
;
// The number of elements per ldg.
const
int
ELEMENTS_PER_LDG
=
4
;
// The number of reducing ops, each uses its own space : mean, var, dscale, dbias
const
int
REDUCE_OPS
=
4
;
// Maximum block.y supported - limited due to buffer allocation
const
int
MAX_BLOCK_Y
=
256
;
const
int
MAX_OFFSET
=
REDUCE_OPS
*
MAX_BLOCK_Y
;
// The warp decomposition.
const
int
warp_id
=
threadIdx
.
x
/
THREADS_PER_WARP
;
const
int
lane_id
=
threadIdx
.
x
%
THREADS_PER_WARP
;
// total size of data per sync iter
const
int
data_total
=
MAX_OFFSET
*
THREADS_PER_PIXEL
*
ELEMENTS_PER_LDG
*
2
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
x
[
i
]
+=
__shfl_sync
(
0xffffffffU
,
x
[
i
],
THREADS_PER_PIXEL
+
lane_id
);
}
// The warp leaders, write to SMEM.
if
(
lane_id
<
THREADS_PER_PIXEL
)
{
write_to_smem
(
smem
,
warp_id
*
THREADS_PER_PIXEL
+
lane_id
,
x
);
}
// The data is in SMEM. Do the final reduction.
__syncthreads
();
// The 1st warp does all the work.
// We do the final reduction each half-warp sequentially reduces the final values.
if
(
warp_id
==
0
)
{
read_from_smem
(
x
,
smem
,
threadIdx
.
x
);
#pragma unroll
for
(
int
offset
=
1
;
offset
<
WARPS_PER_CTA
/
(
THREADS_PER_WARP
/
THREADS_PER_PIXEL
);
++
offset
)
{
float
y
[
ELEMENTS_PER_LDG
];
// Read the mean and variance from the other pixel.
read_from_smem
(
y
,
smem
,
threadIdx
.
x
+
offset
*
THREADS_PER_WARP
);
// Compute the updated sum.
add
(
x
,
y
);
}
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
x
[
i
]
+=
__shfl_sync
(
0xffffffffU
,
x
[
i
],
THREADS_PER_PIXEL
+
lane_id
);
}
// Make sure the data was read from SMEM.
__syncwarp
();
// Store the final values.
if
(
threadIdx
.
x
<
THREADS_PER_PIXEL
)
{
// probably could do it earlier, before sync
for
(
int
sync_iter
=
0
;
sync_iter
<
sync_iters
;
++
sync_iter
)
{
//float* params_pair_data = (reinterpret_cast<float**>(params_pair_datas))[sync_iter];
void
*
params_pair_data
=
params_pair_datas
[
sync_iter
];
// skip the space consumed by previous sync iterations
const
int
xbuf_offset
=
sync_iter
*
data_total
;
// data starts after flags, but have to skip previous
const
int
data_offset
=
xbuf_offset
+
off
*
ELEMENTS_PER_LDG
*
THREADS_PER_PIXEL
*
2
+
ELEMENTS_PER_LDG
*
threadIdx
.
x
*
2
;
// after sums for this GPU were computed, let CTA0 broadcast the sum to over GPU
if
(
blockIdx
.
x
==
0
)
{
volatile
float
*
write_data
=
&
((
reinterpret_cast
<
float
*>
(
params_pair_data
))[
data_offset
]);
// write the data to memory region to be reflected to other GPU
asm
volatile
(
"st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};"
::
"l"
(
write_data
)
,
"f"
(
x
[
0
]),
"r"
(
magic
),
"f"
(
x
[
2
]),
"r"
(
magic
));
asm
volatile
(
"st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};"
::
"l"
(
write_data
+
4
)
,
"f"
(
x
[
1
]),
"r"
(
magic
),
"f"
(
x
[
3
]),
"r"
(
magic
));
}
// now each CTA (on each GPU) reads the data written by CTA 0 of the other GPU
volatile
float
*
read_data
=
&
((
reinterpret_cast
<
float
*>
(
params_my_data
))[
data_offset
]);
float
other
[
4
];
uint32_t
other_flag_a
,
other_flag_b
;
do
{
asm
volatile
(
"ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];"
:
"=f"
(
other
[
0
]),
"=r"
(
other_flag_a
),
"=f"
(
other
[
2
]),
"=r"
(
other_flag_b
)
:
"l"
(
read_data
));
}
while
((
other_flag_a
!=
magic
)
||
(
other_flag_b
!=
magic
));
do
{
asm
volatile
(
"ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];"
:
"=f"
(
other
[
1
]),
"=r"
(
other_flag_a
),
"=f"
(
other
[
3
]),
"=r"
(
other_flag_b
)
:
"l"
(
read_data
+
4
));
}
while
((
other_flag_a
!=
magic
)
||
(
other_flag_b
!=
magic
));
add
(
x
,
other
);
}
// finally, after syncing up and accounting for partial sums from
// other GPUs as required, write the result
write_to_smem
(
smem
,
threadIdx
.
x
,
x
);
}
}
}
#ifdef OLD_STUFF
template
<
int
THREADS_PER_CTA
>
DEVICE_FUNCTION
void
parallel_sums_16x2
(
float
*
smem
,
float
(
&
x
)[
4
],
int
nhw
,
void
*
params_my_data
,
void
*
params_pair_data
,
int
off
,
const
int
magic
,
void
*
params_pair_data2
,
const
unsigned
int
&
sync_iters
)
{
// The size of a warp.
...
...
@@ -495,6 +609,7 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw, voi
}
}
}
#endif //OLD_STUFF
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
@@ -655,12 +770,12 @@ template<>
struct
ParallelSums
<
16
,
4
>
{
template
<
int
THREADS_PER_CTA
>
DEVICE_FUNCTION
void
dispatch
(
float
*
smem
,
float
(
&
x
)[
4
],
int
nhw
)
{
parallel_sums_16x2
<
THREADS_PER_CTA
>
(
smem
,
x
,
nhw
,
0
,
0
,
0
,
0
,
0
,
0
);
parallel_sums_16x2
<
THREADS_PER_CTA
>
(
smem
,
x
,
nhw
,
0
,
0
,
0
,
0
,
0
);
}
template
<
int
THREADS_PER_CTA
>
DEVICE_FUNCTION
void
dispatchX
(
float
*
smem
,
float
(
&
x
)[
4
],
int
nhw
,
void
*
params_my_data
,
void
*
params_pair_data
,
int
off
,
const
int
magic
,
void
*
params_pair_data2
,
const
unsigned
int
&
sync_iters
)
{
parallel_sums_16x2
<
THREADS_PER_CTA
>
(
smem
,
x
,
nhw
,
params_my_data
,
params_pair_data
,
off
,
magic
,
params_pair_data2
,
sync_iters
);
DEVICE_FUNCTION
void
dispatchX
(
float
*
smem
,
float
(
&
x
)[
4
],
int
nhw
,
void
*
params_my_data
,
void
*
*
params_pair_data
s
,
int
off
,
const
int
magic
,
const
unsigned
int
&
sync_iters
)
{
parallel_sums_16x2
<
THREADS_PER_CTA
>
(
smem
,
x
,
nhw
,
params_my_data
,
params_pair_data
s
,
off
,
magic
,
sync_iters
);
}
};
...
...
@@ -858,8 +973,7 @@ struct NhwcBatchNormFwdParams {
int
c_blks
;
void
*
my_data
;
void
*
pair_data
;
void
*
pair_data2
;
void
*
pair_datas
[
4
];
int
magic
;
int
sync_iters
;
};
...
...
@@ -1185,7 +1299,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
m1
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_data
,
4
*
c_blk_index
+
3
,
params
.
magic
,
params
.
pair_data2
,
params
.
sync_iters
);
smem
,
m1
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_data
s
,
4
*
c_blk_index
+
3
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
m1
,
thread_in_cta_nhw
);
...
...
@@ -1242,7 +1356,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
m2
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_data
,
4
*
c_blk_index
+
2
,
params
.
magic
,
params
.
pair_data2
,
params
.
sync_iters
);
smem
,
m2
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_data
s
,
4
*
c_blk_index
+
2
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
m2
,
thread_in_cta_nhw
);
...
...
@@ -1440,8 +1554,7 @@ struct NhwcBatchNormBwdParams {
int
c_blks
;
void
*
my_data
;
void
*
pair_data
;
void
*
pair_data2
;
void
*
pair_datas
[
4
];
int
magic
;
int
sync_iters
;
float
wgrad_coeff
;
...
...
@@ -1778,7 +1891,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
// dscale parallel sum
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_data
,
4
*
c_blk_index
+
1
,
params
.
magic
,
params
.
pair_data2
,
params
.
sync_iters
);
smem
,
dscale
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_data
s
,
4
*
c_blk_index
+
1
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
);
...
...
@@ -1792,7 +1905,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
// dbias parallel sum
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_data
,
4
*
c_blk_index
+
0
,
params
.
magic
,
params
.
pair_data2
,
params
.
sync_iters
);
smem
,
dbias
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_data
s
,
4
*
c_blk_index
+
0
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
);
...
...
@@ -2172,7 +2285,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
// dscale parallel sum
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_data
,
4
*
c_blk_index
+
1
,
params
.
magic
,
params
.
pair_data2
,
params
.
sync_iters
);
smem
,
dscale
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_data
s
,
4
*
c_blk_index
+
1
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
);
...
...
@@ -2186,7 +2299,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
// dbias parallel sum
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_data
,
4
*
c_blk_index
+
0
,
params
.
magic
,
params
.
pair_data2
,
params
.
sync_iters
);
smem
,
dbias
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_data
s
,
4
*
c_blk_index
+
0
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
);
...
...
@@ -2595,7 +2708,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
// dscale parallel sum
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_data
,
4
*
c_blk_index
+
1
,
params
.
magic
,
params
.
pair_data2
,
params
.
sync_iters
);
smem
,
dscale
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_data
s
,
4
*
c_blk_index
+
1
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
);
...
...
@@ -2609,7 +2722,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
// dbias parallel sum
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_data
,
4
*
c_blk_index
+
0
,
params
.
magic
,
params
.
pair_data2
,
params
.
sync_iters
);
smem
,
dbias
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_data
s
,
4
*
c_blk_index
+
0
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
);
...
...
apex/contrib/groupbn/batch_norm.py
View file @
0ef439b6
...
...
@@ -6,91 +6,105 @@ import bnp
class
bn_NHWC_impl
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
mom
,
epsilon
,
fuse_relu
=
False
,
is_train
=
True
,
bn_group
=
1
,
my_data
=
None
,
pair_data
=
None
,
magic
=
1
,
pair_data2
=
None
,
max_cta_per_sm
=
2
,
cta_launch_margin
=
12
):
def
forward
(
ctx
,
x
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
ret_cta
,
mom
,
epsilon
,
fuse_relu
,
is_train
,
bn_group
,
my_data
,
pair_data
,
magic
,
pair_data2
,
pair_data3
,
fwd_occup
,
fwd_grid_x
,
bwd_occup
,
bwd_grid_x
,
multi_stream
):
if
is_train
:
ctx
.
save_for_backward
(
x
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
)
ctx
.
epsilon
=
epsilon
ctx
.
momentum
=
mom
ctx
.
ret_cta
=
ret_cta
ctx
.
fuse_relu
=
fuse_relu
ctx
.
my_data
=
my_data
ctx
.
pair_data
=
pair_data
ctx
.
magic
=
magic
ctx
.
pair_data2
=
pair_data2
ctx
.
pair_data3
=
pair_data3
ctx
.
bn_group
=
bn_group
ctx
.
max_cta_per_sm
=
max_cta_per_sm
ctx
.
cta_launch_margin
=
cta_launch_margin
ctx
.
bwd_occup
=
bwd_occup
ctx
.
bwd_grid_x
=
bwd_grid_x
ctx
.
multi_stream
=
multi_stream
res
=
bnp
.
bn_fwd_nhwc
(
x
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
mom
,
epsilon
,
fuse_relu
,
my_data
,
pair_data
,
pair_data2
,
bn_group
,
magic
,
max_cta_per_sm
,
cta_launch_margin
)
res
=
bnp
.
bn_fwd_nhwc
(
x
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
ret_cta
,
mom
,
epsilon
,
fuse_relu
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
magic
,
fwd_occup
,
fwd_grid_x
,
multi_stream
)
return
res
else
:
return
bnp
.
bn_fwd_eval_nhwc
(
x
,
s
,
b
,
rm
,
riv
,
bn_group
,
mom
,
epsilon
,
fuse_relu
)
return
bnp
.
bn_fwd_eval_nhwc
(
x
,
s
,
b
,
rm
,
riv
,
ret_cta
,
bn_group
,
mom
,
epsilon
,
fuse_relu
)
@
staticmethod
def
backward
(
ctx
,
grad_y
):
x
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
=
ctx
.
saved_variables
epsilon
=
ctx
.
epsilon
mom
=
ctx
.
momentum
ret_cta
=
ctx
.
ret_cta
fuse_relu
=
ctx
.
fuse_relu
my_data
=
ctx
.
my_data
pair_data
=
ctx
.
pair_data
magic
=
ctx
.
magic
pair_data2
=
ctx
.
pair_data2
pair_data3
=
ctx
.
pair_data3
bn_group
=
ctx
.
bn_group
max_cta_per_sm
=
ctx
.
max_cta_per_sm
cta_launch_margin
=
ctx
.
cta_launch_margin
bwd_occup
=
ctx
.
bwd_occup
bwd_grid_x
=
ctx
.
bwd_grid_x
multi_stream
=
ctx
.
multi_stream
dx
,
dscale
,
dbias
=
bnp
.
bn_bwd_nhwc
(
x
,
grad_y
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
mom
,
epsilon
,
fuse_relu
,
my_data
,
pair_data
,
pair_data2
,
bn_group
,
magic
,
max_cta_per_sm
,
cta_launch_margin
)
dx
,
dscale
,
dbias
=
bnp
.
bn_bwd_nhwc
(
x
,
grad_y
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
ret_cta
,
mom
,
epsilon
,
fuse_relu
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
magic
,
bwd_occup
,
bwd_grid_x
,
multi_stream
)
return
dx
,
dscale
,
dbias
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dx
,
dscale
,
dbias
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
bn_addrelu_NHWC_impl
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
z
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
mom
,
epsilon
,
is_train
=
True
,
bn_group
=
1
,
my_data
=
None
,
pair_data
=
None
,
magic
=
1
,
pair_data2
=
None
,
max_cta_per_sm
=
2
,
cta_launch_margin
=
12
):
def
forward
(
ctx
,
x
,
z
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
grid_dim_y
,
ret_cta
,
mom
,
epsilon
,
is_train
,
bn_group
,
my_data
,
pair_data
,
magic
,
pair_data2
,
pair_data3
,
fwd_occup
,
fwd_grid_x
,
bwd_occup
,
bwd_grid_x
,
multi_stream
):
if
is_train
:
bitmask
=
torch
.
cuda
.
IntTensor
(
x
.
numel
()
//
32
)
bitmask
=
torch
.
cuda
.
IntTensor
(
((
x
.
numel
()
+
31
)
//
32
)
*
2
*
grid_dim_y
)
ctx
.
save_for_backward
(
x
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
bitmask
)
ctx
.
epsilon
=
epsilon
ctx
.
momentum
=
mom
ctx
.
ret_cta
=
ret_cta
ctx
.
my_data
=
my_data
ctx
.
pair_data
=
pair_data
ctx
.
magic
=
magic
ctx
.
pair_data2
=
pair_data2
ctx
.
pair_data3
=
pair_data3
ctx
.
bn_group
=
bn_group
ctx
.
max_cta_per_sm
=
max_cta_per_sm
ctx
.
cta_launch_margin
=
cta_launch_margin
ctx
.
bwd_occup
=
bwd_occup
ctx
.
bwd_grid_x
=
bwd_grid_x
ctx
.
multi_stream
=
multi_stream
res
=
bnp
.
bn_addrelu_fwd_nhwc
(
x
,
z
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
bitmask
,
mom
,
epsilon
,
my_data
,
pair_data
,
pair_data2
,
bn_group
,
magic
,
max_cta_per_sm
,
cta_launch_margin
)
res
=
bnp
.
bn_addrelu_fwd_nhwc
(
x
,
z
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
bitmask
,
ret_cta
,
mom
,
epsilon
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
magic
,
fwd_occup
,
fwd_grid_x
,
multi_stream
)
return
res
else
:
return
bnp
.
bn_addrelu_fwd_eval_nhwc
(
x
,
z
,
s
,
b
,
rm
,
riv
,
bn_group
,
mom
,
epsilon
)
return
bnp
.
bn_addrelu_fwd_eval_nhwc
(
x
,
z
,
s
,
b
,
rm
,
riv
,
ret_cta
,
bn_group
,
mom
,
epsilon
)
@
staticmethod
def
backward
(
ctx
,
grad_y
):
x
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
bitmask
=
ctx
.
saved_variables
epsilon
=
ctx
.
epsilon
mom
=
ctx
.
momentum
ret_cta
=
ctx
.
ret_cta
my_data
=
ctx
.
my_data
pair_data
=
ctx
.
pair_data
magic
=
ctx
.
magic
pair_data2
=
ctx
.
pair_data2
pair_data3
=
ctx
.
pair_data3
bn_group
=
ctx
.
bn_group
max_cta_per_sm
=
ctx
.
max_cta_per_sm
cta_launch_margin
=
ctx
.
cta_launch_margin
bwd_occup
=
ctx
.
bwd_occup
bwd_grid_x
=
ctx
.
bwd_grid_x
multi_stream
=
ctx
.
multi_stream
dx
,
dz
,
dscale
,
dbias
=
bnp
.
bn_addrelu_bwd_nhwc
(
x
,
grad_y
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
bitmask
,
mom
,
epsilon
,
my_data
,
pair_data
,
pair_data2
,
bn_group
,
magic
,
max_cta_per_sm
,
cta_launch_margin
)
dx
,
dz
,
dscale
,
dbias
=
bnp
.
bn_addrelu_bwd_nhwc
(
x
,
grad_y
,
s
,
b
,
rm
,
riv
,
mini_m
,
mini_riv
,
bitmask
,
ret_cta
,
mom
,
epsilon
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
magic
,
bwd_occup
,
bwd_grid_x
,
multi_stream
)
return
dx
,
dz
,
dscale
,
dbias
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dx
,
dz
,
dscale
,
dbias
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
BatchNorm2d_NHWC
(
_BatchNorm
):
def
__init__
(
self
,
num_features
,
fuse_relu
=
False
,
bn_group
=
1
,
max_cta_per_sm
=
2
,
cta_launch_margin
=
12
):
# if using BatchNorm2d_NHWC simultaneously with multiple streams set multi_stream to True
def
__init__
(
self
,
num_features
,
fuse_relu
=
False
,
bn_group
=
1
,
max_cta_per_sm
=
2
,
cta_launch_margin
=
12
,
multi_stream
=
False
):
super
(
BatchNorm2d_NHWC
,
self
).
__init__
(
num_features
)
self
.
fuse_relu
=
fuse_relu
self
.
multi_stream
=
multi_stream
self
.
minibatch_mean
=
torch
.
cuda
.
FloatTensor
(
num_features
)
self
.
minibatch_riv
=
torch
.
cuda
.
FloatTensor
(
num_features
)
...
...
@@ -102,13 +116,29 @@ class BatchNorm2d_NHWC(_BatchNorm):
self
.
my_data
=
None
self
.
pair_data
=
None
self
.
pair_data2
=
None
self
.
pair_data3
=
None
self
.
local_rank
=
0
self
.
magic
=
torch
.
IntTensor
([
0
])
#calculate cta per sm occupancies
assert
(
max_cta_per_sm
>
0
)
# won't be able to do much with 0 CTAs :)
self
.
fwd_occupancy
=
min
(
bnp
.
bn_fwd_nhwc_occupancy
(),
max_cta_per_sm
)
self
.
bwd_occupancy
=
min
(
bnp
.
bn_bwd_nhwc_occupancy
(),
max_cta_per_sm
)
self
.
addrelu_fwd_occupancy
=
min
(
bnp
.
bn_addrelu_fwd_nhwc_occupancy
(),
max_cta_per_sm
)
self
.
addrelu_bwd_occupancy
=
min
(
bnp
.
bn_addrelu_bwd_nhwc_occupancy
(),
max_cta_per_sm
)
#calculate grid dimentions based on occupancy numbers
mp_count
=
torch
.
cuda
.
get_device_properties
(
None
).
multi_processor_count
self
.
fwd_grid_dim_x
=
max
(
mp_count
*
self
.
fwd_occupancy
-
cta_launch_margin
,
1
)
self
.
bwd_grid_dim_x
=
max
(
mp_count
*
self
.
bwd_occupancy
-
cta_launch_margin
,
1
)
self
.
addrelu_fwd_grid_dim_x
=
max
(
mp_count
*
self
.
addrelu_fwd_occupancy
-
cta_launch_margin
,
1
)
self
.
addrelu_bwd_grid_dim_x
=
max
(
mp_count
*
self
.
addrelu_bwd_occupancy
-
cta_launch_margin
,
1
)
self
.
grid_dim_y
=
(
num_features
+
63
)
//
64
# allocate scratch space used by implementation
self
.
ret_cta
=
torch
.
cuda
.
ByteTensor
(
8192
).
fill_
(
0
)
#FIXME: turn pair handles into an array
if
bn_group
>
1
:
local_rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
...
...
@@ -118,6 +148,8 @@ class BatchNorm2d_NHWC(_BatchNorm):
bn_sync_steps
=
1
if
(
bn_group
==
4
):
bn_sync_steps
=
2
if
(
bn_group
==
8
):
bn_sync_steps
=
3
self
.
ipc_buffer
=
torch
.
cuda
.
ByteTensor
(
bnp
.
get_buffer_size
(
bn_sync_steps
))
self
.
my_data
=
bnp
.
get_data_ptr
(
self
.
ipc_buffer
)
...
...
@@ -148,6 +180,11 @@ class BatchNorm2d_NHWC(_BatchNorm):
pair_offset2
=
offsets_l
[
local_rank
^
2
].
cpu
()
self
.
pair_data2
=
bnp
.
get_remote_data_ptr
(
self
.
pair_handle2
,
pair_offset2
)
if
bn_group
>
4
:
self
.
pair_handle3
=
handles_l
[
local_rank
^
3
].
cpu
().
contiguous
()
pair_offset3
=
offsets_l
[
local_rank
^
3
].
cpu
()
self
.
pair_data3
=
bnp
.
get_remote_data_ptr
(
self
.
pair_handle3
,
pair_offset3
)
#FIXME: get magic value into C code and eliminate from here
self
.
magic
=
torch
.
IntTensor
([
2
])
self
.
local_rank
=
local_rank
...
...
@@ -159,21 +196,27 @@ class BatchNorm2d_NHWC(_BatchNorm):
return
bn_addrelu_NHWC_impl
.
apply
(
x
,
z
,
self
.
weight
,
self
.
bias
,
self
.
running_mean
,
self
.
running_var
,
self
.
minibatch_mean
,
self
.
minibatch_riv
,
self
.
minibatch_mean
,
self
.
minibatch_riv
,
self
.
grid_dim_y
,
self
.
ret_cta
,
self
.
momentum
,
self
.
eps
,
self
.
training
,
self
.
bn_group
,
self
.
my_data
,
self
.
pair_data
,
(
self
.
magic
),
self
.
pair_data2
,
self
.
max_cta_per_sm
,
self
.
cta_launch_margin
)
self
.
eps
,
self
.
training
,
self
.
bn_group
,
self
.
my_data
,
self
.
pair_data
,
(
self
.
magic
),
self
.
pair_data2
,
self
.
pair_data3
,
self
.
addrelu_fwd_occupancy
,
self
.
addrelu_fwd_grid_dim_x
,
self
.
addrelu_bwd_occupancy
,
self
.
addrelu_bwd_grid_dim_x
,
self
.
multi_stream
)
else
:
return
bn_NHWC_impl
.
apply
(
x
,
self
.
weight
,
self
.
bias
,
self
.
running_mean
,
self
.
running_var
,
self
.
minibatch_mean
,
self
.
minibatch_riv
,
self
.
minibatch_mean
,
self
.
minibatch_riv
,
self
.
ret_cta
,
self
.
momentum
,
self
.
eps
,
self
.
fuse_relu
,
self
.
training
,
self
.
bn_group
,
self
.
my_data
,
self
.
pair_data
,
(
self
.
magic
),
self
.
pair_data2
,
self
.
max_cta_per_sm
,
self
.
cta_launch_margin
)
self
.
eps
,
self
.
fuse_relu
,
self
.
training
,
self
.
bn_group
,
self
.
my_data
,
self
.
pair_data
,
(
self
.
magic
),
self
.
pair_data2
,
self
.
pair_data3
,
self
.
fwd_occupancy
,
self
.
fwd_grid_dim_x
,
self
.
bwd_occupancy
,
self
.
bwd_grid_dim_x
,
self
.
multi_stream
)
def
__del__
(
self
):
if
self
.
bn_group
>
1
:
bnp
.
close_remote_data
(
self
.
pair_handle
)
if
self
.
bn_group
>
2
:
bnp
.
close_remote_data
(
self
.
pair_handle2
)
if
self
.
bn_group
>
4
:
bnp
.
close_remote_data
(
self
.
pair_handle3
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment