Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0530fd66
Commit
0530fd66
authored
Jul 20, 2022
by
Chao Liu
Browse files
update gemm multi-d
parent
05c484e2
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
138 additions
and
201 deletions
+138
-201
include/ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d.hpp
...ensor_operation/gpu/device/device_conv_fwd_multiple_d.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_convnd_fwd_multiple_d_nwc_kxc_nwk_xdl_cshuffle.hpp
...device_convnd_fwd_multiple_d_nwc_kxc_nwk_xdl_cshuffle.hpp
+59
-160
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+78
-40
No files found.
include/ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d.hpp
View file @
0530fd66
...
...
@@ -11,7 +11,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
//
GEMM
:
//
Convolution Forward
:
// input : input image A[N, Hi, Wi, C],
// input : weight B[K, Y, X, C],
// input : D0[N, Ho, Wo, K], D1[N, Ho, Wo, K], ...
...
...
include/ck/tensor_operation/gpu/device/device_convnd_fwd_multiple_d_nwc_kxc_nwk_xdl_cshuffle.hpp
View file @
0530fd66
...
...
@@ -197,18 +197,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
wei_gemmn_gemmk_grid_desc
=
matrix_padder
.
PadBDescriptor_N_K
(
wei_k_yxc_grid_desc
);
const
auto
GemmN
=
wei_gemmn_gemmk_grid_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
wei_gemmn_gemmk_grid_desc
.
GetLength
(
I1
);
const
index_t
GemmK0
=
GemmK
/
GemmK1Number
;
// wei_gemmk0_gemmn_gemmk1_grid_desc
return
transform_tensor_descriptor
(
wei_gemmn_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
wei_gemmn_gemmk_grid_desc
;
}
static
auto
GetOutputTensorDescriptor
(
index_t
GemmMRaw
,
index_t
GemmN
)
...
...
@@ -250,18 +239,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmk_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_grid_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
...
...
@@ -286,19 +264,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
in_gemmk0_gemmm_gemmk1_grid_desc
;
return
in_gemmm_gemmk_grid_desc
;
}
else
{
...
...
@@ -337,19 +303,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmk_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmM
),
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
return
in_gemmk0_gemmm_gemmk1_grid_desc
;
return
in_gemmm_gemmk_grid_desc
;
}
}
...
...
@@ -384,18 +338,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_grid_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
...
...
@@ -422,19 +365,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmk_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
in_gemmk0_gemmm_gemmk1_grid_desc
;
return
in_gemmm_gemmk_grid_desc
;
}
else
{
...
...
@@ -482,19 +413,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmk_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmM
),
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
return
in_gemmk0_gemmm_gemmk1_grid_desc
;
return
in_gemmm_gemmk_grid_desc
;
}
}
...
...
@@ -532,18 +451,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_grid_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
...
...
@@ -573,19 +481,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
in_gemmk0_gemmm_gemmk1_grid_desc
;
return
in_gemmm_gemmk_grid_desc
;
}
else
{
...
...
@@ -646,19 +542,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmM
),
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
return
in_gemmk0_gemmm_gemmk1_grid_desc
;
return
in_gemmm_gemmk_grid_desc
;
}
}
...
...
@@ -696,11 +580,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
index_t
GemmNRaw
=
K
;
const
index_t
GemmKRaw
=
GetGemmKRaw
(
C
,
filter_spatial_lengths
);
// TODO: remove
assert
(
GemmKRaw
%
GemmK1Number
==
0
);
// A:
const
auto
in_
gemmk0_
gemmm_gemmk
1
_grid_desc
=
const
auto
in_gemmm_gemmk_grid_desc
=
GetInputTensorDescriptor
<
NDimSpatial
>
(
N
,
C
,
GemmMRaw
,
...
...
@@ -714,15 +595,13 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
input_right_pads
);
// B:
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
GetWeightTensorDescriptor
(
GemmNRaw
,
GemmKRaw
);
const
auto
wei_gemmn_gemmk_grid_desc
=
GetWeightTensorDescriptor
(
GemmNRaw
,
GemmKRaw
);
// E:
const
auto
out_gemmm_gemmn_grid_desc
=
GetOutputTensorDescriptor
(
GemmMRaw
,
GemmNRaw
);
return
make_tuple
(
in_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmm_gemmn_grid_desc
);
return
make_tuple
(
in_gemmm_gemmk_grid_desc
,
wei_gemmn_gemmk_grid_desc
,
out_gemmm_gemmn_grid_desc
);
}
template
<
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
...
...
@@ -748,8 +627,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
using
ABEGridDescs
=
decltype
(
GetABEGridDesc
<
NDimSpatial
>
());
using
AGridDesc_
AK0_M_AK1
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I0
])
>
;
using
BGridDesc_
BK0_N_BK1
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I1
])
>
;
using
AGridDesc_
M_K
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I0
])
>
;
using
BGridDesc_
N_K
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I1
])
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I2
])
>
;
// GridwiseGemm
...
...
@@ -763,8 +642,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_
AK0_M_AK1
,
BGridDesc_
BK0_N_BK1
,
AGridDesc_
M_K
,
BGridDesc_
N_K
,
EGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
...
...
@@ -799,11 +678,12 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
#if 0
using Block2ETileMap = BlockToCTileMap_M00_N0_M01<MPerBlock, NPerBlock, EGridDesc_M_N>;
#else
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
#endif
// Argument
struct
Argument
:
public
BaseArgument
...
...
@@ -856,10 +736,15 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
input_left_pads
,
input_right_pads
);
a_grid_desc_
ak0_m_ak1_
=
descs
[
I0
];
b_grid_desc_
bk0_n_bk1_
=
descs
[
I1
];
const
auto
a_grid_desc_
m_k
=
descs
[
I0
];
const
auto
b_grid_desc_
n_k
=
descs
[
I1
];
e_grid_desc_m_n_
=
descs
[
I2
];
a_grid_desc_ak0_m_ak1_
=
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k
);
b_grid_desc_bk0_n_bk1_
=
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k
);
block_2_etile_map_
=
Block2ETileMap
{
e_grid_desc_m_n_
};
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
...
...
@@ -917,7 +802,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if
0
#if
1
{
std
::
cout
<<
"arg.a_grid_desc_ak0_m_ak1_{"
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
<<
", "
...
...
@@ -1010,6 +895,20 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
#if 1
{
std
::
cout
<<
"arg.a_grid_desc_ak0_m_ak1_{"
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_bk0_n_bk1_{"
<<
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.e_grid_desc_m_n_{ "
<<
arg
.
e_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
e_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
#endif
if
(
ck
::
get_device_name
()
==
"gfx908"
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
float
>
||
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
0530fd66
...
...
@@ -18,8 +18,8 @@
namespace
ck
{
// GEMM:
// input : A[AK0, M, AK1]
// input : B[AK0, N, AK1]
// input : A[AK0
PerBlock
, M, AK1]
// input : B[AK0
PerBlock
, N, AK1]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
...
...
@@ -35,8 +35,8 @@ template <typename FloatAB,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
typename
AGridDesc_
AK0_M_AK1
,
typename
BGridDesc_
BK0_N_BK1
,
typename
AGridDesc_
M_K
,
typename
BGridDesc_
N_K
,
typename
EGridDesc_M_N
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
...
...
@@ -84,10 +84,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
AK0PerBlock
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0PerBlock
=
Number
<
KPerBlock
/
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
@@ -97,7 +97,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
AK0
PerBlock
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
...
...
@@ -105,7 +105,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
BK0
PerBlock
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
...
...
@@ -164,8 +164,65 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
c_block_size
*
sizeof
(
FloatCShuffle
));
}
__host__
__device__
static
constexpr
auto
MakeDefaultAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
__host__
__device__
static
constexpr
auto
MakeDefaultBGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
__host__
__device__
static
constexpr
auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
const
auto
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
e_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
e_grid_desc_mblock_mperblock_nblock_nperblock
;
}
// return block_id to E matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
EGridDesc_M_N
>
(
e_grid_desc_m_n
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2ETileMap
>
template
<
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
Block2ETileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
...
...
@@ -210,32 +267,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
__host__
__device__
static
constexpr
auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
const
auto
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
e_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
e_grid_desc_mblock_mperblock_nblock_nperblock
;
}
// return block_id to E matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
EGridDesc_M_N
>
(
e_grid_desc_m_n
);
}
using
DefaultAGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
DefaultBGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
...
...
@@ -245,7 +280,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
template
<
bool
HasMainKBlockLoop
,
typename
Block2ETileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
...
@@ -316,7 +354,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
Sequence
<
AK0
PerBlock
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -347,7 +385,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
Sequence
<
BK0
PerBlock
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment