Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
d81ed26d
Unverified
Commit
d81ed26d
authored
Feb 04, 2019
by
mcarilli
Committed by
GitHub
Feb 04, 2019
Browse files
Merge pull request #143 from NVIDIA/sbn_no_affine
allowing syncBN to run with affine = False
parents
48299b0d
223a47e9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
119 additions
and
66 deletions
+119
-66
csrc/syncbn.cpp
csrc/syncbn.cpp
+8
-8
csrc/welford.cu
csrc/welford.cu
+111
-58
No files found.
csrc/syncbn.cpp
View file @
d81ed26d
...
@@ -21,8 +21,8 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
...
@@ -21,8 +21,8 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
at
::
Tensor
batchnorm_forward_CUDA
(
const
at
::
Tensor
input
,
at
::
Tensor
batchnorm_forward_CUDA
(
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
optional
<
at
::
Tensor
>
weight
,
const
at
::
Tensor
shift
);
const
at
::
optional
<
at
::
Tensor
>
shift
);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// grad_output/input should have identical data type;
...
@@ -32,7 +32,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,
...
@@ -32,7 +32,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,
const
at
::
Tensor
input
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
);
const
at
::
optional
<
at
::
Tensor
>
weight
);
// elementwise backward BN operation, returns grad_input
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// grad_output/input/weight precision could be fp16/fp32;
...
@@ -41,7 +41,7 @@ at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output,
...
@@ -41,7 +41,7 @@ at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output,
const
at
::
Tensor
input
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
optional
<
at
::
Tensor
>
weight
,
const
at
::
Tensor
mean_dy
,
const
at
::
Tensor
mean_dy
,
const
at
::
Tensor
mean_dy_xmu
);
const
at
::
Tensor
mean_dy_xmu
);
...
@@ -57,8 +57,8 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);
...
@@ -57,8 +57,8 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);
at
::
Tensor
batchnorm_forward_c_last_CUDA
(
const
at
::
Tensor
input
,
at
::
Tensor
batchnorm_forward_c_last_CUDA
(
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
optional
<
at
::
Tensor
>
weight
,
const
at
::
Tensor
shift
);
const
at
::
optional
<
at
::
Tensor
>
shift
);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// grad_output/input should have identical data type;
...
@@ -68,7 +68,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output,
...
@@ -68,7 +68,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output,
const
at
::
Tensor
input
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
);
const
at
::
optional
<
at
::
Tensor
>
weight
);
// elementwise backward BN operation, returns grad_input
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// grad_output/input/weight precision could be fp16/fp32;
...
@@ -78,7 +78,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
...
@@ -78,7 +78,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
const
at
::
Tensor
input
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
optional
<
at
::
Tensor
>
weight
,
const
at
::
Tensor
mean_dy
,
const
at
::
Tensor
mean_dy
,
const
at
::
Tensor
mean_dy_xmu
);
const
at
::
Tensor
mean_dy_xmu
);
...
...
csrc/welford.cu
View file @
d81ed26d
...
@@ -305,8 +305,8 @@ __global__ void batchnorm_forward_kernel(
...
@@ -305,8 +305,8 @@ __global__ void batchnorm_forward_kernel(
const
int
bs
)
{
const
int
bs
)
{
auto
m_c
=
mean
[
blockIdx
.
x
];
auto
m_c
=
mean
[
blockIdx
.
x
];
auto
inv_std_c
=
inv_std
[
blockIdx
.
x
];
auto
inv_std_c
=
inv_std
[
blockIdx
.
x
];
auto
w_c
=
static_cast
<
accscalar_t
>
(
weight
[
blockIdx
.
x
]);
auto
w_c
=
weight
==
NULL
?
accscalar_t
(
1.0
)
:
static_cast
<
accscalar_t
>
(
weight
[
blockIdx
.
x
]);
auto
s_c
=
static_cast
<
accscalar_t
>
(
shift
[
blockIdx
.
x
]);
auto
s_c
=
shift
==
NULL
?
accscalar_t
(
0.0
)
:
static_cast
<
accscalar_t
>
(
shift
[
blockIdx
.
x
]);
for
(
int
batch_offset
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
batch_offset
<
bs
;
batch_offset
+=
gridDim
.
y
*
blockDim
.
y
)
{
for
(
int
batch_offset
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
batch_offset
<
bs
;
batch_offset
+=
gridDim
.
y
*
blockDim
.
y
)
{
int
address_base
=
blockIdx
.
x
*
ss
+
batch_offset
*
gridDim
.
x
*
ss
;
int
address_base
=
blockIdx
.
x
*
ss
+
batch_offset
*
gridDim
.
x
*
ss
;
...
@@ -370,8 +370,12 @@ __global__ void reduce_bn_kernel(
...
@@ -370,8 +370,12 @@ __global__ void reduce_bn_kernel(
sum_dy_xmu
=
reduce_block
((
accscalar_t
*
)
s_mem
,
sum_dy_xmu
);
sum_dy_xmu
=
reduce_block
((
accscalar_t
*
)
s_mem
,
sum_dy_xmu
);
if
(
thread_id
==
0
)
{
if
(
thread_id
==
0
)
{
if
(
grad_bias
!=
NULL
)
{
grad_bias
[
blockIdx
.
x
]
=
static_cast
<
layerscalar_t
>
(
sum_dy
);
grad_bias
[
blockIdx
.
x
]
=
static_cast
<
layerscalar_t
>
(
sum_dy
);
}
if
(
grad_weight
!=
NULL
)
{
grad_weight
[
blockIdx
.
x
]
=
static_cast
<
layerscalar_t
>
(
sum_dy_xmu
*
factor
);
grad_weight
[
blockIdx
.
x
]
=
static_cast
<
layerscalar_t
>
(
sum_dy_xmu
*
factor
);
}
mean_dy
[
blockIdx
.
x
]
=
sum_dy
/
total_item_num
;
mean_dy
[
blockIdx
.
x
]
=
sum_dy
/
total_item_num
;
mean_dy_xmu
[
blockIdx
.
x
]
=
sum_dy_xmu
/
total_item_num
;
mean_dy_xmu
[
blockIdx
.
x
]
=
sum_dy_xmu
/
total_item_num
;
}
}
...
@@ -393,7 +397,7 @@ __global__ void batchnorm_backward_kernel(
...
@@ -393,7 +397,7 @@ __global__ void batchnorm_backward_kernel(
auto
m_c
=
static_cast
<
accscalar_t
>
(
mean
[
blockIdx
.
x
]);
auto
m_c
=
static_cast
<
accscalar_t
>
(
mean
[
blockIdx
.
x
]);
auto
m_dy_c
=
static_cast
<
accscalar_t
>
(
mean_dy
[
blockIdx
.
x
]);
auto
m_dy_c
=
static_cast
<
accscalar_t
>
(
mean_dy
[
blockIdx
.
x
]);
auto
factor_1_c
=
inv_std
[
blockIdx
.
x
];
auto
factor_1_c
=
inv_std
[
blockIdx
.
x
];
auto
factor_2_c
=
static_cast
<
accscalar_t
>
(
weight
[
blockIdx
.
x
])
*
factor_1_c
;
auto
factor_2_c
=
(
weight
==
NULL
?
accscalar_t
(
1.0
)
:
static_cast
<
accscalar_t
>
(
weight
[
blockIdx
.
x
])
)
*
factor_1_c
;
factor_1_c
=
factor_1_c
*
factor_1_c
*
mean_dy_xmu
[
blockIdx
.
x
];
factor_1_c
=
factor_1_c
*
factor_1_c
*
mean_dy_xmu
[
blockIdx
.
x
];
for
(
int
batch_offset
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
batch_offset
<
bs
;
batch_offset
+=
gridDim
.
y
*
blockDim
.
y
)
{
for
(
int
batch_offset
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
batch_offset
<
bs
;
batch_offset
+=
gridDim
.
y
*
blockDim
.
y
)
{
...
@@ -603,8 +607,8 @@ __global__ void batchnorm_forward_c_last_kernel(
...
@@ -603,8 +607,8 @@ __global__ void batchnorm_forward_c_last_kernel(
auto
m_c
=
mean
[
c_offset
];
auto
m_c
=
mean
[
c_offset
];
auto
inv_std_c
=
static_cast
<
accscalar_t
>
(
inv_std
[
c_offset
]);
auto
inv_std_c
=
static_cast
<
accscalar_t
>
(
inv_std
[
c_offset
]);
auto
w_c
=
static_cast
<
accscalar_t
>
(
weight
[
c_offset
]);
auto
w_c
=
weight
==
NULL
?
accscalar_t
(
1.0
)
:
static_cast
<
accscalar_t
>
(
weight
[
c_offset
]);
auto
s_c
=
static_cast
<
accscalar_t
>
(
shift
[
c_offset
]);
auto
s_c
=
shift
==
NULL
?
accscalar_t
(
0.0
)
:
static_cast
<
accscalar_t
>
(
shift
[
c_offset
]);
int
loop_count
=
1
+
(
reduction_size
-
1
)
/
(
inner_loop_stride
*
PARALLEL_LOADS
);
int
loop_count
=
1
+
(
reduction_size
-
1
)
/
(
inner_loop_stride
*
PARALLEL_LOADS
);
int
address_base
=
m_offset
*
stride
+
c_offset
;
int
address_base
=
m_offset
*
stride
+
c_offset
;
...
@@ -749,16 +753,24 @@ __global__ void reduce_bn_c_last_kernel(
...
@@ -749,16 +753,24 @@ __global__ void reduce_bn_c_last_kernel(
merge_block_vertical
(
sum_dy_th
,
sum_dy_xmu_th
,
shmem_sum_dy
,
shmem_sum_dy_xmu
);
merge_block_vertical
(
sum_dy_th
,
sum_dy_xmu_th
,
shmem_sum_dy
,
shmem_sum_dy_xmu
);
if
(
threadIdx
.
y
==
0
&&
c_offset
<
stride
)
{
if
(
threadIdx
.
y
==
0
&&
c_offset
<
stride
)
{
if
(
grad_bias
!=
NULL
)
{
grad_bias
[
c_offset
]
=
static_cast
<
layerscalar_t
>
(
sum_dy_th
);
grad_bias
[
c_offset
]
=
static_cast
<
layerscalar_t
>
(
sum_dy_th
);
}
if
(
grad_weight
!=
NULL
)
{
grad_weight
[
c_offset
]
=
static_cast
<
layerscalar_t
>
(
sum_dy_xmu_th
*
factor
);
grad_weight
[
c_offset
]
=
static_cast
<
layerscalar_t
>
(
sum_dy_xmu_th
*
factor
);
}
mean_dy
[
c_offset
]
=
sum_dy_th
/
reduction_size
;
mean_dy
[
c_offset
]
=
sum_dy_th
/
reduction_size
;
mean_dy_xmu
[
c_offset
]
=
sum_dy_xmu_th
/
reduction_size
;
mean_dy_xmu
[
c_offset
]
=
sum_dy_xmu_th
/
reduction_size
;
}
}
}
}
}
else
{
}
else
{
if
(
blockIdx
.
y
==
0
&&
threadIdx
.
y
==
0
&&
c_offset
<
stride
)
{
if
(
blockIdx
.
y
==
0
&&
threadIdx
.
y
==
0
&&
c_offset
<
stride
)
{
if
(
grad_bias
!=
NULL
)
{
grad_bias
[
c_offset
]
=
static_cast
<
layerscalar_t
>
(
sum_dy_th
);
grad_bias
[
c_offset
]
=
static_cast
<
layerscalar_t
>
(
sum_dy_th
);
}
if
(
grad_weight
!=
NULL
)
{
grad_weight
[
c_offset
]
=
static_cast
<
layerscalar_t
>
(
sum_dy_xmu_th
*
factor
);
grad_weight
[
c_offset
]
=
static_cast
<
layerscalar_t
>
(
sum_dy_xmu_th
*
factor
);
}
mean_dy
[
c_offset
]
=
sum_dy_th
/
reduction_size
;
mean_dy
[
c_offset
]
=
sum_dy_th
/
reduction_size
;
mean_dy_xmu
[
c_offset
]
=
sum_dy_xmu_th
/
reduction_size
;
mean_dy_xmu
[
c_offset
]
=
sum_dy_xmu_th
/
reduction_size
;
}
}
...
@@ -793,7 +805,7 @@ __global__ void batchnorm_backward_c_last_kernel(
...
@@ -793,7 +805,7 @@ __global__ void batchnorm_backward_c_last_kernel(
auto
m_c
=
mean
[
c_offset
];
auto
m_c
=
mean
[
c_offset
];
auto
m_dy_c
=
mean_dy
[
c_offset
];
auto
m_dy_c
=
mean_dy
[
c_offset
];
auto
factor_1_c
=
inv_std
[
c_offset
];
auto
factor_1_c
=
inv_std
[
c_offset
];
auto
factor_2_c
=
static_cast
<
accscalar_t
>
(
weight
[
c_offset
])
*
factor_1_c
;
auto
factor_2_c
=
(
weight
==
NULL
?
accscalar_t
(
1.0
)
:
static_cast
<
accscalar_t
>
(
weight
[
c_offset
])
)
*
factor_1_c
;
factor_1_c
=
factor_1_c
*
factor_1_c
*
mean_dy_xmu
[
c_offset
];
factor_1_c
=
factor_1_c
*
factor_1_c
*
mean_dy_xmu
[
c_offset
];
int
loop_count
=
1
+
(
reduction_size
-
1
)
/
(
inner_loop_stride
*
PARALLEL_LOADS
);
int
loop_count
=
1
+
(
reduction_size
-
1
)
/
(
inner_loop_stride
*
PARALLEL_LOADS
);
...
@@ -850,8 +862,8 @@ at::Tensor batchnorm_forward_CUDA(
...
@@ -850,8 +862,8 @@ at::Tensor batchnorm_forward_CUDA(
const
at
::
Tensor
input
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
optional
<
at
::
Tensor
>
weight
,
const
at
::
Tensor
shift
)
{
const
at
::
optional
<
at
::
Tensor
>
shift
)
{
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
feature_size
=
input
.
size
(
1
);
const
auto
feature_size
=
input
.
size
(
1
);
at
::
Tensor
out
=
at
::
empty_like
(
input
);
at
::
Tensor
out
=
at
::
empty_like
(
input
);
...
@@ -866,29 +878,34 @@ at::Tensor batchnorm_forward_CUDA(
...
@@ -866,29 +878,34 @@ at::Tensor batchnorm_forward_CUDA(
const
dim3
grid
(
feature_size
,
batch_group_size
,
grid_z
);
const
dim3
grid
(
feature_size
,
batch_group_size
,
grid_z
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
has_value
()
&&
weight
.
value
().
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_forward_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
batchnorm_forward_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
data
<
accscalar_t
>
(),
weight
.
has_value
()
?
weight
.
value
().
data
<
accscalar_t
>
()
:
NULL
,
shift
.
data
<
accscalar_t
>
(),
shift
.
has_value
()
?
shift
.
value
().
data
<
accscalar_t
>
()
:
NULL
,
out
.
data
<
scalar_t
>
(),
out
.
data
<
scalar_t
>
(),
space_size
,
space_size
,
batch_size
);
batch_size
);
}));
}));
}
else
{
}
else
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
if
(
weight
.
has_value
())
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
value
().
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_forward_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
batchnorm_forward_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
data
<
scalar_t
>
(),
weight
.
has_value
()
?
weight
.
value
().
data
<
scalar_t
>
()
:
NULL
,
shift
.
data
<
scalar_t
>
(),
shift
.
has_value
()
?
shift
.
value
().
data
<
scalar_t
>
()
:
NULL
,
out
.
data
<
scalar_t
>
(),
out
.
data
<
scalar_t
>
(),
space_size
,
space_size
,
batch_size
);
batch_size
);
...
@@ -902,7 +919,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
...
@@ -902,7 +919,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
const
at
::
Tensor
input
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
)
const
at
::
optional
<
at
::
Tensor
>
weight
)
{
{
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
feature_size
=
input
.
size
(
1
);
const
auto
feature_size
=
input
.
size
(
1
);
...
@@ -911,8 +928,16 @@ std::vector<at::Tensor> reduce_bn_CUDA(
...
@@ -911,8 +928,16 @@ std::vector<at::Tensor> reduce_bn_CUDA(
at
::
Tensor
mean_dy
=
at
::
empty
({
feature_size
},
mean
.
options
());
at
::
Tensor
mean_dy
=
at
::
empty
({
feature_size
},
mean
.
options
());
at
::
Tensor
mean_dy_xmu
=
at
::
empty
({
feature_size
},
mean
.
options
());
at
::
Tensor
mean_dy_xmu
=
at
::
empty
({
feature_size
},
mean
.
options
());
at
::
Tensor
grad_weight
=
at
::
empty
({
feature_size
},
weight
.
options
());
at
::
Tensor
grad_bias
=
at
::
empty
({
feature_size
},
weight
.
options
());
at
::
Tensor
grad_weight
;
at
::
Tensor
grad_bias
;
if
(
weight
.
has_value
())
{
grad_weight
=
at
::
empty
({
feature_size
},
weight
.
value
().
options
());
grad_bias
=
at
::
empty
({
feature_size
},
weight
.
value
().
options
());
}
else
{
grad_weight
=
at
::
empty
({
0
},
mean
.
options
());
grad_bias
=
at
::
empty
({
0
},
mean
.
options
());
}
auto
space_size
=
get_tensor_spatial_size
(
input
);
auto
space_size
=
get_tensor_spatial_size
(
input
);
...
@@ -922,7 +947,9 @@ std::vector<at::Tensor> reduce_bn_CUDA(
...
@@ -922,7 +947,9 @@ std::vector<at::Tensor> reduce_bn_CUDA(
const
dim3
grid
(
feature_size
);
const
dim3
grid
(
feature_size
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
has_value
()
&&
weight
.
value
().
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward_reduce"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward_reduce"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
reduce_bn_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
reduce_bn_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
...
@@ -932,14 +959,17 @@ std::vector<at::Tensor> reduce_bn_CUDA(
...
@@ -932,14 +959,17 @@ std::vector<at::Tensor> reduce_bn_CUDA(
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_weight
.
data
<
accscalar_t
>
(),
weight
.
has_value
()
?
grad_weight
.
data
<
accscalar_t
>
()
:
NULL
,
grad_bias
.
data
<
accscalar_t
>
(),
weight
.
has_value
()
?
grad_bias
.
data
<
accscalar_t
>
()
:
NULL
,
batch_size
,
batch_size
,
feature_size
,
feature_size
,
space_size
);
space_size
);
}));
}));
}
else
{
}
else
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
if
(
weight
.
has_value
())
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
value
().
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward_reduce"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward_reduce"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
reduce_bn_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
reduce_bn_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
...
@@ -949,8 +979,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
...
@@ -949,8 +979,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_weight
.
data
<
scalar_t
>
(),
weight
.
has_value
()
?
grad_weight
.
data
<
scalar_t
>
()
:
NULL
,
grad_bias
.
data
<
scalar_t
>
(),
weight
.
has_value
()
?
grad_bias
.
data
<
scalar_t
>
()
:
NULL
,
batch_size
,
batch_size
,
feature_size
,
feature_size
,
space_size
);
space_size
);
...
@@ -965,7 +995,7 @@ at::Tensor batchnorm_backward_CUDA(
...
@@ -965,7 +995,7 @@ at::Tensor batchnorm_backward_CUDA(
const
at
::
Tensor
input
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
optional
<
at
::
Tensor
>
weight
,
const
at
::
Tensor
mean_dy
,
const
at
::
Tensor
mean_dy
,
const
at
::
Tensor
mean_dy_xmu
)
{
const
at
::
Tensor
mean_dy_xmu
)
{
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
batch_size
=
input
.
size
(
0
);
...
@@ -984,7 +1014,9 @@ at::Tensor batchnorm_backward_CUDA(
...
@@ -984,7 +1014,9 @@ at::Tensor batchnorm_backward_CUDA(
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
has_value
()
&&
weight
.
value
().
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_backward_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
batchnorm_backward_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
...
@@ -992,7 +1024,7 @@ at::Tensor batchnorm_backward_CUDA(
...
@@ -992,7 +1024,7 @@ at::Tensor batchnorm_backward_CUDA(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
data
<
accscalar_t
>
(),
weight
.
has_value
()
?
weight
.
value
().
data
<
accscalar_t
>
()
:
NULL
,
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_input
.
data
<
scalar_t
>
(),
grad_input
.
data
<
scalar_t
>
(),
...
@@ -1000,7 +1032,10 @@ at::Tensor batchnorm_backward_CUDA(
...
@@ -1000,7 +1032,10 @@ at::Tensor batchnorm_backward_CUDA(
batch_size
);
batch_size
);
}));
}));
}
else
{
}
else
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
if
(
weight
.
has_value
())
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
value
().
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_backward_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
batchnorm_backward_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
...
@@ -1008,7 +1043,7 @@ at::Tensor batchnorm_backward_CUDA(
...
@@ -1008,7 +1043,7 @@ at::Tensor batchnorm_backward_CUDA(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
data
<
scalar_t
>
(),
weight
.
has_value
()
?
weight
.
value
().
data
<
scalar_t
>
()
:
NULL
,
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_input
.
data
<
scalar_t
>
(),
grad_input
.
data
<
scalar_t
>
(),
...
@@ -1099,8 +1134,8 @@ at::Tensor batchnorm_forward_c_last_CUDA(
...
@@ -1099,8 +1134,8 @@ at::Tensor batchnorm_forward_c_last_CUDA(
const
at
::
Tensor
input
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
optional
<
at
::
Tensor
>
weight
,
const
at
::
Tensor
shift
)
{
const
at
::
optional
<
at
::
Tensor
>
shift
)
{
const
auto
stride
=
input
.
size
(
input
.
ndimension
()
-
1
);
const
auto
stride
=
input
.
size
(
input
.
ndimension
()
-
1
);
const
auto
reduction_size
=
input
.
numel
()
/
stride
;
const
auto
reduction_size
=
input
.
numel
()
/
stride
;
...
@@ -1113,7 +1148,7 @@ at::Tensor batchnorm_forward_c_last_CUDA(
...
@@ -1113,7 +1148,7 @@ at::Tensor batchnorm_forward_c_last_CUDA(
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
&&
weight
.
has_value
()
&&
weight
.
value
().
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_forward_c_last_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
batchnorm_forward_c_last_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
...
@@ -1121,15 +1156,17 @@ at::Tensor batchnorm_forward_c_last_CUDA(
...
@@ -1121,15 +1156,17 @@ at::Tensor batchnorm_forward_c_last_CUDA(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
data
<
accscalar_t
>
(),
weight
.
has_value
()
?
weight
.
value
().
data
<
accscalar_t
>
()
:
NULL
,
shift
.
data
<
accscalar_t
>
(),
shift
.
has_value
()
?
shift
.
value
().
data
<
accscalar_t
>
()
:
NULL
,
out
.
data
<
scalar_t
>
(),
out
.
data
<
scalar_t
>
(),
reduction_size
,
reduction_size
,
stride
);
stride
);
}));
}));
}
else
{
}
else
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
type
().
scalarType
(),
if
(
weight
.
has_value
())
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
value
().
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_forward_c_last_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
,
ELEMENTS_PER_ITER
>
batchnorm_forward_c_last_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
,
ELEMENTS_PER_ITER
>
...
@@ -1137,8 +1174,8 @@ at::Tensor batchnorm_forward_c_last_CUDA(
...
@@ -1137,8 +1174,8 @@ at::Tensor batchnorm_forward_c_last_CUDA(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
data
<
scalar_t
>
(),
weight
.
has_value
()
?
weight
.
value
().
data
<
scalar_t
>
()
:
NULL
,
shift
.
data
<
scalar_t
>
(),
shift
.
has_value
()
?
shift
.
value
().
data
<
scalar_t
>
()
:
NULL
,
out
.
data
<
scalar_t
>
(),
out
.
data
<
scalar_t
>
(),
reduction_size
,
reduction_size
,
stride
);
stride
);
...
@@ -1152,14 +1189,23 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
...
@@ -1152,14 +1189,23 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
const
at
::
Tensor
input
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
)
{
const
at
::
optional
<
at
::
Tensor
>
weight
)
{
const
auto
stride
=
input
.
size
(
input
.
ndimension
()
-
1
);
const
auto
stride
=
input
.
size
(
input
.
ndimension
()
-
1
);
const
auto
reduction_size
=
input
.
numel
()
/
stride
;
const
auto
reduction_size
=
input
.
numel
()
/
stride
;
at
::
Tensor
mean_dy
=
at
::
empty
({
stride
},
mean
.
options
());
at
::
Tensor
mean_dy
=
at
::
empty
({
stride
},
mean
.
options
());
at
::
Tensor
mean_dy_xmu
=
at
::
empty
({
stride
},
mean
.
options
());
at
::
Tensor
mean_dy_xmu
=
at
::
empty
({
stride
},
mean
.
options
());
at
::
Tensor
grad_weight
=
at
::
empty
({
stride
},
weight
.
options
());
at
::
Tensor
grad_bias
=
at
::
empty
({
stride
},
weight
.
options
());
at
::
Tensor
grad_weight
;
at
::
Tensor
grad_bias
;
if
(
weight
.
has_value
())
{
grad_weight
=
at
::
empty
({
stride
},
weight
.
value
().
options
());
grad_bias
=
at
::
empty
({
stride
},
weight
.
value
().
options
());
}
else
{
// because I cannot return an uninitialized at::Tensor
grad_weight
=
at
::
empty
({
0
},
mean
.
options
());
grad_bias
=
at
::
empty
({
0
},
mean
.
options
());
}
dim3
block
;
dim3
block
;
dim3
grid
;
dim3
grid
;
...
@@ -1173,7 +1219,9 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
...
@@ -1173,7 +1219,9 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
}
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
has_value
()
&&
weight
.
value
().
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward_reduce"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward_reduce"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
accscalar_t
*
staging_data_ptr
=
grid
.
y
>
1
?
staging_data
.
data
<
accscalar_t
>
()
:
nullptr
;
accscalar_t
*
staging_data_ptr
=
grid
.
y
>
1
?
staging_data
.
data
<
accscalar_t
>
()
:
nullptr
;
...
@@ -1186,15 +1234,18 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
...
@@ -1186,15 +1234,18 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_weight
.
data
<
accscalar_t
>
(),
weight
.
has_value
()
?
grad_weight
.
data
<
accscalar_t
>
()
:
NULL
,
grad_bias
.
data
<
accscalar_t
>
(),
weight
.
has_value
()
?
grad_bias
.
data
<
accscalar_t
>
()
:
NULL
,
staging_data_ptr
,
staging_data_ptr
,
semaphores_ptr
,
semaphores_ptr
,
reduction_size
,
reduction_size
,
stride
);
stride
);
}));
}));
}
else
{
}
else
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
if
(
weight
.
has_value
())
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
value
().
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward_reduce"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_backward_reduce"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
accscalar_t
*
staging_data_ptr
=
grid
.
y
>
1
?
staging_data
.
data
<
accscalar_t
>
()
:
nullptr
;
accscalar_t
*
staging_data_ptr
=
grid
.
y
>
1
?
staging_data
.
data
<
accscalar_t
>
()
:
nullptr
;
...
@@ -1207,8 +1258,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
...
@@ -1207,8 +1258,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_weight
.
data
<
scalar_t
>
(),
weight
.
has_value
()
?
grad_weight
.
data
<
scalar_t
>
()
:
NULL
,
grad_bias
.
data
<
scalar_t
>
(),
weight
.
has_value
()
?
grad_bias
.
data
<
scalar_t
>
()
:
NULL
,
staging_data_ptr
,
staging_data_ptr
,
semaphores_ptr
,
semaphores_ptr
,
reduction_size
,
reduction_size
,
...
@@ -1224,7 +1275,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
...
@@ -1224,7 +1275,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
const
at
::
Tensor
input
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
optional
<
at
::
Tensor
>
weight
,
const
at
::
Tensor
mean_dy
,
const
at
::
Tensor
mean_dy
,
const
at
::
Tensor
mean_dy_xmu
)
{
const
at
::
Tensor
mean_dy_xmu
)
{
const
auto
stride
=
input
.
size
(
input
.
ndimension
()
-
1
);
const
auto
stride
=
input
.
size
(
input
.
ndimension
()
-
1
);
...
@@ -1239,7 +1290,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
...
@@ -1239,7 +1290,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
if
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
&&
weight
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
&&
weight
.
has_value
()
&&
weight
.
value
().
type
().
scalarType
()
==
at
::
ScalarType
::
Float
)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_backward_c_last_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
batchnorm_backward_c_last_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
...
@@ -1248,7 +1299,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
...
@@ -1248,7 +1299,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
data
<
accscalar_t
>
(),
weight
.
has_value
()
?
weight
.
value
().
data
<
accscalar_t
>
()
:
NULL
,
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_input
.
data
<
scalar_t
>
(),
grad_input
.
data
<
scalar_t
>
(),
...
@@ -1256,8 +1307,10 @@ at::Tensor batchnorm_backward_c_last_CUDA(
...
@@ -1256,8 +1307,10 @@ at::Tensor batchnorm_backward_c_last_CUDA(
stride
);
stride
);
}));
}));
}
else
{
}
else
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
type
().
scalarType
(),
if
(
weight
.
has_value
())
{
AT_CHECK
(
input
.
type
().
scalarType
()
==
weight
.
value
().
type
().
scalarType
(),
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
"input.type().scalarType() is not supported with weight.type().scalarType()"
);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
batchnorm_backward_c_last_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
,
ELEMENTS_PER_ITER
>
batchnorm_backward_c_last_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
,
ELEMENTS_PER_ITER
>
...
@@ -1266,7 +1319,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
...
@@ -1266,7 +1319,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
data
<
scalar_t
>
(),
weight
.
has_value
()
?
weight
.
value
().
data
<
scalar_t
>
()
:
NULL
,
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_input
.
data
<
scalar_t
>
(),
grad_input
.
data
<
scalar_t
>
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment