Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7cc806d8
"symphony/git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "113fcda25ebab3689faeac99e78ee1db07a774f2"
Commit
7cc806d8
authored
Jul 19, 2022
by
Chao Liu
Browse files
add matrix padder
parent
0b997ce4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
372 additions
and
351 deletions
+372
-351
example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
+2
-0
include/ck/tensor_operation/gpu/device/device_convnd_fwd_multiple_d_nwc_kxc_nwk_xdl_cshuffle.hpp
...device_convnd_fwd_multiple_d_nwc_kxc_nwk_xdl_cshuffle.hpp
+155
-148
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
+31
-203
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
+184
-0
No files found.
example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
View file @
7cc806d8
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#include "convnd_fwd_common.hpp"
#include "convnd_fwd_common.hpp"
#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp"
#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp"
...
...
include/ck/tensor_operation/gpu/device/device_convnd_fwd_multiple_d_nwc_kxc_nwk_xdl_cshuffle.hpp
View file @
7cc806d8
...
@@ -181,51 +181,54 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -181,51 +181,54 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
GemmK1Number
=
K1Number
;
static
constexpr
auto
GemmK1Number
=
K1Number
;
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
g
emm
_n
,
ck
::
index_t
g
emm
_k
)
static
auto
GetWeightTensorDescriptor
(
index_t
G
emm
N
,
index_t
G
emm
K
)
{
{
const
ck
::
index_t
g
emm
_k
0
=
g
emm
_k
/
GemmK1Number
;
const
index_t
G
emm
K
0
=
G
emm
K
/
GemmK1Number
;
const
auto
wei_k_yx
e
_grid_desc
=
const
auto
wei_k_yx
c
_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
g
emm
_n
,
g
emm
_k
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
G
emm
N
,
G
emm
K
));
// wei_gemmk0_gemmn_gemmk1_grid_desc
// wei_gemmk0_gemmn_gemmk1_grid_desc
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
wei_k_yx
e
_grid_desc
,
wei_k_yx
c
_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
g
emm
_k
0
,
GemmK1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
G
emm
K
0
,
GemmK1Number
)),
make_pass_through_transform
(
g
emm
_n
)),
make_pass_through_transform
(
G
emm
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
static
auto
GetOutputTensorDescriptor
(
index_t
GemmMRaw
,
index_t
GemmN
)
GetOutputTensorDescriptor
(
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_n
,
ck
::
index_t
gemm_m_pad
)
{
{
const
index_t
GemmM
=
math
::
integer_least_multiple
(
GemmMRaw
,
MPerBlock
);
const
index_t
GemmMPad
=
GemmM
-
GemmMRaw
;
const
auto
out_gemmmraw_gemmn_grid_desc
=
const
auto
out_gemmmraw_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
g
emm
_m
,
g
emm
_n
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
G
emm
M
,
G
emm
N
));
// out_gemmm_gemmn_grid_desc
// out_gemmm_gemmn_grid_desc
return
transform_tensor_descriptor
(
out_gemmmraw_gemmn_grid_desc
,
return
transform_tensor_descriptor
(
out_gemmmraw_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
g
emm
_m
,
g
emm
_m_p
ad
),
make_tuple
(
make_right_pad_transform
(
G
emm
M
,
G
emm
MP
ad
),
make_pass_through_transform
(
g
emm
_n
)),
make_pass_through_transform
(
G
emm
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
template
<
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
static
auto
GetInputTensorDescriptor
(
index_t
N
,
ck
::
index_t
C
,
index_t
C
,
ck
::
index_t
gemm_m
,
index_t
GemmMRaw
,
ck
::
index_t
gemm_k
,
index_t
GemmK
,
ck
::
index_t
gemm_m_pad
,
const
std
::
vector
<
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
index_t
>&
output_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
const
std
::
vector
<
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
index_t
>&
conv_filter_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
const
std
::
vector
<
index_t
>&
input_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
const
std
::
vector
<
index_t
>&
input_right_pads
)
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
{
{
const
index_t
GemmM
=
math
::
integer_least_multiple
(
GemmMRaw
,
MPerBlock
);
const
index_t
GemmMPad
=
GemmM
-
GemmMRaw
;
const
ck
::
index_t
g
emm
_k0
=
g
emm
_k
/
GemmK1Number
;
const
index_t
G
emm
K0
=
G
emm
K
/
GemmK1Number
;
const
index_t
Wi
=
input_spatial_lengths
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
...
@@ -234,13 +237,13 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -234,13 +237,13 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
const
auto
in_gemmmraw_gemmk_grid_desc
=
const
auto
in_gemmmraw_gemmk_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
g
emm
_m
,
g
emm
_k
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
G
emm
MRaw
,
G
emm
K
));
// in_gemmk0_gemmm_gemmk1_grid_desc
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
in_gemmmraw_gemmk_grid_desc
,
in_gemmmraw_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
g
emm
_k
0
,
GemmK1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
G
emm
K
0
,
GemmK1Number
)),
make_right_pad_transform
(
g
emm
_m
,
g
emm
_m_p
ad
)),
make_right_pad_transform
(
G
emm
MRaw
,
G
emm
MP
ad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
...
@@ -260,7 +263,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -260,7 +263,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_n_wo_e_grid_desc
,
in_n_wo_e_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
g
emm
_k
0
,
GemmK1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
G
emm
K
0
,
GemmK1Number
)),
make_merge_transform
(
make_tuple
(
N
,
Wo
))),
make_merge_transform
(
make_tuple
(
N
,
Wo
))),
make_tuple
(
Sequence
<
2
>
{},
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
2
>
{},
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
@@ -268,8 +271,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -268,8 +271,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
// in_gemmk0_gemmm_gemmk1_grid_desc
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
g
emm
_k
0
),
make_tuple
(
make_pass_through_transform
(
G
emm
K
0
),
make_right_pad_transform
(
g
emm
_m
,
g
emm
_m_p
ad
),
make_right_pad_transform
(
G
emm
M
,
G
emm
MP
ad
),
make_pass_through_transform
(
GemmK1Number
)),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
...
@@ -310,39 +313,41 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -310,39 +313,41 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmk_gemmmraw_grid_desc
,
in_gemmk_gemmmraw_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
g
emm
_k
0
,
GemmK1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
G
emm
K
0
,
GemmK1Number
)),
make_pass_through_transform
(
g
emm
_m
)),
make_pass_through_transform
(
G
emm
M
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
g
emm
_k
0
),
make_tuple
(
make_pass_through_transform
(
G
emm
K
0
),
make_right_pad_transform
(
g
emm
_m
,
g
emm
_m_p
ad
),
make_right_pad_transform
(
G
emm
M
,
G
emm
MP
ad
),
make_pass_through_transform
(
GemmK1Number
)),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
}
}
}
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
template
<
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
static
auto
GetInputTensorDescriptor
(
index_t
N
,
ck
::
index_t
C
,
index_t
C
,
ck
::
index_t
gemm_m
,
index_t
GemmMRaw
,
ck
::
index_t
gemm_k
,
index_t
GemmK
,
ck
::
index_t
gemm_m_pad
,
const
std
::
vector
<
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
index_t
>&
output_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
const
std
::
vector
<
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
index_t
>&
conv_filter_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
const
std
::
vector
<
index_t
>&
input_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
const
std
::
vector
<
index_t
>&
input_right_pads
)
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
{
{
const
ck
::
index_t
gemm_k0
=
gemm_k
/
GemmK1Number
;
const
index_t
GemmM
=
math
::
integer_least_multiple
(
GemmMRaw
,
MPerBlock
);
const
index_t
Hi
=
input_spatial_lengths
[
0
];
const
index_t
GemmMPad
=
GemmM
-
GemmMRaw
;
const
index_t
Wi
=
input_spatial_lengths
[
1
];
const
index_t
GemmK0
=
GemmK
/
GemmK1Number
;
const
index_t
Hi
=
input_spatial_lengths
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
1
];
const
index_t
Ho
=
output_spatial_lengths
[
0
];
const
index_t
Ho
=
output_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
1
];
const
index_t
Wo
=
output_spatial_lengths
[
1
];
...
@@ -354,13 +359,13 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -354,13 +359,13 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
const
auto
in_gemmmraw_gemmk_grid_desc
=
const
auto
in_gemmmraw_gemmk_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
g
emm
_m
,
g
emm
_k
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
G
emm
M
,
G
emm
K
));
// in_gemmk0_gemmm_gemmk1_grid_desc
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
in_gemmmraw_gemmk_grid_desc
,
in_gemmmraw_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
g
emm
_k
0
,
GemmK1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
G
emm
K
0
,
GemmK1Number
)),
make_right_pad_transform
(
g
emm
_m
,
g
emm
_m_p
ad
)),
make_right_pad_transform
(
G
emm
M
,
G
emm
MP
ad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
...
@@ -381,7 +386,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -381,7 +386,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_n_ho_wo_e_grid_desc
,
in_n_ho_wo_e_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
g
emm
_k
0
,
GemmK1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
G
emm
K
0
,
GemmK1Number
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
@@ -389,8 +394,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -389,8 +394,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
// in_gemmk0_gemmm_gemmk1_grid_desc
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
g
emm
_k
0
),
make_tuple
(
make_pass_through_transform
(
G
emm
K
0
),
make_right_pad_transform
(
g
emm
_m
,
g
emm
_m_p
ad
),
make_right_pad_transform
(
G
emm
M
,
G
emm
MP
ad
),
make_pass_through_transform
(
GemmK1Number
)),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
...
@@ -440,8 +445,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -440,8 +445,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmk_gemmmraw_grid_desc
,
in_gemmk_gemmmraw_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
g
emm
_k
0
,
GemmK1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
G
emm
K
0
,
GemmK1Number
)),
make_pass_through_transform
(
g
emm
_m
)),
make_pass_through_transform
(
G
emm
M
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
@@ -449,32 +454,34 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -449,32 +454,34 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
g
emm
_k
0
),
make_tuple
(
make_pass_through_transform
(
G
emm
K
0
),
make_right_pad_transform
(
g
emm
_m
,
g
emm
_m_p
ad
),
make_right_pad_transform
(
G
emm
M
,
G
emm
MP
ad
),
make_pass_through_transform
(
GemmK1Number
)),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
}
}
}
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
template
<
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
static
auto
GetInputTensorDescriptor
(
index_t
N
,
ck
::
index_t
C
,
index_t
C
,
ck
::
index_t
gemm_m
,
index_t
GemmMRaw
,
ck
::
index_t
gemm_k
,
index_t
GemmK
,
ck
::
index_t
gemm_m_pad
,
const
std
::
vector
<
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
index_t
>&
output_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
const
std
::
vector
<
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
index_t
>&
conv_filter_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
const
std
::
vector
<
index_t
>&
input_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
const
std
::
vector
<
index_t
>&
input_right_pads
)
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
{
{
const
ck
::
index_t
gemm_k0
=
gemm_k
/
GemmK1Number
;
const
index_t
GemmM
=
math
::
integer_least_multiple
(
GemmMRaw
,
MPerBlock
);
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
GemmMPad
=
GemmM
-
GemmMRaw
;
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
GemmK0
=
GemmK
/
GemmK1Number
;
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
Do
=
output_spatial_lengths
[
0
];
const
index_t
Do
=
output_spatial_lengths
[
0
];
const
index_t
Ho
=
output_spatial_lengths
[
1
];
const
index_t
Ho
=
output_spatial_lengths
[
1
];
...
@@ -488,13 +495,13 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -488,13 +495,13 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
const
auto
in_gemmmraw_gemmk_grid_desc
=
const
auto
in_gemmmraw_gemmk_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
g
emm
_m
,
g
emm
_k
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
G
emm
M
,
G
emm
K
));
// in_gemmk0_gemmm_gemmk1_grid_desc
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
in_gemmmraw_gemmk_grid_desc
,
in_gemmmraw_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
g
emm
_k
0
,
GemmK1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
G
emm
K
0
,
GemmK1Number
)),
make_right_pad_transform
(
g
emm
_m
,
g
emm
_m_p
ad
)),
make_right_pad_transform
(
G
emm
M
,
G
emm
MP
ad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
...
@@ -518,7 +525,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -518,7 +525,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_n_do_ho_wo_e_grid_desc
,
in_n_do_ho_wo_e_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
g
emm
_k
0
,
GemmK1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
G
emm
K
0
,
GemmK1Number
)),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
))),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
4
>
{},
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
4
>
{},
Sequence
<
0
,
1
,
2
,
3
>
{}),
...
@@ -527,8 +534,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -527,8 +534,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
// in_gemmk0_gemmm_gemmk1_grid_desc
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
g
emm
_k
0
),
make_tuple
(
make_pass_through_transform
(
G
emm
K
0
),
make_right_pad_transform
(
g
emm
_m
,
g
emm
_m_p
ad
),
make_right_pad_transform
(
G
emm
M
,
G
emm
MP
ad
),
make_pass_through_transform
(
GemmK1Number
)),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
...
@@ -591,68 +598,65 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -591,68 +598,65 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmk_gemmmraw_grid_desc
,
in_gemmk_gemmmraw_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
g
emm
_k
0
,
GemmK1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
G
emm
K
0
,
GemmK1Number
)),
make_pass_through_transform
(
g
emm
_m
)),
make_pass_through_transform
(
G
emm
M
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
g
emm
_k
0
),
make_tuple
(
make_pass_through_transform
(
G
emm
K
0
),
make_right_pad_transform
(
g
emm
_m
,
g
emm
_m_p
ad
),
make_right_pad_transform
(
G
emm
M
,
G
emm
MP
ad
),
make_pass_through_transform
(
GemmK1Number
)),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
}
}
}
}
static
index_t
GetGemmMRaw
(
ck
::
index_t
N
,
static
index_t
GetGemmMRaw
(
index_t
N
,
const
std
::
vector
<
index_t
>&
output_spatial_lengths
)
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
)
{
{
return
N
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths
),
return
N
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths
),
std
::
end
(
output_spatial_lengths
),
std
::
end
(
output_spatial_lengths
),
1
,
1
,
std
::
multiplies
<
ck
::
index_t
>
());
std
::
multiplies
<
index_t
>
());
}
}
static
index_t
GetGemmK
(
ck
::
index_t
C
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
)
static
index_t
GetGemmK
Raw
(
index_t
C
,
const
std
::
vector
<
index_t
>&
filter_spatial_lengths
)
{
{
return
C
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths
),
return
C
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths
),
std
::
end
(
filter_spatial_lengths
),
std
::
end
(
filter_spatial_lengths
),
1
,
1
,
std
::
multiplies
<
ck
::
index_t
>
());
std
::
multiplies
<
index_t
>
());
}
}
static
auto
static
auto
MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
index_t
N
,
ck
::
index_t
K
,
index_t
K
,
ck
::
index_t
C
,
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
std
::
vector
<
index_t
>
input_right_pads
)
{
{
using
namespace
ck
;
using
namespace
ck
;
const
index_t
GemmMRaw
=
GetGemmMRaw
(
N
,
output_spatial_lengths
);
const
index_t
GemmMRaw
=
GetGemmMRaw
(
N
,
output_spatial_lengths
);
const
index_t
GemmN
=
K
;
const
index_t
GemmN
Raw
=
K
;
const
index_t
GemmK
=
GetGemmK
(
C
,
filter_spatial_lengths
);
const
index_t
GemmK
Raw
=
GetGemmK
Raw
(
C
,
filter_spatial_lengths
);
const
auto
GemmMPad
=
math
::
integer_least_multiple
(
GemmMRaw
,
MPerBlock
)
-
GemmMRaw
;
// TODO: remove
assert
(
GemmKRaw
%
GemmK1Number
==
0
);
assert
(
GemmK
%
GemmK1Number
==
0
);
// A:
// A:
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
GetInputTensorDescriptor
<
NDimSpatial
>
(
N
,
GetInputTensorDescriptor
<
NDimSpatial
>
(
N
,
C
,
C
,
GemmMRaw
,
GemmMRaw
,
GemmK
,
GemmKRaw
,
GemmMPad
,
input_spatial_lengths
,
input_spatial_lengths
,
filter_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
output_spatial_lengths
,
...
@@ -660,31 +664,34 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -660,31 +664,34 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
);
input_right_pads
);
// B:
// B:
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
GetWeightTensorDescriptor
(
GemmN
,
GemmK
);
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
GetWeightTensorDescriptor
(
GemmNRaw
,
GemmKRaw
);
// E:
// E:
const
auto
out_gemmm_gemmn_grid_desc
=
GetOutputTensorDescriptor
(
GemmMRaw
,
GemmN
,
GemmMPad
);
const
auto
out_gemmm_gemmn_grid_desc
=
GetOutputTensorDescriptor
(
GemmMRaw
,
GemmN
Raw
);
return
make_tuple
(
in_gemmk0_gemmm_gemmk1_grid_desc
,
return
make_tuple
(
in_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmm_gemmn_grid_desc
);
out_gemmm_gemmn_grid_desc
);
}
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
template
<
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
GetABEGridDesc
()
static
auto
GetABEGridDesc
()
{
{
return
MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
return
MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
1
,
1
,
1
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
});
1
,
1
,
1
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
});
}
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
template
<
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
GetABEGridDesc
()
static
auto
GetABEGridDesc
()
{
{
return
MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
return
MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
});
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
});
}
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
template
<
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetABEGridDesc
()
static
auto
GetABEGridDesc
()
{
{
return
MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
return
MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
...
@@ -756,16 +763,16 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -756,16 +763,16 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
Argument
(
const
ADataType
*
p_in_grid
,
Argument
(
const
ADataType
*
p_in_grid
,
const
BDataType
*
p_wei_grid
,
const
BDataType
*
p_wei_grid
,
EDataType
*
p_out_grid
,
EDataType
*
p_out_grid
,
ck
::
index_t
N
,
index_t
N
,
ck
::
index_t
K
,
index_t
K
,
ck
::
index_t
C
,
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
index_t
>
input_right_pads
,
AElementwiseOperation
in_element_op
,
AElementwiseOperation
in_element_op
,
BElementwiseOperation
wei_element_op
,
BElementwiseOperation
wei_element_op
,
CDEElementwiseOperation
out_element_op
)
CDEElementwiseOperation
out_element_op
)
...
@@ -988,7 +995,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -988,7 +995,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
// check if it's 1x1, stride=1 conv
// check if it's 1x1, stride=1 conv
for
(
ck
::
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
{
if
(
!
(
arg
.
filter_spatial_lengths_
[
i
]
==
1
&&
arg
.
conv_filter_strides_
[
i
]
==
1
&&
if
(
!
(
arg
.
filter_spatial_lengths_
[
i
]
==
1
&&
arg
.
conv_filter_strides_
[
i
]
==
1
&&
arg
.
input_left_pads_
[
i
]
==
0
&&
arg
.
input_right_pads_
[
i
]
==
0
))
arg
.
input_left_pads_
[
i
]
==
0
&&
arg
.
input_right_pads_
[
i
]
==
0
))
...
@@ -1001,7 +1008,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -1001,7 +1008,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
{
// check if it's 1x1 conv
// check if it's 1x1 conv
for
(
ck
::
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
{
if
(
!
(
arg
.
filter_spatial_lengths_
[
i
]
==
1
&&
arg
.
input_left_pads_
[
i
]
==
0
&&
if
(
!
(
arg
.
filter_spatial_lengths_
[
i
]
==
1
&&
arg
.
input_left_pads_
[
i
]
==
0
&&
arg
.
input_right_pads_
[
i
]
==
0
))
arg
.
input_right_pads_
[
i
]
==
0
))
...
@@ -1040,16 +1047,16 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -1040,16 +1047,16 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
static
auto
MakeArgument
(
const
ADataType
*
p_in_grid
,
static
auto
MakeArgument
(
const
ADataType
*
p_in_grid
,
const
BDataType
*
p_wei_grid
,
const
BDataType
*
p_wei_grid
,
EDataType
*
p_out_grid
,
EDataType
*
p_out_grid
,
ck
::
index_t
N
,
index_t
N
,
ck
::
index_t
K
,
index_t
K
,
ck
::
index_t
C
,
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
index_t
>
input_right_pads
,
AElementwiseOperation
in_element_op
,
AElementwiseOperation
in_element_op
,
BElementwiseOperation
wei_element_op
,
BElementwiseOperation
wei_element_op
,
CDEElementwiseOperation
out_element_op
)
CDEElementwiseOperation
out_element_op
)
...
@@ -1078,16 +1085,16 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -1078,16 +1085,16 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
MakeArgumentPointer
(
const
void
*
p_in_grid
,
MakeArgumentPointer
(
const
void
*
p_in_grid
,
const
void
*
p_wei_grid
,
const
void
*
p_wei_grid
,
void
*
p_out_grid
,
void
*
p_out_grid
,
ck
::
index_t
N
,
index_t
N
,
ck
::
index_t
K
,
index_t
K
,
ck
::
index_t
C
,
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
index_t
>
input_right_pads
,
AElementwiseOperation
in_element_op
,
AElementwiseOperation
in_element_op
,
BElementwiseOperation
wei_element_op
,
BElementwiseOperation
wei_element_op
,
CDEElementwiseOperation
out_element_op
)
override
CDEElementwiseOperation
out_element_op
)
override
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
7cc806d8
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
#include "ck/device_utility/kernel_launch.hpp"
...
@@ -160,6 +161,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -160,6 +161,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
...
@@ -175,92 +179,23 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -175,92 +179,23 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}
}();
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
a_grid_desc_m_k
=
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
transform_tensor_descriptor
(
a_grid_desc_m_k
,
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
assert
(
K
%
AK1
==
0
);
}
else
{
// not pad M or K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
K
Raw
/
AK1
;
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m
raw_kraw
,
transform_tensor_descriptor
(
a_grid_desc_m
_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
Raw
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
return
a_grid_desc_ak0_m_ak1
;
}
}
}
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
...
@@ -278,97 +213,28 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -278,97 +213,28 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}
}();
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
b_grid_desc_n_k
=
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
transform_tensor_descriptor
(
b_grid_desc_n_k
,
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
assert
(
K
%
BK1
==
0
);
}
else
{
// not pad N or K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
K
Raw
/
BK1
;
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n
raw_kraw
,
transform_tensor_descriptor
(
b_grid_desc_n
_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
Raw
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
return
b_grid_desc_bk0_n_bk1
;
}
}
}
static
auto
MakeEGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideE
)
static
auto
MakeEGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideE
)
{
{
const
auto
c
_grid_desc_mraw_nraw
=
[
&
]()
{
const
auto
e
_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
DELayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
DELayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
...
@@ -381,47 +247,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -381,47 +247,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}
}();
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
e_grid_desc_m_n
=
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
return
e_grid_desc_m_n
;
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_grid_desc_mraw_nraw
;
}
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
...
...
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
0 → 100644
View file @
7cc806d8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// M/N/KPerTileType could be index_t or Number<>
template
<
GemmSpecialization
GemmSpec
,
typename
MPerTileType
,
typename
NPerTileType
,
typename
KPerTileType
>
struct
MatrixPadder
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
template
<
typename
ADesc_MRaw_KRaw
>
__host__
__device__
constexpr
auto
PadADescriptor_M_K
(
const
ADesc_MRaw_KRaw
&
a_desc_mraw_kraw
)
const
{
const
auto
MRaw
=
a_desc_mraw_kraw
.
GetLength
(
I0
);
const
auto
KRaw
=
a_desc_mraw_kraw
.
GetLength
(
I1
);
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerTile_
)
*
MPerTile_
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerTile_
)
*
KPerTile_
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
return
transform_tensor_descriptor
(
a_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
return
transform_tensor_descriptor
(
a_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
KRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
return
transform_tensor_descriptor
(
a_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or K
return
a_desc_mraw_kraw
;
}
}
template
<
typename
BDesc_NRaw_KRaw
>
__host__
__device__
constexpr
auto
PadBDescriptor_N_K
(
const
BDesc_NRaw_KRaw
&
b_desc_nraw_kraw
)
const
{
const
auto
NRaw
=
b_desc_nraw_kraw
.
GetLength
(
I0
);
const
auto
KRaw
=
b_desc_nraw_kraw
.
GetLength
(
I1
);
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerTile_
)
*
NPerTile_
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerTile_
)
*
KPerTile_
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
return
transform_tensor_descriptor
(
b_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
return
transform_tensor_descriptor
(
b_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_pass_through_transform
(
KRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
return
transform_tensor_descriptor
(
b_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad N or K
return
b_desc_nraw_kraw
;
}
}
template
<
typename
CDesc_MRaw_NRaw
>
__host__
__device__
constexpr
auto
PadCDescriptor_M_N
(
const
CDesc_MRaw_NRaw
&
c_desc_mraw_nraw
)
const
{
const
auto
MRaw
=
c_desc_mraw_nraw
.
GetLength
(
I0
);
const
auto
NRaw
=
c_desc_mraw_nraw
.
GetLength
(
I1
);
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerTile_
)
*
MPerTile_
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerTile_
)
*
NPerTile_
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_desc_mraw_nraw
;
}
}
MPerTileType
MPerTile_
;
NPerTileType
NPerTile_
;
KPerTileType
KPerTile_
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment