Unverified Commit d81ed26d authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Merge pull request #143 from NVIDIA/sbn_no_affine

allowing syncBN to run with affine = False
parents 48299b0d 223a47e9
......@@ -21,8 +21,8 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
at::Tensor batchnorm_forward_CUDA(const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor shift);
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
......@@ -32,7 +32,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight);
const at::optional<at::Tensor> weight);
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
......@@ -41,7 +41,7 @@ at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu);
......@@ -57,8 +57,8 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);
at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor shift);
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
......@@ -68,7 +68,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight);
const at::optional<at::Tensor> weight);
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
......@@ -78,7 +78,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu);
......
......@@ -305,8 +305,8 @@ __global__ void batchnorm_forward_kernel(
const int bs) {
auto m_c = mean[blockIdx.x];
auto inv_std_c = inv_std[blockIdx.x];
auto w_c = static_cast<accscalar_t>(weight[blockIdx.x]);
auto s_c = static_cast<accscalar_t>(shift[blockIdx.x]);
auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x]);
auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[blockIdx.x]);
for (int batch_offset = blockIdx.y*blockDim.y + threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {
int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss;
......@@ -370,8 +370,12 @@ __global__ void reduce_bn_kernel(
sum_dy_xmu = reduce_block((accscalar_t*)s_mem, sum_dy_xmu);
if (thread_id == 0) {
if (grad_bias != NULL) {
grad_bias[blockIdx.x] = static_cast<layerscalar_t>(sum_dy);
}
if (grad_weight != NULL) {
grad_weight[blockIdx.x] = static_cast<layerscalar_t>(sum_dy_xmu * factor);
}
mean_dy[blockIdx.x] = sum_dy / total_item_num;
mean_dy_xmu[blockIdx.x] = sum_dy_xmu / total_item_num;
}
......@@ -393,7 +397,7 @@ __global__ void batchnorm_backward_kernel(
auto m_c = static_cast<accscalar_t>(mean[blockIdx.x]);
auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]);
auto factor_1_c = inv_std[blockIdx.x];
auto factor_2_c = static_cast<accscalar_t>(weight[blockIdx.x]) * factor_1_c;
auto factor_2_c = (weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x])) * factor_1_c;
factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[blockIdx.x];
for (int batch_offset = blockIdx.y*blockDim.y+threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {
......@@ -603,8 +607,8 @@ __global__ void batchnorm_forward_c_last_kernel(
auto m_c = mean[c_offset];
auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
auto w_c = static_cast<accscalar_t>(weight[c_offset]);
auto s_c = static_cast<accscalar_t>(shift[c_offset]);
auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
int address_base = m_offset * stride + c_offset;
......@@ -749,16 +753,24 @@ __global__ void reduce_bn_c_last_kernel(
merge_block_vertical(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
if (threadIdx.y == 0 && c_offset < stride) {
if (grad_bias != NULL) {
grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
}
if (grad_weight != NULL) {
grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
}
mean_dy[c_offset] = sum_dy_th / reduction_size;
mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
}
}
} else {
if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
if (grad_bias != NULL) {
grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
}
if (grad_weight != NULL) {
grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
}
mean_dy[c_offset] = sum_dy_th / reduction_size;
mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
}
......@@ -793,7 +805,7 @@ __global__ void batchnorm_backward_c_last_kernel(
auto m_c = mean[c_offset];
auto m_dy_c = mean_dy[c_offset];
auto factor_1_c = inv_std[c_offset];
auto factor_2_c = static_cast<accscalar_t>(weight[c_offset]) * factor_1_c;
auto factor_2_c = (weight == NULL? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset])) * factor_1_c;
factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[c_offset];
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
......@@ -850,8 +862,8 @@ at::Tensor batchnorm_forward_CUDA(
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor shift) {
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift) {
const auto batch_size = input.size(0);
const auto feature_size = input.size(1);
at::Tensor out = at::empty_like(input);
......@@ -866,29 +878,34 @@ at::Tensor batchnorm_forward_CUDA(
const dim3 grid(feature_size, batch_group_size, grid_z);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<accscalar_t>(),
shift.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.has_value() ? shift.value().data<accscalar_t>() : NULL,
out.data<scalar_t>(),
space_size,
batch_size);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<scalar_t>(),
shift.data<scalar_t>(),
weight.has_value() ? weight.value().data<scalar_t>() : NULL,
shift.has_value() ? shift.value().data<scalar_t>() : NULL,
out.data<scalar_t>(),
space_size,
batch_size);
......@@ -902,7 +919,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight)
const at::optional<at::Tensor> weight)
{
const auto batch_size = input.size(0);
const auto feature_size = input.size(1);
......@@ -911,8 +928,16 @@ std::vector<at::Tensor> reduce_bn_CUDA(
at::Tensor mean_dy = at::empty({feature_size}, mean.options());
at::Tensor mean_dy_xmu = at::empty({feature_size}, mean.options());
at::Tensor grad_weight = at::empty({feature_size}, weight.options());
at::Tensor grad_bias = at::empty({feature_size}, weight.options());
at::Tensor grad_weight;
at::Tensor grad_bias;
if (weight.has_value()) {
grad_weight = at::empty({feature_size}, weight.value().options());
grad_bias = at::empty({feature_size}, weight.value().options());
} else {
grad_weight = at::empty({0}, mean.options());
grad_bias = at::empty({0}, mean.options());
}
auto space_size = get_tensor_spatial_size(input);
......@@ -922,7 +947,9 @@ std::vector<at::Tensor> reduce_bn_CUDA(
const dim3 grid(feature_size);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
reduce_bn_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
......@@ -932,14 +959,17 @@ std::vector<at::Tensor> reduce_bn_CUDA(
inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_weight.data<accscalar_t>(),
grad_bias.data<accscalar_t>(),
weight.has_value() ? grad_weight.data<accscalar_t>() : NULL,
weight.has_value() ? grad_bias.data<accscalar_t>() : NULL,
batch_size,
feature_size,
space_size);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
reduce_bn_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
......@@ -949,8 +979,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_weight.data<scalar_t>(),
grad_bias.data<scalar_t>(),
weight.has_value() ? grad_weight.data<scalar_t>() : NULL,
weight.has_value() ? grad_bias.data<scalar_t>() : NULL,
batch_size,
feature_size,
space_size);
......@@ -965,7 +995,7 @@ at::Tensor batchnorm_backward_CUDA(
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu) {
const auto batch_size = input.size(0);
......@@ -984,7 +1014,9 @@ at::Tensor batchnorm_backward_CUDA(
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
......@@ -992,7 +1024,7 @@ at::Tensor batchnorm_backward_CUDA(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(),
......@@ -1000,7 +1032,10 @@ at::Tensor batchnorm_backward_CUDA(
batch_size);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
......@@ -1008,7 +1043,7 @@ at::Tensor batchnorm_backward_CUDA(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<scalar_t>(),
weight.has_value() ? weight.value().data<scalar_t>() : NULL,
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(),
......@@ -1099,8 +1134,8 @@ at::Tensor batchnorm_forward_c_last_CUDA(
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor shift) {
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift) {
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
......@@ -1113,7 +1148,7 @@ at::Tensor batchnorm_forward_c_last_CUDA(
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half
&& weight.type().scalarType() == at::ScalarType::Float) {
&& weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
......@@ -1121,15 +1156,17 @@ at::Tensor batchnorm_forward_c_last_CUDA(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<accscalar_t>(),
shift.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.has_value() ? shift.value().data<accscalar_t>(): NULL,
out.data<scalar_t>(),
reduction_size,
stride);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(),
if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
......@@ -1137,8 +1174,8 @@ at::Tensor batchnorm_forward_c_last_CUDA(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<scalar_t>(),
shift.data<scalar_t>(),
weight.has_value() ? weight.value().data<scalar_t>() : NULL,
shift.has_value() ? shift.value().data<scalar_t>(): NULL,
out.data<scalar_t>(),
reduction_size,
stride);
......@@ -1152,14 +1189,23 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight) {
const at::optional<at::Tensor> weight) {
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
at::Tensor mean_dy = at::empty({stride}, mean.options());
at::Tensor mean_dy_xmu = at::empty({stride}, mean.options());
at::Tensor grad_weight = at::empty({stride}, weight.options());
at::Tensor grad_bias = at::empty({stride}, weight.options());
at::Tensor grad_weight;
at::Tensor grad_bias;
if (weight.has_value()) {
grad_weight = at::empty({stride}, weight.value().options());
grad_bias = at::empty({stride}, weight.value().options());
} else {
// because I cannot return an uninitialized at::Tensor
grad_weight = at::empty({0}, mean.options());
grad_bias = at::empty({0}, mean.options());
}
dim3 block;
dim3 grid;
......@@ -1173,7 +1219,9 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
}
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value()
&& weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
......@@ -1186,15 +1234,18 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_weight.data<accscalar_t>(),
grad_bias.data<accscalar_t>(),
weight.has_value() ? grad_weight.data<accscalar_t>() : NULL,
weight.has_value() ?grad_bias.data<accscalar_t>() : NULL,
staging_data_ptr,
semaphores_ptr,
reduction_size,
stride);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
......@@ -1207,8 +1258,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_weight.data<scalar_t>(),
grad_bias.data<scalar_t>(),
weight.has_value() ? grad_weight.data<scalar_t>() : NULL,
weight.has_value() ?grad_bias.data<scalar_t>() : NULL,
staging_data_ptr,
semaphores_ptr,
reduction_size,
......@@ -1224,7 +1275,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu) {
const auto stride = input.size(input.ndimension()-1);
......@@ -1239,7 +1290,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half
&& weight.type().scalarType() == at::ScalarType::Float) {
&& weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
......@@ -1248,7 +1299,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(),
......@@ -1256,8 +1307,10 @@ at::Tensor batchnorm_backward_c_last_CUDA(
stride);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(),
if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
......@@ -1266,7 +1319,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<scalar_t>(),
weight.has_value() ? weight.value().data<scalar_t>() : NULL,
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment