Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c713d224
"src/libtorchaudio/forced_align/compute.h" did not exist on "e7935cff440360e29116fcd28da4c0ddc770b866"
Commit
c713d224
authored
May 19, 2023
by
aska-0096
Browse files
Update low level abstration of blockwise gemm wmma
parent
2ec3f4c3
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
307 additions
and
207 deletions
+307
-207
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+2
-2
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+196
-107
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+18
-11
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+2
-10
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+83
-72
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+5
-4
No files found.
example/01_gemm/gemm_wmma_fp16.cpp
View file @
c713d224
include/ck/host_utility/kernel_launch.hpp
View file @
c713d224
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
c713d224
...
@@ -55,6 +55,7 @@ struct BlockwiseGemmWMMA
...
@@ -55,6 +55,7 @@ struct BlockwiseGemmWMMA
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
WmmaK
=
Number
<
16
>
{};
static
constexpr
auto
WmmaK
=
Number
<
16
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
@@ -62,8 +63,13 @@ struct BlockwiseGemmWMMA
...
@@ -62,8 +63,13 @@ struct BlockwiseGemmWMMA
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static
constexpr
index_t
WaveSize
=
32
;
static
constexpr
index_t
WaveSize
=
32
;
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I4
);
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I4
);
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
static
constexpr
index_t
A_KRow
=
AEnableLds
?
1
:
2
;
static
constexpr
index_t
B_KRow
=
BEnableLds
?
1
:
2
;
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I5
);
static
constexpr
auto
wmma_gemm
=
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
...
@@ -71,10 +77,6 @@ struct BlockwiseGemmWMMA
...
@@ -71,10 +77,6 @@ struct BlockwiseGemmWMMA
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
// Read from Lds, duplicate Twice, Read from VGPR, no duplication.
static
constexpr
index_t
A_Data_Duplicated_Rate
=
AEnableLds
?
2
:
1
;
static
constexpr
index_t
B_Data_Duplicated_Rate
=
BEnableLds
?
2
:
1
;
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
FloatAcc
,
MRepeat
*
NRepeat
,
MRepeat
*
NRepeat
,
...
@@ -105,12 +107,12 @@ struct BlockwiseGemmWMMA
...
@@ -105,12 +107,12 @@ struct BlockwiseGemmWMMA
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
WMMA_a_idx
=
wmma_gemm
.
CalculateAThreadOriginDataIndex
();
const
auto
WMMA_a_idx
=
wmma_gemm
.
CalculateAThreadOriginDataIndex
();
// |KRepeat |MRepeat|MWave |MLane
|KPack
// |KRepeat |MRepeat|MWave
|KRow
|MLane |KPack
return
make_tuple
(
0
,
0
,
waveId_m
,
WMMA_a_idx
,
0
);
return
make_tuple
(
0
,
0
,
waveId_m
,
0
,
WMMA_a_idx
,
0
);
}
}
else
else
{
{
return
make_tuple
(
0
,
0
,
0
,
0
,
0
);
return
make_tuple
(
0
,
0
,
0
,
0
,
0
,
0
);
}
}
}
}
...
@@ -122,12 +124,12 @@ struct BlockwiseGemmWMMA
...
@@ -122,12 +124,12 @@ struct BlockwiseGemmWMMA
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
WMMA_b_idx
=
wmma_gemm
.
CalculateBThreadOriginDataIndex
();
const
auto
WMMA_b_idx
=
wmma_gemm
.
CalculateBThreadOriginDataIndex
();
// |KRepeat |NRepeat|Nwave |NLane
|KPack
// |KRepeat |NRepeat|Nwave
|KRow
|NLane |KPack
return
make_tuple
(
0
,
0
,
waveId_n
,
WMMA_b_idx
,
0
);
return
make_tuple
(
0
,
0
,
waveId_n
,
0
,
WMMA_b_idx
,
0
);
}
}
else
else
{
{
return
make_tuple
(
0
,
0
,
0
,
0
,
0
);
return
make_tuple
(
0
,
0
,
0
,
0
,
0
,
0
);
}
}
}
}
...
@@ -173,9 +175,9 @@ struct BlockwiseGemmWMMA
...
@@ -173,9 +175,9 @@ struct BlockwiseGemmWMMA
Number
<
m0
>
{},
waveId_m
,
blk_idx
[
I0
],
Number
<
n0
>
{},
waveId_n
,
blk_idx
[
I1
],
blk_idx
[
I2
]);
Number
<
m0
>
{},
waveId_m
,
blk_idx
[
I0
],
Number
<
n0
>
{},
waveId_n
,
blk_idx
[
I1
],
blk_idx
[
I2
]);
}
}
using
Tuple
5
=
decltype
(
CalculateAThreadOriginDataIndex
());
using
Tuple
6
=
decltype
(
CalculateAThreadOriginDataIndex
());
__host__
__device__
BlockwiseGemmWMMA
(
Tuple
5
a_origin
=
CalculateAThreadOriginDataIndex
(),
__host__
__device__
BlockwiseGemmWMMA
(
Tuple
6
a_origin
=
CalculateAThreadOriginDataIndex
(),
Tuple
5
b_origin
=
CalculateBThreadOriginDataIndex
())
Tuple
6
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
{
{
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
(),
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
(),
...
@@ -224,18 +226,6 @@ struct BlockwiseGemmWMMA
...
@@ -224,18 +226,6 @@ struct BlockwiseGemmWMMA
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
AccStride
));
AccStride
));
#if 0
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{},
I1,
MSubGroup,
Number<NRepeat>{},
I1,
NThreadPerSubGroup,
MAccVgprs));
#endif
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
...
@@ -312,36 +302,26 @@ struct BlockwiseGemmWMMA
...
@@ -312,36 +302,26 @@ struct BlockwiseGemmWMMA
// basic intrinsic to determine loopover direction
// basic intrinsic to determine loopover direction
if
constexpr
(
MRepeat
<
NRepeat
)
if
constexpr
(
MRepeat
<
NRepeat
)
{
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
// read A
a_thread_copy_
.
Run
(
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
*
A_Data_Duplicated_Rate
/
2
>
{},
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_block_buf
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
// read B
b_thread_copy_
.
Run
(
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
*
B_Data_Duplicated_Rate
/
2
>
{},
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_block_buf
,
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
b_thread_buf
);
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
...
@@ -350,12 +330,100 @@ struct BlockwiseGemmWMMA
...
@@ -350,12 +330,100 @@ struct BlockwiseGemmWMMA
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
,
m0
,
0
,
0
,
i
%
A_K1
))
>
{}];
make_tuple
(
i
/
A_K1
/
A_KRow
,
m0
,
0
,
(
i
/
A_K1
)
%
A_KRow
,
0
,
i
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
,
n0
,
0
,
0
,
i
%
B_K1
))
>
{}];
make_tuple
(
i
/
B_K1
/
B_KRow
,
n0
,
0
,
(
i
/
B_K1
)
%
B_KRow
,
0
,
i
%
B_K1
))
>
{}];
});
});
#if 0
if (get_thread_local_1d_id() == 0){
printf("repeat: m,n,k:(%02d, %02d, %02d) a_thread_buf: %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x\n",
m0.value, n0.value, k.value,
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(0 / A_K1, m0, 0, 0, 0 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(1 / A_K1, m0, 0, 0, 1 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(2 / A_K1, m0, 0, 0, 2 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(3 / A_K1, m0, 0, 0, 3 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(4 / A_K1, m0, 0, 0, 4 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(5 / A_K1, m0, 0, 0, 5% A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(6 / A_K1, m0, 0, 0, 6 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(7 / A_K1, m0, 0, 0, 7 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(8 / A_K1, m0, 0, 0, 8 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(9 / A_K1, m0, 0, 0, 9% A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(10 / A_K1, m0, 0, 0, 10 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(11 / A_K1, m0, 0, 0, 11 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(12 / A_K1, m0, 0, 0, 12 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(13 / A_K1, m0, 0, 0, 13 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(14 / A_K1, m0, 0, 0, 14 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(15 / A_K1, m0, 0, 0, 15% A_K1))>{}])))
);
}
// if (get_thread_local_1d_id() == 0){
// printf("repeat: m,n,k:(%02d, %02d, %02d) b_thread_buf: %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x\n",
// m0.value, n0.value, k.value,
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(0 / B_K1, n0, 0, 0, 0 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(1 / B_K1, n0, 0, 0, 1 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(2 / B_K1, n0, 0, 0, 2 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(3 / B_K1, n0, 0, 0, 3 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(4 / B_K1, n0, 0, 0, 4 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(5 / B_K1, n0, 0, 0, 5% B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(6 / B_K1, n0, 0, 0, 6 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(7 / B_K1, n0, 0, 0, 7 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(8 / B_K1, n0, 0, 0, 8 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(9 / B_K1, n0, 0, 0, 9% B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(10 / B_K1, n0, 0, 0, 10 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(11 / B_K1, n0, 0, 0, 11 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(12 / B_K1, n0, 0, 0, 12 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(13 / B_K1, n0, 0, 0, 13 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(14 / B_K1, n0, 0, 0, 14 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(15 / B_K1, n0, 0, 0, 15% B_K1))>{}])))
// );
// }
#endif
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
...
@@ -374,34 +442,24 @@ struct BlockwiseGemmWMMA
...
@@ -374,34 +442,24 @@ struct BlockwiseGemmWMMA
{
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
// read B
// read B
b_thread_copy_
.
Run
(
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
*
B_Data_Duplicated_Rate
/
2
>
{},
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_block_buf
,
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
b_thread_buf
);
// read A
// read A
a_thread_copy_
.
Run
(
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
*
A_Data_Duplicated_Rate
/
2
>
{},
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_block_buf
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
);
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
...
@@ -410,10 +468,20 @@ struct BlockwiseGemmWMMA
...
@@ -410,10 +468,20 @@ struct BlockwiseGemmWMMA
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
,
n0
,
0
,
0
,
i
%
B_K1
))
>
{}];
make_tuple
(
i
/
B_K1
/
B_KRow
,
n0
,
0
,
(
i
/
B_K1
)
%
B_KRow
,
0
,
i
%
B_K1
))
>
{}];
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
,
m0
,
0
,
0
,
i
%
A_K1
))
>
{}];
make_tuple
(
i
/
A_K1
/
A_KRow
,
m0
,
0
,
(
i
/
A_K1
)
%
A_KRow
,
0
,
i
%
A_K1
))
>
{}];
});
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
...
@@ -433,14 +501,33 @@ struct BlockwiseGemmWMMA
...
@@ -433,14 +501,33 @@ struct BlockwiseGemmWMMA
}
}
protected:
protected:
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
static
constexpr
auto
a_thread_desc_
=
make_tuple
(
Number
<
WmmaK
/
A_K1
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
A_K1
>
{}),
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
WmmaK
/
A_K1
/
A_KRow
>
{},
make_tuple
(
Number
<
A_K1
>
{},
Number
<
WmmaK
>
{},
Number
<
A_K1
>
{},
Number
<
A_K1
>
{},
Number
<
1
>
{}));
Number
<
MRepeat
>
{},
I1
,
// B[K0, N0, N1, N2, K1]
Number
<
A_KRow
>
{},
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor
(
I1
,
make_tuple
(
Number
<
WmmaK
/
B_K1
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
B_K1
>
{}),
Number
<
A_K1
>
{}),
make_tuple
(
Number
<
B_K1
>
{},
Number
<
WmmaK
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
1
>
{}));
make_tuple
(
Number
<
A_K1
*
A_KRow
>
{},
Number
<
WmmaK
>
{},
Number
<
A_K1
*
A_KRow
>
{},
Number
<
A_K1
>
{},
Number
<
A_K1
>
{},
Number
<
1
>
{}));
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
WmmaK
/
B_K1
/
B_KRow
>
{},
Number
<
NRepeat
>
{},
I1
,
Number
<
B_KRow
>
{},
I1
,
Number
<
B_K1
>
{}),
make_tuple
(
Number
<
B_K1
*
B_KRow
>
{},
Number
<
WmmaK
>
{},
Number
<
B_K1
*
B_KRow
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
1
>
{}));
// C[M, N, NumRegWMMA]
// C[M, N, NumRegWMMA]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
...
@@ -452,13 +539,14 @@ struct BlockwiseGemmWMMA
...
@@ -452,13 +539,14 @@ struct BlockwiseGemmWMMA
template
<
>
template
<
>
struct
AThreadCopySelector
<
true
>
struct
AThreadCopySelector
<
true
>
{
{
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
decltype
(
a_thread_desc_
),
Sequence
<
WmmaK
/
A_K1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
WmmaK
/
A_K1
/
A_KRow
,
1
,
1
,
A_KRow
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
4
,
5
,
A_K1
,
A_K1
,
A_K1
>
;
A_K1
>
;
};
};
...
@@ -472,9 +560,9 @@ struct BlockwiseGemmWMMA
...
@@ -472,9 +560,9 @@ struct BlockwiseGemmWMMA
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
decltype
(
a_thread_desc_
),
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
WmmaK
/
A_K1
/
A_KRow
,
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
4
,
5
,
A_K1
,
A_K1
,
0x76543210
,
0x76543210
,
0xfedcba98
,
0xfedcba98
,
...
@@ -487,13 +575,14 @@ struct BlockwiseGemmWMMA
...
@@ -487,13 +575,14 @@ struct BlockwiseGemmWMMA
template
<
>
template
<
>
struct
BThreadCopySelector
<
true
>
struct
BThreadCopySelector
<
true
>
{
{
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
Sequence
<
WmmaK
/
B_K1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
WmmaK
/
B_K1
/
B_KRow
,
1
,
1
,
B_KRow
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
4
,
5
,
B_K1
,
B_K1
,
B_K1
>
;
B_K1
>
;
};
};
...
@@ -507,9 +596,9 @@ struct BlockwiseGemmWMMA
...
@@ -507,9 +596,9 @@ struct BlockwiseGemmWMMA
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
WmmaK
/
B_K1
/
B_KRow
,
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
4
,
5
,
B_K1
,
B_K1
,
0x76543210
,
0x76543210
,
0xfedcba98
,
0xfedcba98
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
c713d224
...
@@ -80,6 +80,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -80,6 +80,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
// K1 = Max Vector Access Pixels
// K1 = Max Vector Access Pixels
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
...
@@ -136,18 +137,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -136,18 +137,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
}
}
else
else
{
{
constexpr
auto
A_KRow
=
WmmaK
/
K1
;
constexpr
auto
A_KRow
=
2
;
constexpr
auto
A_K0PerWmma
=
WmmaK
/
A_KRow
/
K1Number
;
const
auto
A_KWmma
=
K
/
WmmaK
;
const
auto
A_KWmma
=
K
/
WmmaK
;
const
auto
M0
=
M
/
MPerBlock
;
const
auto
M0
=
M
/
MPerBlock
;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
A_KWmma
,
Number
<
A_KRow
>
{},
K1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
A_KWmma
,
Number
<
A_K0PerWmma
>
{},
Number
<
A_KRow
>
{},
K1Number
)),
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
,
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{}))),
make_tuple
(
M0
*
MRepeat
,
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
}
}
...
@@ -187,18 +191,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -187,18 +191,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
}
}
else
else
{
{
constexpr
auto
B_KRow
=
WmmaK
/
K1
;
constexpr
auto
B_KRow
=
2
;
constexpr
auto
B_K0PerWmma
=
WmmaK
/
B_KRow
/
K1Number
;
const
auto
B_KWmma
=
K
/
WmmaK
;
const
auto
B_KWmma
=
K
/
WmmaK
;
const
auto
N0
=
N
/
NPerBlock
;
const
auto
N0
=
N
/
NPerBlock
;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B_KWmma
,
Number
<
B_KRow
>
{},
K1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B_KWmma
,
Number
<
B_K0PerWmma
>
{},
Number
<
B_KRow
>
{},
K1Number
)),
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
,
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{}))),
make_tuple
(
N0
*
NRepeat
,
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
}
}
...
@@ -372,7 +379,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -372,7 +379,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
else
else
{
{
return
arg
.
a_grid_desc_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_
.
GetLength
(
I3
)
*
return
arg
.
a_grid_desc_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_
.
GetLength
(
I3
)
*
arg
.
a_grid_desc_
.
GetLength
(
I
5
);
arg
.
a_grid_desc_
.
GetLength
(
I
4
)
*
arg
.
a_grid_desc_
.
GetLength
(
I6
);
}
}
}();
}();
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
c713d224
...
@@ -297,15 +297,7 @@ struct GridwiseGemmPipeline_v1<1, false, true>
...
@@ -297,15 +297,7 @@ struct GridwiseGemmPipeline_v1<1, false, true>
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
index_t
num_loop
)
{
{
#if 0
constexpr
auto
a_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
constexpr auto a_block_origin_idx = generate_sequence_v2(
[]() constexpr {
return Number<0>{};
},
Number<a_block_desc.GetLengths().GetSize()>{});
#endif
constexpr
auto
a_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
a_block_buf_switch
=
a_block_buf
;
auto
a_block_buf_switch
=
a_block_buf
;
// preload data into LDS
// preload data into LDS
...
@@ -404,7 +396,7 @@ struct GridwiseGemmPipeline_v1<1, true, false>
...
@@ -404,7 +396,7 @@ struct GridwiseGemmPipeline_v1<1, true, false>
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
index_t
num_loop
)
{
{
constexpr
auto
b_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
constexpr
auto
b_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
b_block_buf_switch
=
b_block_buf
;
auto
b_block_buf_switch
=
b_block_buf
;
// preload data into LDS
// preload data into LDS
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
c713d224
...
@@ -172,10 +172,23 @@ struct GridwiseGemm_Wmma
...
@@ -172,10 +172,23 @@ struct GridwiseGemm_Wmma
else
else
{
{
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
// KWmma->MRepeat->MWave->KRow->MPerWmma->K1 Per Thread
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
I1
,
K1
),
make_tuple
(
Number
<
KWmmaPerblock
>
{},
make_tuple
(
Number
<
MRepeat
>
{}
*
K1
,
K1
,
K1
,
K1
,
K1
,
I1
));
Number
<
MRepeat
>
{},
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
K1
),
make_tuple
(
Number
<
MRepeat
>
{}
*
Number
<
K0PerWmma
>
{}
*
K1
,
Number
<
K0PerWmma
>
{}
*
K1
,
Number
<
K0PerWmma
>
{}
*
K1
,
K1
,
K1
,
K1
,
I1
));
}
}
}();
}();
...
@@ -206,10 +219,23 @@ struct GridwiseGemm_Wmma
...
@@ -206,10 +219,23 @@ struct GridwiseGemm_Wmma
else
else
{
{
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
// KWmma->NRepeat->NWave->NRow->NPerWmma->K1 Per Thread
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
I1
,
K1
),
make_tuple
(
Number
<
KWmmaPerblock
>
{},
make_tuple
(
Number
<
NRepeat
>
{}
*
K1
,
K1
,
K1
,
K1
,
K1
,
I1
));
Number
<
NRepeat
>
{},
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
K1
),
make_tuple
(
Number
<
NRepeat
>
{}
*
Number
<
K0PerWmma
>
{}
*
K1
,
Number
<
K0PerWmma
>
{}
*
K1
,
Number
<
K0PerWmma
>
{}
*
K1
,
K1
,
K1
,
K1
,
I1
));
}
}
}();
}();
...
@@ -229,7 +255,7 @@ struct GridwiseGemm_Wmma
...
@@ -229,7 +255,7 @@ struct GridwiseGemm_Wmma
{
{
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
);
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
,
0
);
}
}
}();
}();
...
@@ -249,7 +275,7 @@ struct GridwiseGemm_Wmma
...
@@ -249,7 +275,7 @@ struct GridwiseGemm_Wmma
{
{
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
);
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
,
0
);
}
}
}();
}();
...
@@ -264,23 +290,26 @@ struct GridwiseGemm_Wmma
...
@@ -264,23 +290,26 @@ struct GridwiseGemm_Wmma
constexpr
auto
a_wave_desc
=
[
&
]()
{
constexpr
auto
a_wave_desc
=
[
&
]()
{
if
constexpr
(
AEnableLds
)
if
constexpr
(
AEnableLds
)
{
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_
AKRow_
MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
A_KRow
=
I1
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
ABlockDesc_
{},
ABlockDesc_
{},
make_tuple
(
make_
pass_through_transform
(
Number
<
A_K0
>
{}),
make_tuple
(
make_
unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{}
,
A_KRow
)
),
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
}
}
else
else
{
{
// KWmma_MRepeat_MWave_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
// KWmma_MRepeat_MWave_
K0PerWmma_
KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr
auto
KWmma
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
KWmma
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I5
);
constexpr
auto
K0PerWmma
=
ABlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
A_KRow
=
ABlockDesc_
{}.
GetLength
(
I4
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I6
);
// Err: merge transform cause non-constexpr issue
// Err: merge transform cause non-constexpr issue
...
@@ -301,26 +330,12 @@ struct GridwiseGemm_Wmma
...
@@ -301,26 +330,12 @@ struct GridwiseGemm_Wmma
// Sequence<4>{}));
// Sequence<4>{}));
// Workaround, Freeze transform
// Workaround, Freeze transform
return
transform_tensor_descriptor
(
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KWmma
*
K0PerWmma
>
{},
ABlockDesc_
{},
Number
<
MRepeat
>
{},
make_tuple
(
make_freeze_transform
(
I0
),
I1
,
make_pass_through_transform
(
Number
<
KWmma
>
{}),
Number
<
A_KRow
>
{},
make_pass_through_transform
(
Number
<
MRepeat
>
{}),
I1
,
make_pass_through_transform
(
I1
),
Number
<
A_K1
>
{}));
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
}
}
}();
}();
...
@@ -336,42 +351,31 @@ struct GridwiseGemm_Wmma
...
@@ -336,42 +351,31 @@ struct GridwiseGemm_Wmma
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_KRow
=
I1
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
BBlockDesc_
{},
BBlockDesc_
{},
make_tuple
(
make_
pass_through_transform
(
Number
<
B_K0
>
{}),
make_tuple
(
make_
unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{}
,
B_KRow
)
),
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
}
}
else
else
{
{
// KWmma_
N
Repeat_
N
Wave_KRow_
N
PerWmma_K1 -> K0_
N
Repeat_
N
waves_
N
PerWmma_K1
// KWmma_
M
Repeat_
M
Wave_
K0PerWmma_
KRow_
M
PerWmma_K1 -> K0_
M
Repeat_
M
waves_
M
PerWmma_K1
constexpr
auto
KWmma
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
KWmma
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I5
);
constexpr
auto
K0PerWmma
=
BBlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
B_KRow
=
BBlockDesc_
{}.
GetLength
(
I4
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I6
);
// Workaround, Freeze transform
// Workaround, Freeze transform
return
transform_tensor_descriptor
(
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KWmma
*
K0PerWmma
>
{},
BBlockDesc_
{},
Number
<
NRepeat
>
{},
make_tuple
(
make_freeze_transform
(
I0
),
I1
,
make_pass_through_transform
(
Number
<
KWmma
>
{}),
Number
<
B_KRow
>
{},
make_pass_through_transform
(
Number
<
NRepeat
>
{}),
I1
,
make_pass_through_transform
(
I1
),
Number
<
B_K1
>
{}));
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
}
}
}();
}();
...
@@ -415,9 +419,9 @@ struct GridwiseGemm_Wmma
...
@@ -415,9 +419,9 @@ struct GridwiseGemm_Wmma
else
else
{
{
return
make_tuple
(
a_grid_desc
.
GetLength
(
I1
)
*
a_grid_desc
.
GetLength
(
I2
)
*
return
make_tuple
(
a_grid_desc
.
GetLength
(
I1
)
*
a_grid_desc
.
GetLength
(
I2
)
*
a_grid_desc
.
GetLength
(
I
4
),
a_grid_desc
.
GetLength
(
I
5
),
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I
5
));
a_grid_desc
.
GetLength
(
I
4
)
*
a_grid_desc
.
GetLength
(
I6
));
}
}
};
};
...
@@ -430,9 +434,9 @@ struct GridwiseGemm_Wmma
...
@@ -430,9 +434,9 @@ struct GridwiseGemm_Wmma
else
else
{
{
return
make_tuple
(
b_grid_desc
.
GetLength
(
I1
)
*
b_grid_desc
.
GetLength
(
I2
)
*
return
make_tuple
(
b_grid_desc
.
GetLength
(
I1
)
*
b_grid_desc
.
GetLength
(
I2
)
*
b_grid_desc
.
GetLength
(
I
4
),
b_grid_desc
.
GetLength
(
I
5
),
b_grid_desc
.
GetLength
(
I0
)
*
b_grid_desc
.
GetLength
(
I3
)
*
b_grid_desc
.
GetLength
(
I0
)
*
b_grid_desc
.
GetLength
(
I3
)
*
b_grid_desc
.
GetLength
(
I
5
));
b_grid_desc
.
GetLength
(
I
4
)
*
b_grid_desc
.
GetLength
(
I6
));
}
}
};
};
...
@@ -599,7 +603,8 @@ struct GridwiseGemm_Wmma
...
@@ -599,7 +603,8 @@ struct GridwiseGemm_Wmma
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I2
);
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I2
);
}
}
else
{
else
{
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I5
);
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I4
)
*
a_grid_desc
.
GetLength
(
I6
);
}
}
}();
}();
...
@@ -652,6 +657,7 @@ struct GridwiseGemm_Wmma
...
@@ -652,6 +657,7 @@ struct GridwiseGemm_Wmma
// Thread-wise copy
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1Value
;
auto
a_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
auto
a_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
a_block_desc
.
GetElementSpaceSize
());
a_block_desc
.
GetElementSpaceSize
());
...
@@ -664,11 +670,12 @@ struct GridwiseGemm_Wmma
...
@@ -664,11 +670,12 @@ struct GridwiseGemm_Wmma
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Number
<
MRepeat
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
I1
,
I1
,
Number
<
K1Value
>
{}
>
,
Number
<
K1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
5
,
6
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
...
@@ -676,6 +683,7 @@ struct GridwiseGemm_Wmma
...
@@ -676,6 +683,7 @@ struct GridwiseGemm_Wmma
make_multi_index
(
0
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
/
(
MWaves
*
MPerWmma
),
m_block_data_idx_on_grid
/
(
MWaves
*
MPerWmma
),
get_thread_local_1d_id
()
/
32
,
get_thread_local_1d_id
()
/
32
,
0
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
get_thread_local_1d_id
()
%
16
,
0
));
0
));
...
@@ -729,6 +737,7 @@ struct GridwiseGemm_Wmma
...
@@ -729,6 +737,7 @@ struct GridwiseGemm_Wmma
// Thread-wise copy
// Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1Value
;
auto
b_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
BDataType
>
(
auto
b_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
BDataType
>
(
b_block_desc
.
GetElementSpaceSize
());
b_block_desc
.
GetElementSpaceSize
());
...
@@ -741,11 +750,12 @@ struct GridwiseGemm_Wmma
...
@@ -741,11 +750,12 @@ struct GridwiseGemm_Wmma
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Number
<
NRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
I1
,
I1
,
Number
<
K1Value
>
{}
>
,
Number
<
K1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
5
,
6
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
...
@@ -753,6 +763,7 @@ struct GridwiseGemm_Wmma
...
@@ -753,6 +763,7 @@ struct GridwiseGemm_Wmma
make_multi_index
(
0
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
/
(
NWaves
*
NPerWmma
),
n_block_data_idx_on_grid
/
(
NWaves
*
NPerWmma
),
get_thread_local_1d_id
()
/
32
,
get_thread_local_1d_id
()
/
32
,
0
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
get_thread_local_1d_id
()
%
16
,
0
));
0
));
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
c713d224
...
@@ -1387,7 +1387,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
...
@@ -1387,7 +1387,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// copy data from src_buf into dst_vector
// copy data from src_buf into dst_vector
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
// src_desc error, non constexpr
?
// src_desc error, non constexpr
, caused by merge transform
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
...
@@ -1396,8 +1396,6 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
...
@@ -1396,8 +1396,6 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
SrcData
v_this_row
,
v_theother_row
;
SrcData
v_this_row
,
v_theother_row
;
// int type temp value due to intrinsic requirement
// int type temp value due to intrinsic requirement
// TODO: This temp value will generate the scratch memory if
// IntraRowSwizzlePerm is flase
int
temp
=
0
;
int
temp
=
0
;
// apply element-wise operation
// apply element-wise operation
...
@@ -1419,7 +1417,10 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
...
@@ -1419,7 +1417,10 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
1
,
1
,
0
);
0
);
v_theother_row
=
type_convert_sp
<
SrcData
>
(
temp
);
v_theother_row
=
type_convert_sp
<
SrcData
>
(
temp
);
// if (get_thread_local_1d_id() == 0){
// printf("src_offset:%d, dst_offset for this row: %d, dst_offset
// for the other row: %d \n",
// src_offset, dst_offset, dst_offset+DstScalarPerVector);}
if
(
get_thread_local_1d_id
()
%
32
<
16
)
if
(
get_thread_local_1d_id
()
%
32
<
16
)
{
{
// apply type convert
// apply type convert
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment