Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f221c68e
Commit
f221c68e
authored
Feb 27, 2024
by
Jing Zhang
Browse files
merge navi3_ref
parent
37560a6d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
98 deletions
+72
-98
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+63
-88
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+9
-10
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
f221c68e
...
@@ -66,8 +66,8 @@ struct BlockwiseGemmWMMA
...
@@ -66,8 +66,8 @@ struct BlockwiseGemmWMMA
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
// permutation
static
constexpr
index_t
A_KRow
=
AEnableLds
?
1
:
2
;
static
constexpr
index_t
A_KRow
=
2
;
static
constexpr
index_t
B_KRow
=
BEnableLds
?
1
:
2
;
static
constexpr
index_t
B_KRow
=
2
;
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I5
);
...
@@ -213,19 +213,20 @@ struct BlockwiseGemmWMMA
...
@@ -213,19 +213,20 @@ struct BlockwiseGemmWMMA
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
constexpr
auto
MAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
constexpr
auto
MSubGroup
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I0
];
constexpr
auto
AccStride
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I3
];
constexpr
auto
NThreadPerSubGroup
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I1
];
return
make_naive_tensor_descriptor
(
constexpr
auto
MAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
return
make_naive_tensor_descriptor_packed
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
MAccVgprs
),
make_tuple
(
Number
<
MRepeat
>
{},
make_tuple
(
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
I1
,
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
MSubGroup
,
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
Number
<
NRepeat
>
{},
MAccVgprs
*
AccStride
,
I1
,
MAccVgprs
*
AccStride
,
NThreadPerSubGroup
,
MAccVgprs
*
AccStride
,
MAccVgprs
));
AccStride
));
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
...
@@ -324,30 +325,25 @@ struct BlockwiseGemmWMMA
...
@@ -324,30 +325,25 @@ struct BlockwiseGemmWMMA
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
b_thread_buf
);
vector_type
<
FloatA
,
KPack
>
a_thread_vec
;
vector_type
<
FloatA
,
KPack
/
A_KRow
>
a_thread_vec
;
vector_type
<
FloatB
,
KPack
>
b_thread_vec
;
vector_type
<
FloatB
,
KPack
/
B_KRow
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
KPack
/
A_KRow
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
/
A_KRow
,
make_tuple
(
i
/
A_K1
,
m0
,
0
,
0
,
0
,
i
%
A_K1
))
>
{}];
m0
,
});
0
,
(
i
/
A_K1
)
%
A_KRow
,
static_for
<
0
,
KPack
/
B_KRow
,
1
>
{}([
&
](
auto
i
)
{
0
,
i
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
/
B_KRow
,
make_tuple
(
i
/
B_K1
,
n0
,
0
,
0
,
0
,
i
%
B_K1
))
>
{}];
n0
,
0
,
(
i
/
B_K1
)
%
B_KRow
,
0
,
i
%
B_K1
))
>
{}];
});
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_a
=
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
typename
vector_type
<
FloatA
,
WmmaK
/
A_KRow
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
/
B_KRow
>::
type
;
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
...
@@ -383,30 +379,25 @@ struct BlockwiseGemmWMMA
...
@@ -383,30 +379,25 @@ struct BlockwiseGemmWMMA
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
);
vector_type
<
FloatA
,
KPack
>
a_thread_vec
;
vector_type
<
FloatA
,
KPack
/
A_KRow
>
a_thread_vec
;
vector_type
<
FloatB
,
KPack
>
b_thread_vec
;
vector_type
<
FloatB
,
KPack
/
B_KRow
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
KPack
/
A_KRow
,
1
>
{}([
&
](
auto
i
)
{
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
/
B_KRow
,
n0
,
0
,
(
i
/
B_K1
)
%
B_KRow
,
0
,
i
%
B_K1
))
>
{}];
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
/
A_KRow
,
make_tuple
(
i
/
A_K1
,
m0
,
0
,
0
,
0
,
i
%
A_K1
))
>
{}];
m0
,
});
0
,
(
i
/
A_K1
)
%
A_KRow
,
static_for
<
0
,
KPack
/
B_KRow
,
1
>
{}([
&
](
auto
i
)
{
0
,
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
i
%
A_K1
))
>
{}];
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
,
n0
,
0
,
0
,
0
,
i
%
B_K1
))
>
{}];
});
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_a
=
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
typename
vector_type
<
FloatA
,
WmmaK
/
A_KRow
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
/
B_KRow
>::
type
;
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
...
@@ -422,33 +413,23 @@ struct BlockwiseGemmWMMA
...
@@ -422,33 +413,23 @@ struct BlockwiseGemmWMMA
}
}
protected:
protected:
static
constexpr
auto
a_thread_desc_
=
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
A_K1
/
A_KRow
>
{},
make_tuple
(
Number
<
KPack
/
A_K1
/
A_KRow
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
A_K1
>
{}),
Number
<
MRepeat
>
{},
make_tuple
(
Number
<
A_K1
>
{},
I1
,
Number
<
KPack
/
A_KRow
>
{},
Number
<
A_KRow
>
{},
Number
<
A_K1
>
{},
I1
,
Number
<
A_K1
>
{},
Number
<
A_K1
>
{}),
Number
<
A_K1
>
{},
make_tuple
(
Number
<
A_K1
*
A_KRow
>
{},
Number
<
1
>
{}));
Number
<
KPack
>
{},
Number
<
A_K1
*
A_KRow
>
{},
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor
(
Number
<
A_K1
>
{},
make_tuple
(
Number
<
KPack
/
B_K1
/
B_KRow
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
B_K1
>
{}),
Number
<
A_K1
>
{},
make_tuple
(
Number
<
B_K1
>
{},
Number
<
1
>
{}));
Number
<
KPack
/
B_KRow
>
{},
Number
<
B_K1
>
{},
static
constexpr
auto
b_thread_desc_
=
Number
<
B_K1
>
{},
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
B_K1
/
B_KRow
>
{},
Number
<
B_K1
>
{},
Number
<
NRepeat
>
{},
Number
<
1
>
{}));
I1
,
Number
<
B_KRow
>
{},
I1
,
Number
<
B_K1
>
{}),
make_tuple
(
Number
<
B_K1
*
B_KRow
>
{},
Number
<
KPack
>
{},
Number
<
B_K1
*
B_KRow
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
1
>
{}));
// C[M, N, NumRegWMMA]
// C[M, N, NumRegWMMA]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
...
@@ -465,7 +446,7 @@ struct BlockwiseGemmWMMA
...
@@ -465,7 +446,7 @@ struct BlockwiseGemmWMMA
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
decltype
(
a_thread_desc_
),
Sequence
<
KPack
/
A_K1
/
A_KRow
,
1
,
1
,
A_KRow
,
1
,
A_K1
>
,
Sequence
<
KPack
/
A_K1
/
A_KRow
,
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
5
,
A_K1
,
A_K1
,
...
@@ -475,7 +456,7 @@ struct BlockwiseGemmWMMA
...
@@ -475,7 +456,7 @@ struct BlockwiseGemmWMMA
template
<
>
template
<
>
struct
AThreadCopySelector
<
false
>
struct
AThreadCopySelector
<
false
>
{
{
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic
_InterRow
<
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatA
,
FloatA
,
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
...
@@ -484,10 +465,7 @@ struct BlockwiseGemmWMMA
...
@@ -484,10 +465,7 @@ struct BlockwiseGemmWMMA
Sequence
<
KPack
/
A_K1
/
A_KRow
,
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
KPack
/
A_K1
/
A_KRow
,
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
5
,
A_K1
,
A_K1
>
;
0x76543210
,
0xfedcba98
,
TransposeC
?
false
:
true
>
;
};
};
template
<
bool
EnableLds
>
template
<
bool
EnableLds
>
...
@@ -501,7 +479,7 @@ struct BlockwiseGemmWMMA
...
@@ -501,7 +479,7 @@ struct BlockwiseGemmWMMA
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
Sequence
<
KPack
/
B_K1
/
B_KRow
,
1
,
1
,
B_KRow
,
1
,
B_K1
>
,
Sequence
<
KPack
/
B_K1
/
B_KRow
,
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
5
,
B_K1
,
B_K1
,
...
@@ -511,7 +489,7 @@ struct BlockwiseGemmWMMA
...
@@ -511,7 +489,7 @@ struct BlockwiseGemmWMMA
template
<
>
template
<
>
struct
BThreadCopySelector
<
false
>
struct
BThreadCopySelector
<
false
>
{
{
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic
_InterRow
<
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatB
,
FloatB
,
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
...
@@ -520,10 +498,7 @@ struct BlockwiseGemmWMMA
...
@@ -520,10 +498,7 @@ struct BlockwiseGemmWMMA
Sequence
<
KPack
/
B_K1
/
B_KRow
,
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
KPack
/
B_K1
/
B_KRow
,
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
5
,
B_K1
,
B_K1
>
;
0x76543210
,
0xfedcba98
,
TransposeC
?
true
:
false
>
;
};
};
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
f221c68e
...
@@ -141,8 +141,8 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
...
@@ -141,8 +141,8 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
// Wave mode dependent propety
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m
_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
k
_per_wmma
/
2
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n
_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
k
_per_wmma
/
2
*
src_b_data_size
/
4
;
// * num_acc_vgprs_per_wave alone M direction
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
// * num_subgroups alone M direction
static
constexpr
index_t
num_acc_vgprs_per_wave
=
static
constexpr
index_t
num_acc_vgprs_per_wave
=
...
@@ -390,7 +390,7 @@ struct WmmaSelector
...
@@ -390,7 +390,7 @@ struct WmmaSelector
static_assert
(
selected_wmma
.
k_per_wmma
==
16
,
"WRONG! WMMA_M must equal to 16"
);
static_assert
(
selected_wmma
.
k_per_wmma
==
16
,
"WRONG! WMMA_M must equal to 16"
);
static_assert
(
selected_wmma
.
wave_size
*
selected_wmma
.
num_acc_vgprs_per_wave
*
static_assert
(
selected_wmma
.
wave_size
*
selected_wmma
.
num_acc_vgprs_per_wave
*
selected_wmma
.
acc_data_size
*
selected_wmma
.
acc_pack_number
==
selected_wmma
.
acc_data_size
==
selected_wmma
.
m_per_wmma
*
selected_wmma
.
n_per_wmma
*
4
,
selected_wmma
.
m_per_wmma
*
selected_wmma
.
n_per_wmma
*
4
,
"WRONG! Invalid Number of Accumulator Register"
);
"WRONG! Invalid Number of Accumulator Register"
);
}
}
...
@@ -510,7 +510,7 @@ struct WmmaGemm
...
@@ -510,7 +510,7 @@ struct WmmaGemm
__device__
static
constexpr
index_t
GetRegSizePerWmma
()
__device__
static
constexpr
index_t
GetRegSizePerWmma
()
{
{
return
wmma_instr
.
num_acc_vgprs_per_wave
*
wmma_instr
.
acc_pack_number
;
return
wmma_instr
.
num_acc_vgprs_per_wave
;
}
}
__device__
static
constexpr
index_t
GetWaveSize
()
{
return
wmma_instr
.
wave_size
;
}
__device__
static
constexpr
index_t
GetWaveSize
()
{
return
wmma_instr
.
wave_size
;
}
...
@@ -566,12 +566,14 @@ struct WmmaGemm
...
@@ -566,12 +566,14 @@ struct WmmaGemm
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
{
return
TransposeC
?
GetLaneIdUnderSubGroup
()
:
GetSwizzledLaneIdLow
();
// return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
return
GetLaneIdUnderSubGroup
();
}
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
{
return
TransposeC
?
GetSwizzledLaneIdLow
()
:
GetLaneIdUnderSubGroup
();
// return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
return
GetLaneIdUnderSubGroup
();
}
}
__device__
static
CIndex
GetBeginOfThreadBlk
()
__device__
static
CIndex
GetBeginOfThreadBlk
()
...
@@ -597,10 +599,7 @@ struct WmmaGemm
...
@@ -597,10 +599,7 @@ struct WmmaGemm
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
()
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
()
{
{
return
make_tuple
(
I1
,
return
make_tuple
(
I1
,
I1
,
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{});
I1
,
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{},
Number
<
wmma_instr
.
acc_pack_number
>
{});
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment