Commit 37a213d8 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Minor stylistic fixes

parent b42fe7c3
...@@ -91,6 +91,7 @@ int main(int argc, char* argv[]) ...@@ -91,6 +91,7 @@ int main(int argc, char* argv[])
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
} }
return 1; return 1;
......
...@@ -81,6 +81,7 @@ int main(int argc, char* argv[]) ...@@ -81,6 +81,7 @@ int main(int argc, char* argv[])
switch(conv_param.num_dim_spatial_) switch(conv_param.num_dim_spatial_)
{ {
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
} }
return 1; return 1;
......
...@@ -93,6 +93,7 @@ int main(int argc, char* argv[]) ...@@ -93,6 +93,7 @@ int main(int argc, char* argv[])
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
} }
return 1; return 1;
......
...@@ -92,6 +92,7 @@ int main(int argc, char* argv[]) ...@@ -92,6 +92,7 @@ int main(int argc, char* argv[])
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
} }
return 1; return 1;
......
...@@ -98,6 +98,7 @@ int main(int argc, char* argv[]) ...@@ -98,6 +98,7 @@ int main(int argc, char* argv[])
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
} }
return 1; return 1;
......
...@@ -5,7 +5,7 @@ template <ck::index_t NDimSpatial> ...@@ -5,7 +5,7 @@ template <ck::index_t NDimSpatial>
bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
const ck::utils::conv::ConvParam& conv_param) const ck::utils::conv::ConvParam& conv_param)
{ {
// Dl and WMMA ops doesn't support split_k > 1 // Dl and WMMA ops don't support split_k > 1
constexpr ck::index_t split_k = 1; constexpr ck::index_t split_k = 1;
const auto in_g_n_c_wis_desc = const auto in_g_n_c_wis_desc =
......
...@@ -63,7 +63,8 @@ template <index_t NDimSpatial, ...@@ -63,7 +63,8 @@ template <index_t NDimSpatial,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t NumGemmKPrefetchStage = 1, index_t NumGemmKPrefetchStage = 1,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1> ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1,
typename ck::enable_if<NDimSpatial == 3, bool>::type = false>
struct DeviceGroupedConvBwdWeight_Wmma_CShuffle struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
: public DeviceGroupedConvBwdWeight<NDimSpatial, : public DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout, InLayout,
...@@ -332,7 +333,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle ...@@ -332,7 +333,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// Padd // Pad
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
...@@ -362,7 +363,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle ...@@ -362,7 +363,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
wei_gemmm_gemmn_pad_grid_desc); wei_gemmm_gemmn_pad_grid_desc);
} }
} // function end }
template <index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc() static auto GetABCGridDesc()
...@@ -716,22 +717,15 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle ...@@ -716,22 +717,15 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
return false; return false;
} }
if constexpr(NDimSpatial == 3)
{
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC)) if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC))
{ {
return false; return false;
} }
}
else
{
return false;
}
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 pad = 0 conv // check if it's a 1x1 convolution with stride=1 and no padding
for(int i = 0; i < NDimSpatial; i++) for(int i = 0; i < NDimSpatial; i++)
{ {
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
......
...@@ -36,7 +36,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -36,7 +36,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
{ {
// Odd K or C values are supported only by DL and WMMA // Odd K or C values are supported only by DL and WMMA
// kernels (only applies to fp16) // kernels (only applies to fp16)
// DL kernel currently supports only `split_k=1` // DL and WMMA kernels currently support only `split_k=1`
if constexpr(std::is_same_v<InDataType, ck::half_t>) if constexpr(std::is_same_v<InDataType, ck::half_t>)
{ {
if(split_k != 1 && (params.K_ % 2 != 0 || params.C_ % 2 != 0)) if(split_k != 1 && (params.K_ % 2 != 0 || params.C_ % 2 != 0))
......
...@@ -137,13 +137,13 @@ TYPED_TEST_SUITE(TestGroupedConvndBwdWeightDefault3d, KernelTypes3d); ...@@ -137,13 +137,13 @@ TYPED_TEST_SUITE(TestGroupedConvndBwdWeightDefault3d, KernelTypes3d);
TYPED_TEST(TestGroupedConvndBwdWeightFilter1x13d, SpecializationCheck) TYPED_TEST(TestGroupedConvndBwdWeightFilter1x13d, SpecializationCheck)
{ {
// Check filter 3,3 instead of 1,1 // Check filter 3x3x3 instead of 1x1x1
this->conv_param = { this->conv_param = {
3, 2, 4, 192, 192, {3, 3, 3}, {28, 28, 28}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}; 3, 2, 4, 192, 192, {3, 3, 3}, {28, 28, 28}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
bool is_supported = this->template Run<1>(); bool is_supported = this->template Run<1>();
EXPECT_FALSE(is_supported); EXPECT_FALSE(is_supported);
// Check strides 2,2 instead of 1,1 // Check strides 2x2x2 instead of 1x1x1
this->conv_param = { this->conv_param = {
3, 2, 4, 192, 192, {1, 1, 1}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}; 3, 2, 4, 192, 192, {1, 1, 1}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
is_supported = this->template Run<1>(); is_supported = this->template Run<1>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment