Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a5011336
Commit
a5011336
authored
Jul 21, 2022
by
Chao Liu
Browse files
refactor
parent
33975236
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
121 additions
and
125 deletions
+121
-125
include/ck/device_utility/io.hpp
include/ck/device_utility/io.hpp
+26
-0
include/ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d_xdl_cshuffle.hpp
...on/gpu/device/device_conv_fwd_multiple_d_xdl_cshuffle.hpp
+86
-114
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+9
-11
No files found.
include/ck/device_utility/io.hpp
0 → 100644
View file @
a5011336
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <vector>
#include "ck/tensor_description/tensor_descriptor.hpp"
template
<
typename
...
Ts
>
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck
::
TensorDescriptor
<
Ts
...
>&
desc
)
{
constexpr
ck
::
index_t
nDim
=
ck
::
remove_cvref_t
<
decltype
(
desc
)
>::
GetNumOfDimension
();
os
<<
"{"
;
ck
::
static_for
<
0
,
nDim
-
1
,
1
>
{}([
&
](
auto
i
)
{
os
<<
desc
.
GetLength
(
i
)
<<
", "
;
});
os
<<
desc
.
GetLength
(
ck
::
Number
<
nDim
-
1
>
{});
os
<<
"}"
;
return
os
;
}
include/ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d_xdl_cshuffle.hpp
View file @
a5011336
...
...
@@ -20,6 +20,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
#include "ck/device_utility/io.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -84,12 +85,6 @@ __global__ void
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
// input : input image A[N, C, Hi, Wi],
// input : weight B[K, C, Y, X],
// input : D0[N, K, Ho, Wo], D1[N, K, Ho, Wo], ...
// output : output image E[N, K, Ho, Wo]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
...
...
@@ -172,8 +167,6 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
BElementwiseOperation
,
CDEElementwiseOperation
>
{
namespace
ctc
=
ck
::
tensor_layout
::
convolution
;
using
DeviceOp
=
DeviceConvFwdMultipleD_Xdl_CShuffle
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
...
@@ -189,7 +182,9 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
template
<
typename
ALay
,
typename
std
::
enable_if
<
is_same_v
<
ALay
,
ctc
::
NWC
>,
bool
>::
type
=
false
>
template
<
typename
ALay
,
typename
std
::
enable_if
<
is_same_v
<
ALay
,
tensor_layout
::
convolution
::
NWC
>,
bool
>::
type
=
false
>
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_strides
,
...
...
@@ -299,7 +294,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
template
<
typename
ALay
,
typename
std
::
enable_if
<
is_same_v
<
ALay
,
ctc
::
NHWC
>,
bool
>::
type
=
false
>
typename
std
::
enable_if
<
is_same_v
<
ALay
,
tensor_layout
::
convolution
::
NHWC
>,
bool
>::
type
=
false
>
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_strides
,
...
...
@@ -423,7 +419,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
template
<
typename
ALay
,
typename
std
::
enable_if
<
is_same_v
<
ALay
,
ctc
::
NDHWC
>,
bool
>::
type
=
false
>
typename
std
::
enable_if
<
is_same_v
<
ALay
,
tensor_layout
::
convolution
::
NDHWC
>,
bool
>::
type
=
false
>
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_strides
,
...
...
@@ -570,11 +567,24 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// KYXC, K_YXC
// KZYXC, K_ZYXC
template
<
typename
BLay
,
typename
std
::
enable_if
<
is_same_v
<
BLay
,
ctc
::
KXC
>
||
is_same_v
<
BLay
,
ctc
::
KYXC
>
||
is_same_v
<
BLay
,
ctc
::
KZYXC
>
,
typename
std
::
enable_if
<
is_same_v
<
BLay
,
tensor_layout
::
convolution
::
KXC
>
||
is_same_v
<
BLay
,
tensor_layout
::
convolution
::
KYXC
>
||
is_same_v
<
BLay
,
tensor_layout
::
convolution
::
KZYXC
>
,
bool
>::
type
=
false
>
static
auto
MakeBGridDescriptor_N_K
(
index_t
GemmNRaw
,
index_t
GemmKRaw
)
static
auto
MakeBGridDescriptor_N_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_strides
)
{
const
index_t
K
=
b_k_c_xs_lengths
[
0
];
const
index_t
C
=
b_k_c_xs_lengths
[
1
];
const
index_t
GemmNRaw
=
K
;
const
index_t
GemmKRaw
=
C
*
std
::
accumulate
(
b_k_c_xs_lengths
.
begin
()
+
2
,
b_k_c_xs_lengths
.
begin
()
+
2
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
auto
wei_k_yxc_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
GemmNRaw
,
GemmKRaw
));
...
...
@@ -585,37 +595,16 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
template
<
typename
ELay
,
typename
std
::
enable_if
<
is_same_v
<
ELay
,
ctc
::
NWK
>
||
is_same_v
<
ELay
,
ctc
::
NHWK
>
||
is_same_v
<
ELay
,
ctc
::
NDHWK
>
,
typename
std
::
enable_if
<
is_same_v
<
ELay
,
tensor_layout
::
convolution
::
NWK
>
||
is_same_v
<
ELay
,
tensor_layout
::
convolution
::
NHWK
>
||
is_same_v
<
ELay
,
tensor_layout
::
convolution
::
NDHWK
>
,
bool
>::
type
=
false
>
static
auto
MakeEGridDescriptor_M_N
(
index_t
GemmMRaw
,
index_t
GemmN
)
{
const
index_t
GemmM
=
math
::
integer_least_multiple
(
GemmMRaw
,
MPerBlock
);
const
auto
out_gemmmraw_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
GemmM
,
GemmN
));
const
auto
out_gemmm_gemmn_grid_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmn_grid_desc
);
return
out_gemmm_gemmn_grid_desc
;
}
static
auto
MakeABEGridDescriptors
(
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
MakeEGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_strides
)
{
const
index_t
N
=
a_n_c_wis_lengths
[
0
];
const
index_t
K
=
b_k_c_xs_lengths
[
0
];
const
index_t
C
=
a_n_c_wis_lengths
[
1
];
const
index_t
N
=
e_n_k_wos_lengths
[
0
];
const
index_t
K
=
e_n_k_wos_lengths
[
1
];
const
index_t
GemmMRaw
=
N
*
std
::
accumulate
(
e_n_k_wos_lengths
.
begin
()
+
2
,
e_n_k_wos_lengths
.
begin
()
+
2
+
NDimSpatial
,
...
...
@@ -624,42 +613,22 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
const
index_t
GemmNRaw
=
K
;
const
index_t
GemmKRaw
=
C
*
std
::
accumulate
(
b_k_c_xs_lengths
.
begin
()
+
2
,
b_k_c_xs_lengths
.
begin
()
+
2
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
auto
out_gemmmraw_gemmnraw_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
GemmMRaw
,
GemmNRaw
));
// A:
const
auto
in_gemmm_gemmk_grid_desc
=
MakeAGridDescriptor_M_K
<
ALayout
>
(
a_n_c_wis_lengths
,
a_n_c_wis_strides
,
b_k_c_xs_lengths
,
b_k_c_xs_strides
,
e_n_k_wos_lengths
,
e_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
// B:
const
auto
wei_gemmn_gemmk_grid_desc
=
MakeBGridDescriptor_N_K
<
BLayout
>
(
GemmNRaw
,
GemmKRaw
);
// E:
const
auto
out_gemmm_gemmn_grid_desc
=
MakeEGridDescriptor_M_N
<
ELayout
>
(
GemmMRaw
,
GemmNRaw
);
return
make_tuple
(
in_gemmm_gemmk_grid_desc
,
wei_gemmn_gemmk_grid_desc
,
out_gemmm_gemmn_grid_desc
);
}
const
auto
out_gemmm_gemmn_grid_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_grid_desc
);
using
ABEGridDescs
=
decltype
(
MakeABEGridDescriptors
({},
{},
{},
{},
{},
{},
{},
{},
{},
{}));
return
out_gemmm_gemmn_grid_desc
;
}
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I0
])
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I1
])
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I2
])
>
;
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_M_K
<
ALayout
>
({},
{},
{},
{},
{},
{},
{},
{},
{},
{}))
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_N_K
<
BLayout
>
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}))
>
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_
k0mk1_k0nk1_mn_
xdl_cshuffle
<
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CShuffleDataType
,
...
...
@@ -739,9 +708,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b
)},
p_ds_grid_
{},
// FIXME
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
a_grid_desc_m_k_
{},
b_grid_desc_n_k_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{},
a_grid_desc_ak0_m_ak1_
{},
b_grid_desc_bk0_n_bk1_
{},
e
_grid_desc_m
_n
_
{},
ds
_grid_desc_m
block_mperblock_nblock_nperblock
_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{},
a_element_op_
{
a_element_op
},
...
...
@@ -760,32 +733,33 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
const
auto
descs
=
DeviceOp
::
MakeABEGridDescriptors
(
a_n_c_wis_lengths
,
a_n_c_wis_strides
,
b_k_c_xs_lengths
,
b_k_c_xs_strides
,
e_n_k_wos_lengths
,
e_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
const
auto
a_grid_desc_m_k
=
descs
[
I0
];
const
auto
b_grid_desc_n_k
=
descs
[
I1
];
e_grid_desc_m_n_
=
descs
[
I2
];
a_grid_desc_m_k_
=
DeviceOp
::
MakeAGridDescriptor_M_K
<
ALayout
>
(
a_n_c_wis_lengths
,
a_n_c_wis_strides
,
b_k_c_xs_lengths
,
b_k_c_xs_strides
,
e_n_k_wos_lengths
,
e_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
b_grid_desc_n_k_
=
DeviceOp
::
MakeBGridDescriptor_N_K
<
BLayout
>
(
b_k_c_xs_lengths
,
b_k_c_xs_strides
);
e_grid_desc_m_n_
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
e_n_k_wos_lengths
,
e_n_k_wos_strides
);
a_grid_desc_ak0_m_ak1_
=
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k
);
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
);
b_grid_desc_bk0_n_bk1_
=
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k
);
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k
_
);
block_2_etile_map_
=
Block2ETileMap
{
e_grid_desc_m_n_
};
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k_
,
b_grid_desc_n_k_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
...
...
@@ -801,14 +775,19 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
EDataType
*
p_e_grid_
;
// tensor descriptors
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
// FIXME: don't assume D and E desc are the same type
StaticallyIndexedArray
<
EGridDesc_M_N
,
NumDTensor
>
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
// FIXME: don't assume D and E desc are the same type
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
// FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N
e_grid_desc_m_n_
;
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
...
...
@@ -844,27 +823,18 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
{
#if 1
{
std
::
cout
<<
"arg.a_grid_desc_ak0_m_ak1_{"
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_bk0_n_bk1_{"
<<
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.e_grid_desc_m_n_{ "
<<
arg
.
e_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
e_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"A[M, K]: "
<<
arg
.
a_grid_desc_m_k_
<<
std
::
endl
;
std
::
cout
<<
"B[N, K]: "
<<
arg
.
b_grid_desc_n_k_
<<
std
::
endl
;
std
::
cout
<<
"E[M, N]: "
<<
arg
.
e_grid_desc_m_n_
<<
std
::
endl
;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
ak0_m_ak1
_
,
arg
.
b_grid_desc_
bk0_n_bk1
_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
m_k
_
,
arg
.
b_grid_desc_
n_k
_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmMultipleD_
k0mk1_k0nk1_mn_
xdl_cshuffle has invalid setting"
);
"wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"
);
}
const
index_t
grid_size
=
...
...
@@ -931,6 +901,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
namespace
ctc
=
tensor_layout
::
convolution
;
// check device
if
(
get_device_name
()
==
"gfx908"
)
{
...
...
@@ -1049,8 +1021,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
ak0_m_ak1
_
,
arg
.
b_grid_desc_
bk0_n_bk1
_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
m_k
_
,
arg
.
b_grid_desc_
n_k
_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
a5011336
...
...
@@ -70,7 +70,7 @@ template <typename FloatAB,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
>
struct
GridwiseGemmMultipleD_
k0mk1_k0nk1_mn_
xdl_cshuffle
struct
GridwiseGemmMultipleD_xdl_cshuffle
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
...
@@ -222,20 +222,19 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
Block2ETileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2ETileMap
&
block_2_etile_map
)
template
<
typename
Block2ETileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2ETileMap
&
block_2_etile_map
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_
ak0_m_ak1
.
GetLength
(
I
1
);
const
auto
N
=
b_grid_desc_
bk0_n_bk1
.
GetLength
(
I
1
);
const
auto
K
=
a_grid_desc_
ak0_m_ak1
.
GetLength
(
I
0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
const
auto
M
=
a_grid_desc_
m_k
.
GetLength
(
I
0
);
const
auto
N
=
b_grid_desc_
n_k
.
GetLength
(
I
0
);
const
auto
K
=
a_grid_desc_
m_k
.
GetLength
(
I
1
);
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)))
return
false
;
...
...
@@ -271,7 +270,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
remove_cvref_t
<
decltype
(
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
DefaultBGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment