Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
04d762dc
Commit
04d762dc
authored
Aug 31, 2023
by
Bartlomiej Kocot
Browse files
Fix K padding calculation for grouped conv data
parent
38ada109
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
16 deletions
+22
-16
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
...vice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+2
-0
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
...operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
+19
-16
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
View file @
04d762dc
...
...
@@ -280,6 +280,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
BK1
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
DoPadGemmM
,
DoPadGemmN
>
{};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
04d762dc
...
...
@@ -263,6 +263,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
static_assert
(
KPerBlock
%
AK1Value
==
0
&&
KPerBlock
%
BK1Value
==
0
,
"KPerBlock must be divisible by AK1Value and BK1Value!"
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
...
...
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
View file @
04d762dc
...
...
@@ -164,6 +164,7 @@ template <
index_t
BK1
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
bool
DoPadGemmM
,
bool
DoPadGemmN
>
struct
TransformConvBwdDataToGemm_v1
...
...
@@ -308,9 +309,6 @@ struct TransformConvBwdDataToGemm_v1
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
const
index_t
AK0
=
math
::
integer_divide_ceil
(
ZDotSlice
*
YDotSlice
*
XDotSlice
*
K
,
AK1
);
if
constexpr
(
NDimSpatial
==
2
)
{
// A: output tensor
...
...
@@ -367,9 +365,11 @@ struct TransformConvBwdDataToGemm_v1
const
auto
out_gemmk_gemmm_padded_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmk_gemmmraw_grid_desc
,
make_tuple
(
AK1
,
GemmMPerBlock
),
make_tuple
(
GemmKPerBlock
,
GemmMPerBlock
),
Sequence
<
true
,
DoPadGemmM
>
{});
const
index_t
AK0
=
out_gemmk_gemmm_padded_grid_desc
.
GetLength
(
I0
)
/
AK1
;
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmm_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
...
...
@@ -460,9 +460,11 @@ struct TransformConvBwdDataToGemm_v1
const
auto
out_gemmk_gemmm_padded_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmk_gemmmraw_grid_desc
,
make_tuple
(
AK1
,
GemmMPerBlock
),
make_tuple
(
GemmKPerBlock
,
GemmMPerBlock
),
Sequence
<
true
,
DoPadGemmM
>
{});
const
index_t
AK0
=
out_gemmk_gemmm_padded_grid_desc
.
GetLength
(
I0
)
/
AK1
;
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmm_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
...
...
@@ -544,7 +546,7 @@ struct TransformConvBwdDataToGemm_v1
const
auto
wei_gemmbk0_gemmn_gemmbk1_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
,
make_tuple
(
BK0
,
GemmNPerBlock
,
BK1
),
make_tuple
(
GemmKPerBlock
,
GemmNPerBlock
,
BK1
),
Sequence
<
false
,
DoPadGemmN
,
false
>
{});
return
wei_gemmbk0_gemmn_gemmbk1_grid_desc
;
...
...
@@ -568,9 +570,6 @@ struct TransformConvBwdDataToGemm_v1
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
const
index_t
BK0
=
math
::
integer_divide_ceil
(
ZDotSlice
*
YDotSlice
*
XDotSlice
*
K
,
BK1
);
// B weight tensor
if
constexpr
(
NDimSpatial
==
2
)
{
...
...
@@ -617,9 +616,11 @@ struct TransformConvBwdDataToGemm_v1
const
auto
wei_gemmk_gemmn_padded_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmk_gemmnraw_grid_desc
,
make_tuple
(
BK1
,
GemmNPerBlock
),
make_tuple
(
GemmKPerBlock
,
GemmNPerBlock
),
Sequence
<
true
,
DoPadGemmN
>
{});
const
index_t
BK0
=
wei_gemmk_gemmn_padded_grid_desc
.
GetLength
(
I0
)
/
BK1
;
const
auto
wei_gemmbk0_gemmn_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemmn_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
...
...
@@ -690,17 +691,19 @@ struct TransformConvBwdDataToGemm_v1
make_tuple
(
Sequence
<
1
,
2
,
3
,
0
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
wei_gemmk_gemm_padded_grid_desc
=
const
auto
wei_gemmk_gemm
n
_padded_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmk_gemmnraw_grid_desc
,
make_tuple
(
BK1
,
GemmNPerBlock
),
make_tuple
(
GemmKPerBlock
,
GemmNPerBlock
),
Sequence
<
true
,
DoPadGemmN
>
{});
const
index_t
BK0
=
wei_gemmk_gemmn_padded_grid_desc
.
GetLength
(
I0
)
/
BK1
;
const
auto
wei_gemmbk0_gemm_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemm_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
wei_gemmk_gemm_padded_grid_desc
.
GetLength
(
I1
))),
wei_gemmk_gemm
n
_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
wei_gemmk_gemm
n
_padded_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment