"...internal/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "2412adf42b8380748ac79476e273f5b337c3b977"
Unverified Commit c981f6d0 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Fix K padding calculation for grouped conv data (#876)

* Fix K padding calculation for grouped conv data

* Restore previous padd for 1x1 specialization
parent bd8024b8
...@@ -280,6 +280,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -280,6 +280,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
BK1, BK1,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock,
DoPadGemmM, DoPadGemmM,
DoPadGemmN>{}; DoPadGemmN>{};
......
...@@ -268,6 +268,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -268,6 +268,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
"KPerBlock must be divisible by AK1Value and BK1Value!");
const auto M = a_grid_desc_m_k.GetLength(I0); const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0); const auto N = b_grid_desc_n_k.GetLength(I0);
......
...@@ -164,6 +164,7 @@ template < ...@@ -164,6 +164,7 @@ template <
index_t BK1, index_t BK1,
index_t GemmMPerBlock, index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock,
bool DoPadGemmM, bool DoPadGemmM,
bool DoPadGemmN> bool DoPadGemmN>
struct TransformConvBwdDataToGemm_v1 struct TransformConvBwdDataToGemm_v1
...@@ -308,9 +309,6 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -308,9 +309,6 @@ struct TransformConvBwdDataToGemm_v1
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const index_t AK0 =
math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, AK1);
if constexpr(NDimSpatial == 2) if constexpr(NDimSpatial == 2)
{ {
// A: output tensor // A: output tensor
...@@ -367,9 +365,11 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -367,9 +365,11 @@ struct TransformConvBwdDataToGemm_v1
const auto out_gemmk_gemmm_padded_grid_desc = const auto out_gemmk_gemmm_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
out_gemmk_gemmmraw_grid_desc, out_gemmk_gemmmraw_grid_desc,
make_tuple(AK1, GemmMPerBlock), make_tuple(GemmKPerBlock, GemmMPerBlock),
Sequence<true, DoPadGemmM>{}); Sequence<true, DoPadGemmM>{});
const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1;
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc, out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
...@@ -460,9 +460,11 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -460,9 +460,11 @@ struct TransformConvBwdDataToGemm_v1
const auto out_gemmk_gemmm_padded_grid_desc = const auto out_gemmk_gemmm_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
out_gemmk_gemmmraw_grid_desc, out_gemmk_gemmmraw_grid_desc,
make_tuple(AK1, GemmMPerBlock), make_tuple(GemmKPerBlock, GemmMPerBlock),
Sequence<true, DoPadGemmM>{}); Sequence<true, DoPadGemmM>{});
const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1;
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc, out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
...@@ -568,9 +570,6 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -568,9 +570,6 @@ struct TransformConvBwdDataToGemm_v1
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const index_t BK0 =
math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, BK1);
// B weight tensor // B weight tensor
if constexpr(NDimSpatial == 2) if constexpr(NDimSpatial == 2)
{ {
...@@ -617,9 +616,11 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -617,9 +616,11 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_gemmk_gemmn_padded_grid_desc = const auto wei_gemmk_gemmn_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmk_gemmnraw_grid_desc, wei_gemmk_gemmnraw_grid_desc,
make_tuple(BK1, GemmNPerBlock), make_tuple(GemmKPerBlock, GemmNPerBlock),
Sequence<true, DoPadGemmN>{}); Sequence<true, DoPadGemmN>{});
const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1;
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemmn_padded_grid_desc, wei_gemmk_gemmn_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
...@@ -690,17 +691,19 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -690,17 +691,19 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}), make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto wei_gemmk_gemm_padded_grid_desc = const auto wei_gemmk_gemmn_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmk_gemmnraw_grid_desc, wei_gemmk_gemmnraw_grid_desc,
make_tuple(BK1, GemmNPerBlock), make_tuple(GemmKPerBlock, GemmNPerBlock),
Sequence<true, DoPadGemmN>{}); Sequence<true, DoPadGemmN>{});
const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1;
const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemm_padded_grid_desc, wei_gemmk_gemmn_padded_grid_desc,
make_tuple( make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_unmerge_transform(make_tuple(BK0, BK1)), make_pass_through_transform(
make_pass_through_transform(wei_gemmk_gemm_padded_grid_desc.GetLength(I1))), wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment