Commit 5d89c6f3 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Add comments in device_column_to_image_impl

parent 262d6757
......@@ -72,9 +72,11 @@ struct DeviceColumnToImageImpl
const index_t x_eff = (filter_len - 1) * filter_dilation + 1;
const index_t next_filter_padded =
math::integer_divide_ceil(x_eff, filter_stride) * filter_stride;
// If filter_stride >= x_eff then each filter is independent
const index_t independent_filter_stride =
filter_stride >= x_eff ? filter_stride : next_filter_padded;
const index_t w_eff = input_spatial_len - image_offset + left_pad + right_pad - x_eff;
// There are no independent filters
if(w_eff < 0)
return 0;
const index_t independent_kernels_num = w_eff / independent_filter_stride + 1;
......@@ -99,7 +101,8 @@ struct DeviceColumnToImageImpl
filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t NStride = DoHoWo * gemm_m_k_strides[I0] * gemm_m_k_strides[I1];
// Strides to filters for each dimensions
// Calculate the appropriate stride for each set of independent filters
// in each dimension
const index_t WStride =
math::integer_divide_ceil(effs[XIdx], conv_filter_strides[XIdx]) * gemm_m_k_strides[I0];
const index_t HStride = math::integer_divide_ceil(effs[YIdx], conv_filter_strides[YIdx]) *
......@@ -107,6 +110,8 @@ struct DeviceColumnToImageImpl
const index_t DStride = math::integer_divide_ceil(effs[ZIdx], conv_filter_strides[ZIdx]) *
output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx] *
gemm_m_k_strides[I0];
// Create descriptor for independent filters in each dimension and
// then merge them into column form
if constexpr(NDimSpatial == 1)
{
const auto desc_gemm_form =
......@@ -218,6 +223,7 @@ struct DeviceColumnToImageImpl
: independent_filter_stride;
}
// Calculate image form descriptor for the modified convolution problem
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>(
a_g_n_c_wis_lengths,
......@@ -294,7 +300,7 @@ struct DeviceColumnToImageImpl
? I1
: (filter_spatial_lengths[ZIdx] - 1) * conv_filter_dilations[ZIdx] + 1;
// Iterate over sets of independent kernels
// Iterate over sets of independent filters
for(int z_img_offset = 0; z_img_offset < z_eff;
z_img_offset += conv_filter_strides[ZIdx])
{
......@@ -307,6 +313,8 @@ struct DeviceColumnToImageImpl
std::array<index_t, NDimSpatial> image_offsets;
std::array<index_t, NDimSpatial> effs;
// Calculate the starting offset for a given set of
// independent filters
if constexpr(NDimSpatial == 1)
{
image_offsets = {x_img_offset};
......@@ -376,12 +384,14 @@ struct DeviceColumnToImageImpl
const index_t z_offset_with_pad =
math::max(0, z_img_offset - input_left_pads[ZIdx]);
// Memory offsets to next set of independent kernels
// Memory offsets to next set of independent filters,
// move to independent filters in each dimension
const index_t in_offset =
x_idx * gemm_m_k_strides[0] +
y_idx * gemm_m_k_strides[0] * output_spatial_lengths[XIdx] +
z_idx * gemm_m_k_strides[0] * output_spatial_lengths[YIdx] *
output_spatial_lengths[XIdx];
// Move to independent filters in appropriate dimensions
const index_t out_offset =
x_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + XIdx] +
y_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + YIdx] +
......@@ -443,6 +453,7 @@ struct DeviceColumnToImageImpl
Block2ETileMap,
GridwiseTensorRearrangeKernel>;
// Execute each set of independent filters
for(std::size_t i = 0; i < arg.in_grid_desc_m_k_container_.size(); i++)
{
const auto block_2_tile_map =
......@@ -479,10 +490,6 @@ struct DeviceColumnToImageImpl
{
return false;
}
if constexpr(!(NDimSpatial >= 1 && NDimSpatial <= 3))
{
return false;
}
const auto w_pad_left = arg.input_left_pads_[NDimSpatial - I1];
const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1];
......
......@@ -256,10 +256,6 @@ struct DeviceImageToColumnImpl
{
return false;
}
if constexpr(!(NDimSpatial >= 1 && NDimSpatial <= 3))
{
return false;
}
const auto w_pad_left = arg.input_left_pads_[NDimSpatial - I1];
const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment