"git@developer.sourcefind.cn:change/sglang.git" did not exist on "a9ca297d769b52251a8fca7073c1a41700825fa4"
Commit 37a213d8 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Minor stylistic fixes

parent b42fe7c3
......@@ -91,6 +91,7 @@ int main(int argc, char* argv[])
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
}
return 1;
......
......@@ -81,6 +81,7 @@ int main(int argc, char* argv[])
switch(conv_param.num_dim_spatial_)
{
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
}
return 1;
......
......@@ -93,6 +93,7 @@ int main(int argc, char* argv[])
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
}
return 1;
......
......@@ -92,6 +92,7 @@ int main(int argc, char* argv[])
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
}
return 1;
......
......@@ -98,6 +98,7 @@ int main(int argc, char* argv[])
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
}
return 1;
......
......@@ -5,7 +5,7 @@ template <ck::index_t NDimSpatial>
bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
const ck::utils::conv::ConvParam& conv_param)
{
// Dl and WMMA ops doesn't support split_k > 1
// Dl and WMMA ops don't support split_k > 1
constexpr ck::index_t split_k = 1;
const auto in_g_n_c_wis_desc =
......
......@@ -61,9 +61,10 @@ template <index_t NDimSpatial,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t NumGemmKPrefetchStage = 1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
index_t NumGemmKPrefetchStage = 1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1,
typename ck::enable_if<NDimSpatial == 3, bool>::type = false>
struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
: public DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout,
......@@ -332,7 +333,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// Padd
// Pad
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
......@@ -362,7 +363,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
}
} // function end
}
template <index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc()
......@@ -716,14 +717,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
return false;
}
if constexpr(NDimSpatial == 3)
{
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC))
{
return false;
}
}
else
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC))
{
return false;
}
......@@ -731,7 +725,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 pad = 0 conv
// check if it's a 1x1 convolution with stride=1 and no padding
for(int i = 0; i < NDimSpatial; i++)
{
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
......
......@@ -36,7 +36,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
{
// Odd K or C values are supported only by DL and WMMA
// kernels (only applies to fp16)
// DL kernel currently supports only `split_k=1`
// DL and WMMA kernels currently support only `split_k=1`
if constexpr(std::is_same_v<InDataType, ck::half_t>)
{
if(split_k != 1 && (params.K_ % 2 != 0 || params.C_ % 2 != 0))
......
......@@ -137,13 +137,13 @@ TYPED_TEST_SUITE(TestGroupedConvndBwdWeightDefault3d, KernelTypes3d);
TYPED_TEST(TestGroupedConvndBwdWeightFilter1x13d, SpecializationCheck)
{
// Check filter 3,3 instead of 1,1
// Check filter 3x3x3 instead of 1x1x1
this->conv_param = {
3, 2, 4, 192, 192, {3, 3, 3}, {28, 28, 28}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
bool is_supported = this->template Run<1>();
EXPECT_FALSE(is_supported);
// Check strides 2,2 instead of 1,1
// Check strides 2x2x2 instead of 1x1x1
this->conv_param = {
3, 2, 4, 192, 192, {1, 1, 1}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
is_supported = this->template Run<1>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment