Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
fc62babb
Commit
fc62babb
authored
Mar 09, 2024
by
Jing Zhang
Browse files
seperate gfx12 blockwise_gemm
parent
f3111877
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
354 additions
and
49 deletions
+354
-49
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+347
-45
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+7
-4
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
fc62babb
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
namespace
ck
{
namespace
ck
{
#ifdef __gfx12__
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatA
,
typename
FloatB
,
typename
FloatB
,
...
@@ -66,13 +67,8 @@ struct BlockwiseGemmWMMA
...
@@ -66,13 +67,8 @@ struct BlockwiseGemmWMMA
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
// permutation
#ifdef __gfx12__
static
constexpr
index_t
A_KRow
=
2
;
static
constexpr
index_t
A_KRow
=
1
;
static
constexpr
index_t
B_KRow
=
2
;
static
constexpr
index_t
B_KRow
=
1
;
#else
static
constexpr
index_t
A_KRow
=
AEnableLds
?
1
:
2
;
static
constexpr
index_t
B_KRow
=
BEnableLds
?
1
:
2
;
#endif
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I5
);
...
@@ -114,11 +110,7 @@ struct BlockwiseGemmWMMA
...
@@ -114,11 +110,7 @@ struct BlockwiseGemmWMMA
const
auto
WMMA_a_idx
=
wmma_gemm
.
CalculateAThreadOriginDataIndex
();
const
auto
WMMA_a_idx
=
wmma_gemm
.
CalculateAThreadOriginDataIndex
();
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
#ifdef __gfx12__
return
make_tuple
(
0
,
0
,
waveId_m
,
wmma_gemm
.
GetSubGroupId
(),
WMMA_a_idx
,
0
);
return
make_tuple
(
0
,
0
,
waveId_m
,
wmma_gemm
.
GetSubGroupId
(),
WMMA_a_idx
,
0
);
#else
return
make_tuple
(
0
,
0
,
waveId_m
,
0
,
WMMA_a_idx
,
0
);
#endif
}
}
else
else
{
{
...
@@ -135,11 +127,7 @@ struct BlockwiseGemmWMMA
...
@@ -135,11 +127,7 @@ struct BlockwiseGemmWMMA
const
auto
WMMA_b_idx
=
wmma_gemm
.
CalculateBThreadOriginDataIndex
();
const
auto
WMMA_b_idx
=
wmma_gemm
.
CalculateBThreadOriginDataIndex
();
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
#ifdef __gfx12__
return
make_tuple
(
0
,
0
,
waveId_n
,
wmma_gemm
.
GetSubGroupId
(),
WMMA_b_idx
,
0
);
return
make_tuple
(
0
,
0
,
waveId_n
,
wmma_gemm
.
GetSubGroupId
(),
WMMA_b_idx
,
0
);
#else
return
make_tuple
(
0
,
0
,
waveId_n
,
0
,
WMMA_b_idx
,
0
);
#endif
}
}
else
else
{
{
...
@@ -203,6 +191,9 @@ struct BlockwiseGemmWMMA
...
@@ -203,6 +191,9 @@ struct BlockwiseGemmWMMA
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
"wrong!"
);
"wrong!"
);
static_assert
(
AEnableLds
==
true
,
"only support EnableLds"
);
static_assert
(
BEnableLds
==
true
,
"only support EnableLds"
);
}
}
// transposed WMMA output C' = B' * A'
// transposed WMMA output C' = B' * A'
...
@@ -303,7 +294,6 @@ struct BlockwiseGemmWMMA
...
@@ -303,7 +294,6 @@ struct BlockwiseGemmWMMA
static
constexpr
ABlockDesc
a_block_desc_k0_m0_m1_m2_k1
;
static
constexpr
ABlockDesc
a_block_desc_k0_m0_m1_m2_k1
;
static
constexpr
BBlockDesc
b_block_desc_k0_n0_n1_n2_k1
;
static
constexpr
BBlockDesc
b_block_desc_k0_n0_n1_n2_k1
;
#ifdef __gfx12__
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
...
@@ -428,7 +418,347 @@ struct BlockwiseGemmWMMA
...
@@ -428,7 +418,347 @@ struct BlockwiseGemmWMMA
});
});
}
}
}
}
protected:
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
A_K1
/
A_KRow
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
A_K1
>
{}),
make_tuple
(
Number
<
A_K1
>
{},
Number
<
KPack
/
A_KRow
>
{},
Number
<
A_K1
>
{},
Number
<
A_K1
>
{},
Number
<
A_K1
>
{},
Number
<
1
>
{}));
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
B_K1
/
B_KRow
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
B_K1
>
{}),
make_tuple
(
Number
<
B_K1
>
{},
Number
<
KPack
/
B_KRow
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
1
>
{}));
// C[M, N, NumRegWMMA]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
wmma_gemm
.
GetRegSizePerWmma
()));
template
<
bool
EnableLds
>
struct
AThreadCopySelector
;
template
<
>
struct
AThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
Sequence
<
KPack
/
A_K1
/
A_KRow
,
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
A_K1
,
A_K1
>
;
};
template
<
bool
EnableLds
>
struct
BThreadCopySelector
;
template
<
>
struct
BThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
Sequence
<
KPack
/
B_K1
/
B_KRow
,
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
B_K1
,
B_K1
>
;
};
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
};
#else
#else
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatAcc
,
typename
ABlockDesc
,
typename
BBlockDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerWMMA
,
index_t
NPerWMMA
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
bool
AEnableLds
=
true
,
bool
BEnableLds
=
true
,
bool
TransposeC
=
false
>
/* Option: Read from LDS, big buffer hold all threads required data
* Source
* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* Destination
* C, non-transpose
* thread level: MRepeat x NRepeat x MAccVgprs
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16
*
* Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
* Source:
* A(if skip LDS): MRepeat x KPack
* B(if skip LDS): NRepeat x KPack
* Destination
* C, non-transpose
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
*/
struct
BlockwiseGemmWMMA
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
WmmaK
=
Number
<
16
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static
constexpr
index_t
WaveSize
=
32
;
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
static
constexpr
index_t
A_KRow
=
AEnableLds
?
1
:
2
;
static
constexpr
index_t
B_KRow
=
BEnableLds
?
1
:
2
;
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I5
);
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
MRepeat
*
NRepeat
,
wmma_gemm
.
GetRegSizePerWmma
(),
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
ThisThreadBlock
::
GetThreadId
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
// Default, Block buffer in LDS, thread level offset enabled
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
if
constexpr
(
AEnableLds
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
WMMA_a_idx
=
wmma_gemm
.
CalculateAThreadOriginDataIndex
();
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return
make_tuple
(
0
,
0
,
waveId_m
,
0
,
WMMA_a_idx
,
0
);
}
else
{
return
make_tuple
(
0
,
0
,
0
,
0
,
0
,
0
);
}
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
if
constexpr
(
BEnableLds
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
WMMA_b_idx
=
wmma_gemm
.
CalculateBThreadOriginDataIndex
();
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return
make_tuple
(
0
,
0
,
waveId_n
,
0
,
WMMA_b_idx
,
0
);
}
else
{
return
make_tuple
(
0
,
0
,
0
,
0
,
0
,
0
);
}
}
template
<
index_t
m0
,
index_t
n0
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk
();
constexpr
auto
mrepeat_mwave_mperWMMA_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperWMMA_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperWMMA_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperWMMA_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
n0
,
waveId_n
,
blk_idx
[
I1
]))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
template
<
index_t
m0
,
index_t
n0
>
__device__
static
auto
CalculateCThreadOriginDataIndex7D
(
Number
<
m0
>
,
Number
<
n0
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk3D
();
return
make_tuple
(
Number
<
m0
>
{},
waveId_m
,
blk_idx
[
I0
],
Number
<
n0
>
{},
waveId_n
,
blk_idx
[
I1
],
blk_idx
[
I2
]);
}
using
Tuple6
=
decltype
(
CalculateAThreadOriginDataIndex
());
__host__
__device__
BlockwiseGemmWMMA
(
Tuple6
a_origin
=
CalculateAThreadOriginDataIndex
(),
Tuple6
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
{
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
"wrong!"
);
}
// transposed WMMA output C' = B' * A'
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
()
{
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
constexpr
auto
NAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
return
make_naive_tensor_descriptor_packed
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
NAccVgprs
));
}
// Thread level, register decriptor. Vector-write
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
{
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
constexpr
auto
MAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
constexpr
auto
AccStride
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
MAccVgprs
),
make_tuple
(
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
AccStride
));
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerWMMA
),
MWaves
,
MPerWMMA
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerWMMA
),
NWaves
,
NPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
);
}
// transposed WMMA output C' = B' * A'
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
()
{
constexpr
auto
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs
(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
);
}
// Provide dimension size
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
{
constexpr
auto
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
);
}
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static
constexpr
ABlockDesc
a_block_desc_k0_m0_m1_m2_k1
;
static
constexpr
BBlockDesc
b_block_desc_k0_n0_n1_n2_k1
;
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
...
@@ -560,28 +890,8 @@ struct BlockwiseGemmWMMA
...
@@ -560,28 +890,8 @@ struct BlockwiseGemmWMMA
});
});
}
}
}
}
#endif
protected:
protected:
#ifdef __gfx12__
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
A_K1
/
A_KRow
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
A_K1
>
{}),
make_tuple
(
Number
<
A_K1
>
{},
Number
<
KPack
/
A_KRow
>
{},
Number
<
A_K1
>
{},
Number
<
A_K1
>
{},
Number
<
A_K1
>
{},
Number
<
1
>
{}));
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
B_K1
/
B_KRow
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
B_K1
>
{}),
make_tuple
(
Number
<
B_K1
>
{},
Number
<
KPack
/
B_KRow
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
1
>
{}));
#else
static
constexpr
auto
a_thread_desc_
=
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
A_K1
/
A_KRow
>
{},
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
A_K1
/
A_KRow
>
{},
Number
<
MRepeat
>
{},
Number
<
MRepeat
>
{},
...
@@ -609,7 +919,6 @@ struct BlockwiseGemmWMMA
...
@@ -609,7 +919,6 @@ struct BlockwiseGemmWMMA
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
1
>
{}));
Number
<
1
>
{}));
#endif
// C[M, N, NumRegWMMA]
// C[M, N, NumRegWMMA]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
...
@@ -626,11 +935,7 @@ struct BlockwiseGemmWMMA
...
@@ -626,11 +935,7 @@ struct BlockwiseGemmWMMA
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
decltype
(
a_thread_desc_
),
#ifdef __gfx12__
Sequence
<
KPack
/
A_K1
/
A_KRow
,
1
,
1
,
1
,
1
,
A_K1
>
,
#else
Sequence
<
KPack
/
A_K1
/
A_KRow
,
1
,
1
,
A_KRow
,
1
,
A_K1
>
,
Sequence
<
KPack
/
A_K1
/
A_KRow
,
1
,
1
,
A_KRow
,
1
,
A_K1
>
,
#endif
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
5
,
A_K1
,
A_K1
,
...
@@ -666,11 +971,7 @@ struct BlockwiseGemmWMMA
...
@@ -666,11 +971,7 @@ struct BlockwiseGemmWMMA
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
#ifdef __gfx12__
Sequence
<
KPack
/
B_K1
/
B_KRow
,
1
,
1
,
1
,
1
,
B_K1
>
,
#else
Sequence
<
KPack
/
B_K1
/
B_KRow
,
1
,
1
,
B_KRow
,
1
,
B_K1
>
,
Sequence
<
KPack
/
B_K1
/
B_KRow
,
1
,
1
,
B_KRow
,
1
,
B_K1
>
,
#endif
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
5
,
B_K1
,
B_K1
,
...
@@ -698,5 +999,6 @@ struct BlockwiseGemmWMMA
...
@@ -698,5 +999,6 @@ struct BlockwiseGemmWMMA
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
};
};
#endif
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
fc62babb
...
@@ -94,13 +94,16 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -94,13 +94,16 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
(
MWaves
==
1
&&
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
?
false
:
true
;
(
MWaves
==
1
&&
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
?
false
:
true
;
// If true, LDS is used unconditionally
// If true, LDS is used unconditionally
#ifdef __gfx12__
static
constexpr
auto
AEnableLds_manu
=
true
;
static
constexpr
auto
BEnableLds_manu
=
true
;
#else
static
constexpr
auto
AEnableLds_manu
=
false
;
static
constexpr
auto
AEnableLds_manu
=
false
;
static
constexpr
auto
BEnableLds_manu
=
false
;
static
constexpr
auto
BEnableLds_manu
=
false
;
#endif
static
constexpr
auto
AEnableLds
=
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
false
;
// AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
true
;
// BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment