Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5765ba51
Commit
5765ba51
authored
Dec 30, 2024
by
coderfeli
Browse files
auto calculate hard code params
parent
3f9dbcac
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
15 deletions
+14
-15
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
...id/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
+14
-15
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
View file @
5765ba51
...
@@ -124,21 +124,24 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -124,21 +124,24 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static
constexpr
auto
CShuffleBlockTransferScalarPerVector_NPerBlock
=
static
constexpr
auto
CShuffleBlockTransferScalarPerVector_NPerBlock
=
CDEShuffleBlockTransferScalarPerVectors
{}[
I0
];
CDEShuffleBlockTransferScalarPerVectors
{}[
I0
];
// K1 should be Number<...>
// K1 should be Number<...>
static
constexpr
auto
AK0Number
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
AK0Number
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0Number
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
BK0Number
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1Number
=
Number
<
AK1Value
>
{};
static
constexpr
auto
AK1Number
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BlockSizeNumber
=
Number
<
BlockSize
>
{};
static
constexpr
auto
BlockSizeNumber
=
Number
<
BlockSize
>
{};
static
constexpr
index_t
NLane
=
32
;
static
constexpr
index_t
NWave
=
4
;
static
constexpr
index_t
KLane
=
2
;
static
constexpr
index_t
KRepeat
=
8
;
static_assert
(
NLane
*
NWave
*
KLane
==
BlockSize
);
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
using
mfma_selector
=
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
,
ComputeTypeB
>
;
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
mfma_selector
::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
KLane
=
mfma_selector
::
GetKPerXdlops
()
/
mfma_selector
::
GetK1PerXdlops
();
static
constexpr
index_t
KRepeat
=
KPerBlock
/
KLane
/
KPack
;
static
constexpr
index_t
NLane
=
NPerXdl
;
static
constexpr
index_t
NWave
=
NPerBlock
/
NPerXdl
/
NXdlPerWave
;
static_assert
(
NLane
*
NWave
*
KLane
==
BlockSize
);
static_assert
(
NXdlPerWave
==
1
,
"only 1 validated now, tbd next week"
);
static
constexpr
auto
MakeDsGridPointer
()
static
constexpr
auto
MakeDsGridPointer
()
{
{
return
generate_tuple
(
return
generate_tuple
(
...
@@ -152,10 +155,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -152,10 +155,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
,
ComputeTypeB
>::
selected_mfma
.
k_per_blk
);
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
...
@@ -321,11 +320,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -321,11 +320,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__
__device__
static
auto
MakeBGridDescriptor_Preshuffled
(
index_t
N0
,
index_t
K0
)
__host__
__device__
static
auto
MakeBGridDescriptor_Preshuffled
(
index_t
N0
,
index_t
K0
)
{
{
constexpr
index_t
N
KSWIZZLE_V
=
BlockSize
*
KPack
;
constexpr
index_t
N
kSwizzle
=
BlockSize
*
KPack
;
constexpr
index_t
N
KSWIZZLE_N
=
Number
<
N
KSWIZZLE_V
>
{};
constexpr
index_t
N
kSwizzleNumber
=
Number
<
N
kSwizzle
>
{};
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
N0
,
K0
,
N
KSWIZZLE_N
),
make_tuple
(
N0
,
K0
,
N
kSwizzleNumber
),
make_tuple
(
K0
*
N
KSWIZZLE_V
,
NKSWIZZLE_N
,
I1
)
make_tuple
(
K0
*
N
kSwizzle
,
NkSwizzleNumber
,
I1
)
);
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment