Commit 223a47e9 authored by jiej's avatar jiej
Browse files

allowing syncBN to run with affine = False

parent aed3086a
...@@ -21,8 +21,8 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node ...@@ -21,8 +21,8 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
at::Tensor batchnorm_forward_CUDA(const at::Tensor input, at::Tensor batchnorm_forward_CUDA(const at::Tensor input,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor inv_std, const at::Tensor inv_std,
const at::Tensor weight, const at::optional<at::Tensor> weight,
const at::Tensor shift); const at::optional<at::Tensor> shift);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias} // backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type; // grad_output/input should have identical data type;
...@@ -32,7 +32,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output, ...@@ -32,7 +32,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,
const at::Tensor input, const at::Tensor input,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor inv_std, const at::Tensor inv_std,
const at::Tensor weight); const at::optional<at::Tensor> weight);
// elementwise backward BN operation, returns grad_input // elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32; // grad_output/input/weight precision could be fp16/fp32;
...@@ -41,7 +41,7 @@ at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output, ...@@ -41,7 +41,7 @@ at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output,
const at::Tensor input, const at::Tensor input,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor inv_std, const at::Tensor inv_std,
const at::Tensor weight, const at::optional<at::Tensor> weight,
const at::Tensor mean_dy, const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu); const at::Tensor mean_dy_xmu);
...@@ -57,8 +57,8 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input); ...@@ -57,8 +57,8 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);
at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input, at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor inv_std, const at::Tensor inv_std,
const at::Tensor weight, const at::optional<at::Tensor> weight,
const at::Tensor shift); const at::optional<at::Tensor> shift);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias} // backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type; // grad_output/input should have identical data type;
...@@ -68,7 +68,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output, ...@@ -68,7 +68,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input, const at::Tensor input,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor inv_std, const at::Tensor inv_std,
const at::Tensor weight); const at::optional<at::Tensor> weight);
// elementwise backward BN operation, returns grad_input // elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32; // grad_output/input/weight precision could be fp16/fp32;
...@@ -78,7 +78,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output, ...@@ -78,7 +78,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input, const at::Tensor input,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor inv_std, const at::Tensor inv_std,
const at::Tensor weight, const at::optional<at::Tensor> weight,
const at::Tensor mean_dy, const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu); const at::Tensor mean_dy_xmu);
......
...@@ -305,8 +305,8 @@ __global__ void batchnorm_forward_kernel( ...@@ -305,8 +305,8 @@ __global__ void batchnorm_forward_kernel(
const int bs) { const int bs) {
auto m_c = mean[blockIdx.x]; auto m_c = mean[blockIdx.x];
auto inv_std_c = inv_std[blockIdx.x]; auto inv_std_c = inv_std[blockIdx.x];
auto w_c = static_cast<accscalar_t>(weight[blockIdx.x]); auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x]);
auto s_c = static_cast<accscalar_t>(shift[blockIdx.x]); auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[blockIdx.x]);
for (int batch_offset = blockIdx.y*blockDim.y + threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) { for (int batch_offset = blockIdx.y*blockDim.y + threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {
int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss; int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss;
...@@ -370,8 +370,12 @@ __global__ void reduce_bn_kernel( ...@@ -370,8 +370,12 @@ __global__ void reduce_bn_kernel(
sum_dy_xmu = reduce_block((accscalar_t*)s_mem, sum_dy_xmu); sum_dy_xmu = reduce_block((accscalar_t*)s_mem, sum_dy_xmu);
if (thread_id == 0) { if (thread_id == 0) {
if (grad_bias != NULL) {
grad_bias[blockIdx.x] = static_cast<layerscalar_t>(sum_dy); grad_bias[blockIdx.x] = static_cast<layerscalar_t>(sum_dy);
}
if (grad_weight != NULL) {
grad_weight[blockIdx.x] = static_cast<layerscalar_t>(sum_dy_xmu * factor); grad_weight[blockIdx.x] = static_cast<layerscalar_t>(sum_dy_xmu * factor);
}
mean_dy[blockIdx.x] = sum_dy / total_item_num; mean_dy[blockIdx.x] = sum_dy / total_item_num;
mean_dy_xmu[blockIdx.x] = sum_dy_xmu / total_item_num; mean_dy_xmu[blockIdx.x] = sum_dy_xmu / total_item_num;
} }
...@@ -393,7 +397,7 @@ __global__ void batchnorm_backward_kernel( ...@@ -393,7 +397,7 @@ __global__ void batchnorm_backward_kernel(
auto m_c = static_cast<accscalar_t>(mean[blockIdx.x]); auto m_c = static_cast<accscalar_t>(mean[blockIdx.x]);
auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]); auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]);
auto factor_1_c = inv_std[blockIdx.x]; auto factor_1_c = inv_std[blockIdx.x];
auto factor_2_c = static_cast<accscalar_t>(weight[blockIdx.x]) * factor_1_c; auto factor_2_c = (weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x])) * factor_1_c;
factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[blockIdx.x]; factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[blockIdx.x];
for (int batch_offset = blockIdx.y*blockDim.y+threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) { for (int batch_offset = blockIdx.y*blockDim.y+threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {
...@@ -603,8 +607,8 @@ __global__ void batchnorm_forward_c_last_kernel( ...@@ -603,8 +607,8 @@ __global__ void batchnorm_forward_c_last_kernel(
auto m_c = mean[c_offset]; auto m_c = mean[c_offset];
auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]); auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
auto w_c = static_cast<accscalar_t>(weight[c_offset]); auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
auto s_c = static_cast<accscalar_t>(shift[c_offset]); auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
int address_base = m_offset * stride + c_offset; int address_base = m_offset * stride + c_offset;
...@@ -749,16 +753,24 @@ __global__ void reduce_bn_c_last_kernel( ...@@ -749,16 +753,24 @@ __global__ void reduce_bn_c_last_kernel(
merge_block_vertical(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu); merge_block_vertical(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
if (threadIdx.y == 0 && c_offset < stride) { if (threadIdx.y == 0 && c_offset < stride) {
if (grad_bias != NULL) {
grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th); grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
}
if (grad_weight != NULL) {
grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor); grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
}
mean_dy[c_offset] = sum_dy_th / reduction_size; mean_dy[c_offset] = sum_dy_th / reduction_size;
mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size; mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
} }
} }
} else { } else {
if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) { if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
if (grad_bias != NULL) {
grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th); grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
}
if (grad_weight != NULL) {
grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor); grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
}
mean_dy[c_offset] = sum_dy_th / reduction_size; mean_dy[c_offset] = sum_dy_th / reduction_size;
mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size; mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
} }
...@@ -793,7 +805,7 @@ __global__ void batchnorm_backward_c_last_kernel( ...@@ -793,7 +805,7 @@ __global__ void batchnorm_backward_c_last_kernel(
auto m_c = mean[c_offset]; auto m_c = mean[c_offset];
auto m_dy_c = mean_dy[c_offset]; auto m_dy_c = mean_dy[c_offset];
auto factor_1_c = inv_std[c_offset]; auto factor_1_c = inv_std[c_offset];
auto factor_2_c = static_cast<accscalar_t>(weight[c_offset]) * factor_1_c; auto factor_2_c = (weight == NULL? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset])) * factor_1_c;
factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[c_offset]; factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[c_offset];
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
...@@ -850,8 +862,8 @@ at::Tensor batchnorm_forward_CUDA( ...@@ -850,8 +862,8 @@ at::Tensor batchnorm_forward_CUDA(
const at::Tensor input, const at::Tensor input,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor inv_std, const at::Tensor inv_std,
const at::Tensor weight, const at::optional<at::Tensor> weight,
const at::Tensor shift) { const at::optional<at::Tensor> shift) {
const auto batch_size = input.size(0); const auto batch_size = input.size(0);
const auto feature_size = input.size(1); const auto feature_size = input.size(1);
at::Tensor out = at::empty_like(input); at::Tensor out = at::empty_like(input);
...@@ -866,29 +878,34 @@ at::Tensor batchnorm_forward_CUDA( ...@@ -866,29 +878,34 @@ at::Tensor batchnorm_forward_CUDA(
const dim3 grid(feature_size, batch_group_size, grid_z); const dim3 grid(feature_size, batch_group_size, grid_z);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) { if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( batchnorm_forward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.data<accscalar_t>(), weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.data<accscalar_t>(), shift.has_value() ? shift.value().data<accscalar_t>() : NULL,
out.data<scalar_t>(), out.data<scalar_t>(),
space_size, space_size,
batch_size); batch_size);
})); }));
} else { } else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()"); if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>( batchnorm_forward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.data<scalar_t>(), weight.has_value() ? weight.value().data<scalar_t>() : NULL,
shift.data<scalar_t>(), shift.has_value() ? shift.value().data<scalar_t>() : NULL,
out.data<scalar_t>(), out.data<scalar_t>(),
space_size, space_size,
batch_size); batch_size);
...@@ -902,7 +919,7 @@ std::vector<at::Tensor> reduce_bn_CUDA( ...@@ -902,7 +919,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
const at::Tensor input, const at::Tensor input,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor inv_std, const at::Tensor inv_std,
const at::Tensor weight) const at::optional<at::Tensor> weight)
{ {
const auto batch_size = input.size(0); const auto batch_size = input.size(0);
const auto feature_size = input.size(1); const auto feature_size = input.size(1);
...@@ -911,8 +928,16 @@ std::vector<at::Tensor> reduce_bn_CUDA( ...@@ -911,8 +928,16 @@ std::vector<at::Tensor> reduce_bn_CUDA(
at::Tensor mean_dy = at::empty({feature_size}, mean.options()); at::Tensor mean_dy = at::empty({feature_size}, mean.options());
at::Tensor mean_dy_xmu = at::empty({feature_size}, mean.options()); at::Tensor mean_dy_xmu = at::empty({feature_size}, mean.options());
at::Tensor grad_weight = at::empty({feature_size}, weight.options());
at::Tensor grad_bias = at::empty({feature_size}, weight.options()); at::Tensor grad_weight;
at::Tensor grad_bias;
if (weight.has_value()) {
grad_weight = at::empty({feature_size}, weight.value().options());
grad_bias = at::empty({feature_size}, weight.value().options());
} else {
grad_weight = at::empty({0}, mean.options());
grad_bias = at::empty({0}, mean.options());
}
auto space_size = get_tensor_spatial_size(input); auto space_size = get_tensor_spatial_size(input);
...@@ -922,7 +947,9 @@ std::vector<at::Tensor> reduce_bn_CUDA( ...@@ -922,7 +947,9 @@ std::vector<at::Tensor> reduce_bn_CUDA(
const dim3 grid(feature_size); const dim3 grid(feature_size);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) { if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
reduce_bn_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( reduce_bn_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
...@@ -932,14 +959,17 @@ std::vector<at::Tensor> reduce_bn_CUDA( ...@@ -932,14 +959,17 @@ std::vector<at::Tensor> reduce_bn_CUDA(
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
grad_weight.data<accscalar_t>(), weight.has_value() ? grad_weight.data<accscalar_t>() : NULL,
grad_bias.data<accscalar_t>(), weight.has_value() ? grad_bias.data<accscalar_t>() : NULL,
batch_size, batch_size,
feature_size, feature_size,
space_size); space_size);
})); }));
} else { } else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()"); if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
reduce_bn_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>( reduce_bn_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
...@@ -949,8 +979,8 @@ std::vector<at::Tensor> reduce_bn_CUDA( ...@@ -949,8 +979,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
grad_weight.data<scalar_t>(), weight.has_value() ? grad_weight.data<scalar_t>() : NULL,
grad_bias.data<scalar_t>(), weight.has_value() ? grad_bias.data<scalar_t>() : NULL,
batch_size, batch_size,
feature_size, feature_size,
space_size); space_size);
...@@ -965,7 +995,7 @@ at::Tensor batchnorm_backward_CUDA( ...@@ -965,7 +995,7 @@ at::Tensor batchnorm_backward_CUDA(
const at::Tensor input, const at::Tensor input,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor inv_std, const at::Tensor inv_std,
const at::Tensor weight, const at::optional<at::Tensor> weight,
const at::Tensor mean_dy, const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu) { const at::Tensor mean_dy_xmu) {
const auto batch_size = input.size(0); const auto batch_size = input.size(0);
...@@ -984,7 +1014,9 @@ at::Tensor batchnorm_backward_CUDA( ...@@ -984,7 +1014,9 @@ at::Tensor batchnorm_backward_CUDA(
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) { if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( batchnorm_backward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
...@@ -992,7 +1024,7 @@ at::Tensor batchnorm_backward_CUDA( ...@@ -992,7 +1024,7 @@ at::Tensor batchnorm_backward_CUDA(
input.data<scalar_t>(), input.data<scalar_t>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.data<accscalar_t>(), weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(), grad_input.data<scalar_t>(),
...@@ -1000,7 +1032,10 @@ at::Tensor batchnorm_backward_CUDA( ...@@ -1000,7 +1032,10 @@ at::Tensor batchnorm_backward_CUDA(
batch_size); batch_size);
})); }));
} else { } else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()"); if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>( batchnorm_backward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
...@@ -1008,7 +1043,7 @@ at::Tensor batchnorm_backward_CUDA( ...@@ -1008,7 +1043,7 @@ at::Tensor batchnorm_backward_CUDA(
input.data<scalar_t>(), input.data<scalar_t>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.data<scalar_t>(), weight.has_value() ? weight.value().data<scalar_t>() : NULL,
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(), grad_input.data<scalar_t>(),
...@@ -1099,8 +1134,8 @@ at::Tensor batchnorm_forward_c_last_CUDA( ...@@ -1099,8 +1134,8 @@ at::Tensor batchnorm_forward_c_last_CUDA(
const at::Tensor input, const at::Tensor input,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor inv_std, const at::Tensor inv_std,
const at::Tensor weight, const at::optional<at::Tensor> weight,
const at::Tensor shift) { const at::optional<at::Tensor> shift) {
const auto stride = input.size(input.ndimension()-1); const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride; const auto reduction_size = input.numel() / stride;
...@@ -1113,7 +1148,7 @@ at::Tensor batchnorm_forward_c_last_CUDA( ...@@ -1113,7 +1148,7 @@ at::Tensor batchnorm_forward_c_last_CUDA(
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half if (input.type().scalarType() == at::ScalarType::Half
&& weight.type().scalarType() == at::ScalarType::Float) { && weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
...@@ -1121,15 +1156,17 @@ at::Tensor batchnorm_forward_c_last_CUDA( ...@@ -1121,15 +1156,17 @@ at::Tensor batchnorm_forward_c_last_CUDA(
input.data<scalar_t>(), input.data<scalar_t>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.data<accscalar_t>(), weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.data<accscalar_t>(), shift.has_value() ? shift.value().data<accscalar_t>(): NULL,
out.data<scalar_t>(), out.data<scalar_t>(),
reduction_size, reduction_size,
stride); stride);
})); }));
} else { } else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.type().scalarType() is not supported with weight.type().scalarType()");
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER> batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
...@@ -1137,8 +1174,8 @@ at::Tensor batchnorm_forward_c_last_CUDA( ...@@ -1137,8 +1174,8 @@ at::Tensor batchnorm_forward_c_last_CUDA(
input.data<scalar_t>(), input.data<scalar_t>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.data<scalar_t>(), weight.has_value() ? weight.value().data<scalar_t>() : NULL,
shift.data<scalar_t>(), shift.has_value() ? shift.value().data<scalar_t>(): NULL,
out.data<scalar_t>(), out.data<scalar_t>(),
reduction_size, reduction_size,
stride); stride);
...@@ -1152,14 +1189,23 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA( ...@@ -1152,14 +1189,23 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
const at::Tensor input, const at::Tensor input,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor inv_std, const at::Tensor inv_std,
const at::Tensor weight) { const at::optional<at::Tensor> weight) {
const auto stride = input.size(input.ndimension()-1); const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride; const auto reduction_size = input.numel() / stride;
at::Tensor mean_dy = at::empty({stride}, mean.options()); at::Tensor mean_dy = at::empty({stride}, mean.options());
at::Tensor mean_dy_xmu = at::empty({stride}, mean.options()); at::Tensor mean_dy_xmu = at::empty({stride}, mean.options());
at::Tensor grad_weight = at::empty({stride}, weight.options());
at::Tensor grad_bias = at::empty({stride}, weight.options()); at::Tensor grad_weight;
at::Tensor grad_bias;
if (weight.has_value()) {
grad_weight = at::empty({stride}, weight.value().options());
grad_bias = at::empty({stride}, weight.value().options());
} else {
// because I cannot return an uninitialized at::Tensor
grad_weight = at::empty({0}, mean.options());
grad_bias = at::empty({0}, mean.options());
}
dim3 block; dim3 block;
dim3 grid; dim3 grid;
...@@ -1173,7 +1219,9 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA( ...@@ -1173,7 +1219,9 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
} }
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) { if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value()
&& weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr; accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
...@@ -1186,15 +1234,18 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA( ...@@ -1186,15 +1234,18 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
grad_weight.data<accscalar_t>(), weight.has_value() ? grad_weight.data<accscalar_t>() : NULL,
grad_bias.data<accscalar_t>(), weight.has_value() ?grad_bias.data<accscalar_t>() : NULL,
staging_data_ptr, staging_data_ptr,
semaphores_ptr, semaphores_ptr,
reduction_size, reduction_size,
stride); stride);
})); }));
} else { } else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()"); if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr; accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
...@@ -1207,8 +1258,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA( ...@@ -1207,8 +1258,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
grad_weight.data<scalar_t>(), weight.has_value() ? grad_weight.data<scalar_t>() : NULL,
grad_bias.data<scalar_t>(), weight.has_value() ?grad_bias.data<scalar_t>() : NULL,
staging_data_ptr, staging_data_ptr,
semaphores_ptr, semaphores_ptr,
reduction_size, reduction_size,
...@@ -1224,7 +1275,7 @@ at::Tensor batchnorm_backward_c_last_CUDA( ...@@ -1224,7 +1275,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
const at::Tensor input, const at::Tensor input,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor inv_std, const at::Tensor inv_std,
const at::Tensor weight, const at::optional<at::Tensor> weight,
const at::Tensor mean_dy, const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu) { const at::Tensor mean_dy_xmu) {
const auto stride = input.size(input.ndimension()-1); const auto stride = input.size(input.ndimension()-1);
...@@ -1239,7 +1290,7 @@ at::Tensor batchnorm_backward_c_last_CUDA( ...@@ -1239,7 +1290,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half if (input.type().scalarType() == at::ScalarType::Half
&& weight.type().scalarType() == at::ScalarType::Float) { && weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
...@@ -1248,7 +1299,7 @@ at::Tensor batchnorm_backward_c_last_CUDA( ...@@ -1248,7 +1299,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
input.data<scalar_t>(), input.data<scalar_t>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.data<accscalar_t>(), weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(), grad_input.data<scalar_t>(),
...@@ -1256,8 +1307,10 @@ at::Tensor batchnorm_backward_c_last_CUDA( ...@@ -1256,8 +1307,10 @@ at::Tensor batchnorm_backward_c_last_CUDA(
stride); stride);
})); }));
} else { } else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.type().scalarType() is not supported with weight.type().scalarType()");
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER> batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
...@@ -1266,7 +1319,7 @@ at::Tensor batchnorm_backward_c_last_CUDA( ...@@ -1266,7 +1319,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
input.data<scalar_t>(), input.data<scalar_t>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.data<scalar_t>(), weight.has_value() ? weight.value().data<scalar_t>() : NULL,
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(), grad_input.data<scalar_t>(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment