Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
0273d7ad
Unverified
Commit
0273d7ad
authored
Dec 03, 2018
by
mcarilli
Committed by
GitHub
Dec 03, 2018
Browse files
[syncBN] (#77)
adjusted kernel config for better perf. removed divergence in welford warp reduction.
parents
5dad4c21
ee67e56a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
56 additions
and
43 deletions
+56
-43
csrc/welford.cu
csrc/welford.cu
+56
-43
No files found.
csrc/welford.cu
View file @
0273d7ad
...
@@ -14,18 +14,19 @@
...
@@ -14,18 +14,19 @@
__device__
__forceinline__
int
lastpow2
(
int
n
)
__device__
__forceinline__
int
lastpow2
(
int
n
)
{
{
int
out
=
1
<<
(
31
-
__clz
(
n
));
int
out
=
1
<<
(
31
-
__clz
(
n
));
if
(
n
==
out
)
if
(
n
==
out
)
out
>>=
1
;
out
>>=
1
;
return
out
;
return
out
;
}
}
__host__
__forceinline__
int
h_next_pow2
(
unsigned
int
n
)
{
__host__
__forceinline__
int
h_next_pow2
(
unsigned
int
n
)
{
unsigned
int
old
=
n
;
n
|=
(
n
>>
1
);
n
|=
(
n
>>
1
);
n
|=
(
n
>>
2
);
n
|=
(
n
>>
2
);
n
|=
(
n
>>
4
);
n
|=
(
n
>>
4
);
n
|=
(
n
>>
8
);
n
|=
(
n
>>
8
);
n
|=
(
n
>>
16
);
n
|=
(
n
>>
16
);
return
n
+
1
;
return
n
==
old
?
n
:
n
+
1
;
}
}
__host__
__forceinline__
int
h_last_pow2
(
unsigned
int
n
)
{
__host__
__forceinline__
int
h_last_pow2
(
unsigned
int
n
)
{
...
@@ -71,7 +72,7 @@ __device__ __forceinline__ T reduce_block(T *x, T val)
...
@@ -71,7 +72,7 @@ __device__ __forceinline__ T reduce_block(T *x, T val)
}
}
#define TILE_W 32
#define TILE_W 32
#define MAX_BLOCK_SIZE
256
#define MAX_BLOCK_SIZE
1024
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
void
warp_reduce_mean_m2n
(
T
&
mean
,
T
&
m2n
,
int
&
num
)
__device__
__forceinline__
void
warp_reduce_mean_m2n
(
T
&
mean
,
T
&
m2n
,
int
&
num
)
...
@@ -81,12 +82,11 @@ __device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num)
...
@@ -81,12 +82,11 @@ __device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num)
auto
num_new
=
__shfl_down_sync
(
0xffffffff
,
num
,
i
);
auto
num_new
=
__shfl_down_sync
(
0xffffffff
,
num
,
i
);
auto
mean_new
=
__shfl_down_sync
(
0xffffffff
,
mean
,
i
);
auto
mean_new
=
__shfl_down_sync
(
0xffffffff
,
mean
,
i
);
auto
m2n_new
=
__shfl_down_sync
(
0xffffffff
,
m2n
,
i
);
auto
m2n_new
=
__shfl_down_sync
(
0xffffffff
,
m2n
,
i
);
if
(
num_new
!=
0
)
{
T
factor
=
1.0
/
max
(
1
,
(
num
+
num_new
));
auto
dif_mean
=
mean
-
mean_new
;
auto
dif_mean
=
mean
-
mean_new
;
mean
=
(
mean_new
*
num_new
+
mean
*
num
)
/
(
num
+
num_new
);
mean
=
(
mean_new
*
num_new
+
mean
*
num
)
*
factor
;
m2n
+=
m2n_new
+
dif_mean
*
dif_mean
*
num
*
num_new
/
(
num_new
+
num
);
m2n
+=
m2n_new
+
dif_mean
*
dif_mean
*
num
*
num_new
*
factor
;
num
+=
num_new
;
num
+=
num_new
;
}
}
}
}
}
...
@@ -159,11 +159,7 @@ __global__ void welford_kernel(
...
@@ -159,11 +159,7 @@ __global__ void welford_kernel(
const
int
bs
,
const
int
bs
,
const
int
fs
,
const
int
fs
,
const
int
ss
)
{
const
int
ss
)
{
static
__shared__
int
s_mem
[
160
];
int
block_size
=
blockDim
.
x
*
blockDim
.
y
;
int
block_size
=
blockDim
.
x
*
blockDim
.
y
;
accscalar_t
*
s_mem_ac
=
(
accscalar_t
*
)
&
s_mem
[
32
];
int
count
=
0
;
int
count
=
0
;
accscalar_t
x_mean
=
accscalar_t
(
0
);
accscalar_t
x_mean
=
accscalar_t
(
0
);
accscalar_t
m_2_n
=
accscalar_t
(
0
);
accscalar_t
m_2_n
=
accscalar_t
(
0
);
...
@@ -176,12 +172,15 @@ __global__ void welford_kernel(
...
@@ -176,12 +172,15 @@ __global__ void welford_kernel(
for
(
int
offset
=
threadIdx
.
x
;
offset
<
ss
;
offset
+=
blockDim
.
x
)
{
for
(
int
offset
=
threadIdx
.
x
;
offset
<
ss
;
offset
+=
blockDim
.
x
)
{
count
++
;
count
++
;
auto
x_n
=
static_cast
<
accscalar_t
>
(
input
[
offset
+
input_base
]);
auto
x_n
=
static_cast
<
accscalar_t
>
(
input
[
offset
+
input_base
]);
auto
x_mean_new
=
x_mean
+
(
x_n
-
x_mean
)
/
count
;
auto
d
=
x_n
-
x_mean
;
m_2_n
=
m_2_n
+
(
x_n
-
x_mean_new
)
*
(
x_n
-
x_mean
)
;
x_mean
+=
d
/
count
;
x_mea
n
=
x_mean
_new
;
m_2_
n
+
=
d
*
(
x_n
-
x_mean
)
;
}
}
}
}
static
__shared__
int
s_mem
[
160
];
accscalar_t
*
s_mem_ac
=
(
accscalar_t
*
)
&
s_mem
[
32
];
welford_reduce_mean_m2n
<
accscalar_t
>
(
s_mem_ac
,
s_mem
,
x_mean
,
m_2_n
,
count
,
block_size
,
thread_id
);
welford_reduce_mean_m2n
<
accscalar_t
>
(
s_mem_ac
,
s_mem
,
x_mean
,
m_2_n
,
count
,
block_size
,
thread_id
);
if
(
thread_id
==
0
)
{
if
(
thread_id
==
0
)
{
...
@@ -201,16 +200,18 @@ __global__ void batchnorm_forward_kernel(
...
@@ -201,16 +200,18 @@ __global__ void batchnorm_forward_kernel(
const
layerscalar_t
*
__restrict__
shift
,
const
layerscalar_t
*
__restrict__
shift
,
scalar_t
*
__restrict__
out
,
scalar_t
*
__restrict__
out
,
const
int
ss
,
const
int
ss
,
const
int
bs
,
const
float
eps
)
{
const
float
eps
)
{
int
address_base
=
blockIdx
.
x
*
ss
+
blockIdx
.
y
*
gridDim
.
x
*
ss
;
auto
m_c
=
mean
[
blockIdx
.
x
];
auto
m_c
=
mean
[
blockIdx
.
x
];
auto
inv_std_c
=
static_cast
<
accscalar_t
>
(
rsqrt
(
var
[
blockIdx
.
x
]
+
eps
));
auto
inv_std_c
=
static_cast
<
accscalar_t
>
(
rsqrt
(
var
[
blockIdx
.
x
]
+
eps
));
auto
w_c
=
static_cast
<
accscalar_t
>
(
weight
[
blockIdx
.
x
]);
auto
w_c
=
static_cast
<
accscalar_t
>
(
weight
[
blockIdx
.
x
]);
auto
s_c
=
static_cast
<
accscalar_t
>
(
shift
[
blockIdx
.
x
]);
auto
s_c
=
static_cast
<
accscalar_t
>
(
shift
[
blockIdx
.
x
]);
for
(
int
offset
=
threadIdx
.
x
;
offset
<
ss
;
offset
+=
blockDim
.
x
)
{
for
(
int
batch_offset
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
batch_offset
<
bs
;
batch_offset
+=
gridDim
.
y
*
blockDim
.
y
)
{
out
[
address_base
+
offset
]
=
static_cast
<
scalar_t
>
(
w_c
*
(
static_cast
<
accscalar_t
>
(
input
[
address_base
+
offset
])
-
m_c
)
*
inv_std_c
+
s_c
);
int
address_base
=
blockIdx
.
x
*
ss
+
batch_offset
*
gridDim
.
x
*
ss
;
for
(
int
offset
=
threadIdx
.
x
+
blockIdx
.
z
*
blockDim
.
x
;
offset
<
ss
;
offset
+=
gridDim
.
z
*
blockDim
.
x
)
{
out
[
address_base
+
offset
]
=
static_cast
<
scalar_t
>
(
w_c
*
(
static_cast
<
accscalar_t
>
(
input
[
address_base
+
offset
])
-
m_c
)
*
inv_std_c
+
s_c
);
}
}
}
}
}
...
@@ -267,7 +268,7 @@ __global__ void reduce_bn_kernel(
...
@@ -267,7 +268,7 @@ __global__ void reduce_bn_kernel(
sum_dy
=
reduce_block
((
accscalar_t
*
)
s_mem
,
sum_dy
);
sum_dy
=
reduce_block
((
accscalar_t
*
)
s_mem
,
sum_dy
);
__syncthreads
();
__syncthreads
();
sum_dy_xmu
=
reduce_block
((
accscalar_t
*
)
s_mem
,
sum_dy_xmu
);
sum_dy_xmu
=
reduce_block
((
accscalar_t
*
)
s_mem
,
sum_dy_xmu
);
if
(
thread_id
==
0
)
{
if
(
thread_id
==
0
)
{
grad_bias
[
blockIdx
.
x
]
=
static_cast
<
layerscalar_t
>
(
sum_dy
);
grad_bias
[
blockIdx
.
x
]
=
static_cast
<
layerscalar_t
>
(
sum_dy
);
grad_weight
[
blockIdx
.
x
]
=
static_cast
<
layerscalar_t
>
(
sum_dy_xmu
*
factor
);
grad_weight
[
blockIdx
.
x
]
=
static_cast
<
layerscalar_t
>
(
sum_dy_xmu
*
factor
);
...
@@ -288,17 +289,19 @@ __global__ void batchnorm_backward_kernel(
...
@@ -288,17 +289,19 @@ __global__ void batchnorm_backward_kernel(
const
accscalar_t
*
__restrict__
mean_dy_xmu
,
const
accscalar_t
*
__restrict__
mean_dy_xmu
,
scalar_t
*
__restrict__
grad_input
,
scalar_t
*
__restrict__
grad_input
,
const
int
ss
,
const
int
ss
,
const
int
bs
,
const
float
eps
)
{
const
float
eps
)
{
int
address_base
=
blockIdx
.
x
*
ss
+
blockIdx
.
y
*
gridDim
.
x
*
ss
;
auto
m_c
=
static_cast
<
accscalar_t
>
(
mean
[
blockIdx
.
x
]);
auto
m_c
=
static_cast
<
accscalar_t
>
(
mean
[
blockIdx
.
x
]);
auto
m_dy_c
=
static_cast
<
accscalar_t
>
(
mean_dy
[
blockIdx
.
x
]);
auto
m_dy_c
=
static_cast
<
accscalar_t
>
(
mean_dy
[
blockIdx
.
x
]);
auto
factor_1_c
=
static_cast
<
accscalar_t
>
(
var
[
blockIdx
.
x
])
+
eps
;
auto
factor_1_c
=
static_cast
<
accscalar_t
>
(
var
[
blockIdx
.
x
])
+
eps
;
auto
factor_2_c
=
static_cast
<
accscalar_t
>
(
weight
[
blockIdx
.
x
])
/
sqrt
(
factor_1_c
);
auto
factor_2_c
=
static_cast
<
accscalar_t
>
(
weight
[
blockIdx
.
x
])
/
sqrt
(
factor_1_c
);
factor_1_c
/=
static_cast
<
accscalar_t
>
(
mean_dy_xmu
[
blockIdx
.
x
]);
factor_1_c
/=
static_cast
<
accscalar_t
>
(
mean_dy_xmu
[
blockIdx
.
x
]);
for
(
int
offset
=
threadIdx
.
x
;
offset
<
ss
;
offset
+=
blockDim
.
x
)
{
for
(
int
batch_offset
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
batch_offset
<
bs
;
batch_offset
+=
gridDim
.
y
*
blockDim
.
y
)
{
grad_input
[
address_base
+
offset
]
=
(
static_cast
<
accscalar_t
>
(
grad_output
[
address_base
+
offset
])
-
m_dy_c
-
(
static_cast
<
accscalar_t
>
(
input
[
address_base
+
offset
])
-
m_c
)
/
factor_1_c
)
*
factor_2_c
;
int
address_base
=
blockIdx
.
x
*
ss
+
batch_offset
*
gridDim
.
x
*
ss
;
for
(
int
offset
=
threadIdx
.
x
+
blockIdx
.
z
*
blockDim
.
x
;
offset
<
ss
;
offset
+=
gridDim
.
z
*
blockDim
.
x
)
{
grad_input
[
address_base
+
offset
]
=
(
static_cast
<
accscalar_t
>
(
grad_output
[
address_base
+
offset
])
-
m_dy_c
-
(
static_cast
<
accscalar_t
>
(
input
[
address_base
+
offset
])
-
m_c
)
/
factor_1_c
)
*
factor_2_c
;
}
}
}
}
}
...
@@ -322,7 +325,7 @@ __global__ void welford_kernel_parallel(
...
@@ -322,7 +325,7 @@ __global__ void welford_kernel_parallel(
int
input_base
=
blockIdx
.
x
*
ns
+
threadIdx
.
x
;
int
input_base
=
blockIdx
.
x
*
ns
+
threadIdx
.
x
;
int
thread_id
=
threadIdx
.
x
;
int
thread_id
=
threadIdx
.
x
;
// load data;
// load data;
auto
x_mean
=
static_cast
<
accscalar_t
>
(
mean
[
input_base
]);
auto
x_mean
=
static_cast
<
accscalar_t
>
(
mean
[
input_base
]);
auto
m_2_n
=
static_cast
<
accscalar_t
>
(
var_biased
[
input_base
])
*
numel
;
auto
m_2_n
=
static_cast
<
accscalar_t
>
(
var_biased
[
input_base
])
*
numel
;
auto
count
=
numel
;
auto
count
=
numel
;
...
@@ -337,7 +340,7 @@ __global__ void welford_kernel_parallel(
...
@@ -337,7 +340,7 @@ __global__ void welford_kernel_parallel(
out_var_biased
[
blockIdx
.
x
]
=
static_cast
<
scalar_t
>
(
m_2_n
/
count
);
out_var_biased
[
blockIdx
.
x
]
=
static_cast
<
scalar_t
>
(
m_2_n
/
count
);
}
}
}
}
std
::
vector
<
at
::
Tensor
>
welford_mean_var_CUDA
(
const
at
::
Tensor
input
)
{
std
::
vector
<
at
::
Tensor
>
welford_mean_var_CUDA
(
const
at
::
Tensor
input
)
{
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
batch_size
=
input
.
size
(
0
);
...
@@ -350,8 +353,8 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
...
@@ -350,8 +353,8 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
at
::
Tensor
out_var_biased
=
at
::
empty
({
feature_size
},
input
.
options
().
dtype
(
scalar_type
));
at
::
Tensor
out_var_biased
=
at
::
empty
({
feature_size
},
input
.
options
().
dtype
(
scalar_type
));
at
::
Tensor
out_mean
=
at
::
empty
({
feature_size
},
input
.
options
().
dtype
(
scalar_type
));
at
::
Tensor
out_mean
=
at
::
empty
({
feature_size
},
input
.
options
().
dtype
(
scalar_type
));
int
block_
x
=
TILE_W
;
int
block_
y
=
min
(
h_last_pow2
(
batch_size
),
int
(
MAX_BLOCK_SIZE
/
32
))
;
int
block_
y
=
m
in
(
h_last_pow2
(
batch_size
)
,
in
t
(
MAX_BLOCK_SIZE
/
block_
x
));
int
block_
x
=
m
ax
(
1
,
m
in
(
MAX_BLOCK_SIZE
/
block_
y
,
h_last_pow2
(
space_size
)
));
const
dim3
block
(
block_x
,
block_y
);
const
dim3
block
(
block_x
,
block_y
);
const
dim3
grid
(
feature_size
);
const
dim3
grid
(
feature_size
);
...
@@ -386,9 +389,12 @@ at::Tensor batchnorm_forward_CUDA(
...
@@ -386,9 +389,12 @@ at::Tensor batchnorm_forward_CUDA(
auto
space_size
=
get_tensor_spatial_size
(
input
);
auto
space_size
=
get_tensor_spatial_size
(
input
);
int
block
=
min
(
MAX_BLOCK_SIZE
,
h_next_pow2
(
space_size
)
/
4
);
int
block_x
=
max
(
32
,
min
(
MAX_BLOCK_SIZE
,
h_last_pow2
(
space_size
)
/
4
));
// TODO(jie): should I do 1 block per feature?
int
block_y
=
max
(
1
,
min
(
MAX_BLOCK_SIZE
/
block_x
,
h_last_pow2
(
batch_size
)
/
4
));
const
dim3
grid
(
feature_size
,
batch_size
);
const
dim3
block
(
block_x
,
block_y
);
int
grid_z
=
max
(
1
,
min
(
65535
,
h_last_pow2
(
space_size
)
/
4
/
block_x
));
int
batch_group_size
=
max
(
1
,
min
(
65535
,
h_last_pow2
(
batch_size
)
/
block_y
));
const
dim3
grid
(
feature_size
,
batch_group_size
,
grid_z
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
...
@@ -402,10 +408,11 @@ at::Tensor batchnorm_forward_CUDA(
...
@@ -402,10 +408,11 @@ at::Tensor batchnorm_forward_CUDA(
shift
.
data
<
accscalar_t
>
(),
shift
.
data
<
accscalar_t
>
(),
out
.
data
<
scalar_t
>
(),
out
.
data
<
scalar_t
>
(),
space_size
,
space_size
,
batch_size
,
eps
);
eps
);
}));
}));
}
else
{
}
else
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_forward_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
batchnorm_forward_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
...
@@ -416,6 +423,7 @@ at::Tensor batchnorm_forward_CUDA(
...
@@ -416,6 +423,7 @@ at::Tensor batchnorm_forward_CUDA(
shift
.
data
<
scalar_t
>
(),
shift
.
data
<
scalar_t
>
(),
out
.
data
<
scalar_t
>
(),
out
.
data
<
scalar_t
>
(),
space_size
,
space_size
,
batch_size
,
eps
);
eps
);
}));
}));
}
}
...
@@ -442,11 +450,10 @@ std::vector<at::Tensor> reduce_bn_CUDA(
...
@@ -442,11 +450,10 @@ std::vector<at::Tensor> reduce_bn_CUDA(
auto
space_size
=
get_tensor_spatial_size
(
input
);
auto
space_size
=
get_tensor_spatial_size
(
input
);
int
block_
x
=
TILE_W
;
int
block_
y
=
min
(
h_last_pow2
(
batch_size
),
int
(
MAX_BLOCK_SIZE
/
32
))
;
int
block_
y
=
m
in
(
h_last_pow2
(
batch_size
)
,
in
t
(
MAX_BLOCK_SIZE
/
block_
x
));
int
block_
x
=
m
ax
(
1
,
m
in
(
MAX_BLOCK_SIZE
/
block_
y
,
h_last_pow2
(
space_size
)
));
const
dim3
block
(
block_x
,
block_y
);
const
dim3
block
(
block_x
,
block_y
);
const
dim3
grid
(
feature_size
);
const
dim3
grid
(
feature_size
);
// shared memory used for reduce on sum_dy, sum_dy_xmu;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
...
@@ -467,7 +474,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
...
@@ -467,7 +474,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
eps
);
eps
);
}));
}));
}
else
{
}
else
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward_reduce"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward_reduce"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
reduce_bn_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
reduce_bn_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
...
@@ -485,7 +492,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
...
@@ -485,7 +492,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
eps
);
eps
);
}));
}));
}
}
return
{
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
};
return
{
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
};
}
}
...
@@ -505,9 +512,13 @@ at::Tensor batchnorm_backward_CUDA(
...
@@ -505,9 +512,13 @@ at::Tensor batchnorm_backward_CUDA(
auto
space_size
=
get_tensor_spatial_size
(
input
);
auto
space_size
=
get_tensor_spatial_size
(
input
);
int
block
=
min
(
MAX_BLOCK_SIZE
,
h_next_pow2
(
space_size
)
/
4
);
int
block_x
=
max
(
32
,
min
(
MAX_BLOCK_SIZE
,
h_last_pow2
(
space_size
)
/
4
));
// TODO(jie): should I do 1 block per feature?
int
block_y
=
max
(
1
,
min
(
MAX_BLOCK_SIZE
/
block_x
,
h_last_pow2
(
batch_size
)
/
4
));
const
dim3
grid
(
feature_size
,
batch_size
);
const
dim3
block
(
block_x
,
block_y
);
int
grid_z
=
max
(
1
,
min
(
65535
,
h_last_pow2
(
space_size
)
/
4
/
block_x
));
int
batch_group_size
=
max
(
1
,
min
(
65535
,
h_last_pow2
(
batch_size
)
/
block_y
));
const
dim3
grid
(
feature_size
,
batch_group_size
,
grid_z
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
...
@@ -523,10 +534,11 @@ at::Tensor batchnorm_backward_CUDA(
...
@@ -523,10 +534,11 @@ at::Tensor batchnorm_backward_CUDA(
mean_dy_xmu
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_input
.
data
<
scalar_t
>
(),
grad_input
.
data
<
scalar_t
>
(),
space_size
,
space_size
,
batch_size
,
eps
);
eps
);
}));
}));
}
else
{
}
else
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_backward_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
batchnorm_backward_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
...
@@ -539,10 +551,11 @@ at::Tensor batchnorm_backward_CUDA(
...
@@ -539,10 +551,11 @@ at::Tensor batchnorm_backward_CUDA(
mean_dy_xmu
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_input
.
data
<
scalar_t
>
(),
grad_input
.
data
<
scalar_t
>
(),
space_size
,
space_size
,
batch_size
,
eps
);
eps
);
}));
}));
}
}
return
grad_input
;
return
grad_input
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment