Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ea41fc2f
Commit
ea41fc2f
authored
Oct 22, 2024
by
aska-0096
Browse files
Port new layout to v1, v2 pipeline
parent
b70bcd86
Changes
4
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
452 additions
and
468 deletions
+452
-468
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp
...ion/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp
+12
-4
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp
+153
-97
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
+201
-161
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+86
-206
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp
View file @
ea41fc2f
...
@@ -66,7 +66,9 @@ constexpr auto BlockGemmPipeline_Selector()
...
@@ -66,7 +66,9 @@ constexpr auto BlockGemmPipeline_Selector()
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
{};
KPack
,
TransposeA
,
TransposeB
>
{};
}
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
{
...
@@ -89,7 +91,9 @@ constexpr auto BlockGemmPipeline_Selector()
...
@@ -89,7 +91,9 @@ constexpr auto BlockGemmPipeline_Selector()
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
{};
KPack
,
TransposeA
,
TransposeB
>
{};
}
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
{
...
@@ -137,7 +141,9 @@ constexpr auto BlockGemmPipeline_Selector()
...
@@ -137,7 +141,9 @@ constexpr auto BlockGemmPipeline_Selector()
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
{};
KPack
,
TransposeA
,
TransposeB
>
{};
}
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v5
)
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v5
)
{
{
...
@@ -160,7 +166,9 @@ constexpr auto BlockGemmPipeline_Selector()
...
@@ -160,7 +166,9 @@ constexpr auto BlockGemmPipeline_Selector()
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
{};
KPack
,
TransposeA
,
TransposeB
>
{};
}
}
else
else
{
{
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp
View file @
ea41fc2f
...
@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
...
@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t
NPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
index_t
KPacks
,
bool
TransposeA
,
bool
TransposeB
>
struct
BlockwiseGemmXdlops_pipeline_v1
struct
BlockwiseGemmXdlops_pipeline_v1
{
{
};
};
...
@@ -55,7 +57,9 @@ template <index_t BlockSize,
...
@@ -55,7 +57,9 @@ template <index_t BlockSize,
index_t
NPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPack
index_t
KPack
,
bool
TransposeA
,
bool
TransposeB
// ,bool TransposeC //disable transposec right now...
// ,bool TransposeC //disable transposec right now...
>
>
struct
BlockwiseGemmXdlops_pipeline_v1
<
BlockGemmPipelineScheduler
::
Intrawave
,
struct
BlockwiseGemmXdlops_pipeline_v1
<
BlockGemmPipelineScheduler
::
Intrawave
,
...
@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
...
@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
KPack
,
TransposeA
,
TransposeB
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
...
@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
...
@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
KPack
,
TransposeA
,
TransposeB
>
{
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
...
@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
...
@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
;
KPack
,
TransposeA
,
TransposeB
>
;
using
Base
::
I0
;
using
Base
::
I0
;
using
Base
::
KRepeat
;
using
Base
::
KRepeat
;
using
Base
::
xdlops_gemm
;
using
Base
::
xdlops_gemm
;
...
@@ -218,24 +228,21 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
...
@@ -218,24 +228,21 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m
0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
make_tuple
(
I
0
,
I0
,
I0
,
Number
<
k
0
*
AMmaKStride
>
{}),
a_block_buf
,
a_block_buf
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m
0
,
I0
,
k
,
I0
),
make_tuple
(
I
0
,
I0
,
k
0
,
I0
),
a_thread_buf
);
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n
0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
make_tuple
(
I
0
,
I0
,
I0
,
Number
<
k
0
*
BMmaKStride
>
{}),
b_block_buf
,
b_block_buf
,
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
n
0
,
I0
,
k
,
I0
),
make_tuple
(
I
0
,
I0
,
k
0
,
I0
),
b_thread_buf
);
b_thread_buf
);
});
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
...
@@ -279,24 +286,21 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
...
@@ -279,24 +286,21 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
{
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m
0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
make_tuple
(
I
0
,
I0
,
I0
,
Number
<
k
0
*
AMmaKStride
>
{}),
a_block_buf
,
a_block_buf
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m
0
,
I0
,
k
,
I0
),
make_tuple
(
I
0
,
I0
,
k
0
,
I0
),
a_thread_buf
);
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n
0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
make_tuple
(
I
0
,
I0
,
I0
,
Number
<
k
0
*
BMmaKStride
>
{}),
b_block_buf
,
b_block_buf
,
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
n
0
,
I0
,
k
,
I0
),
make_tuple
(
I
0
,
I0
,
k
0
,
I0
),
b_thread_buf
);
b_thread_buf
);
});
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
...
@@ -354,7 +358,9 @@ template <index_t BlockSize,
...
@@ -354,7 +358,9 @@ template <index_t BlockSize,
index_t
NPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPack
index_t
KPack
,
bool
TransposeA
,
bool
TransposeB
// ,bool TransposeC //disable transposec right now...
// ,bool TransposeC //disable transposec right now...
>
>
struct
BlockwiseGemmXdlops_pipeline_v1
<
BlockGemmPipelineScheduler
::
Interwave
,
struct
BlockwiseGemmXdlops_pipeline_v1
<
BlockGemmPipelineScheduler
::
Interwave
,
...
@@ -376,7 +382,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
...
@@ -376,7 +382,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
KPack
,
TransposeA
,
TransposeB
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
...
@@ -395,7 +403,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
...
@@ -395,7 +403,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
KPack
,
TransposeA
,
TransposeB
>
{
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
...
@@ -416,7 +426,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
...
@@ -416,7 +426,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
;
KPack
,
TransposeA
,
TransposeB
>
;
using
Base
::
A_K1
;
using
Base
::
A_K1
;
using
Base
::
B_K1
;
using
Base
::
B_K1
;
using
Base
::
I0
;
using
Base
::
I0
;
...
@@ -520,22 +532,18 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
...
@@ -520,22 +532,18 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m
0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
make_tuple
(
I
0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
a_block_buf
,
a_block_buf
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m
0
,
I0
,
k0
,
I0
),
make_tuple
(
I
0
,
I0
,
k0
,
I0
),
a_thread_buf
);
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n
0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
make_tuple
(
I
0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_block_buf
,
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
n
0
,
I0
,
k0
,
I0
),
make_tuple
(
I
0
,
I0
,
k0
,
I0
),
b_thread_buf
);
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
// NOTE: Synchronize threads in a workgroup at the start of each MAC cluster,
// NOTE: Synchronize threads in a workgroup at the start of each MAC cluster,
// but except the first, as we can shorten non-MAC cluster a bit and there's no
// but except the first, as we can shorten non-MAC cluster a bit and there's no
...
@@ -614,22 +622,18 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
...
@@ -614,22 +622,18 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
{
{
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m
0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
make_tuple
(
I
0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
a_block_buf
,
a_block_buf
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m
0
,
I0
,
k0
,
I0
),
make_tuple
(
I
0
,
I0
,
k0
,
I0
),
a_thread_buf
);
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n
0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
make_tuple
(
I
0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_block_buf
,
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
n
0
,
I0
,
k0
,
I0
),
make_tuple
(
I
0
,
I0
,
k0
,
I0
),
b_thread_buf
);
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
k0
.
value
!=
0
||
KRepeat
==
1
)
if
constexpr
(
k0
.
value
!=
0
||
KRepeat
==
1
)
...
@@ -703,28 +707,80 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
...
@@ -703,28 +707,80 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
Number
<
NRepeat
*
KPerInnerLoop
>
{},
Number
<
NRepeat
*
KPerInnerLoop
>
{},
I1
));
I1
));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
ADataType
,
template
<
bool
Transpose
>
struct
AThreadCopySelector
;
template
<
>
struct
AThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v5
<
ADataType
,
ComputeDataType
,
ComputeDataType
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
MRepeat
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
3
,
A_K1
,
A_K1
,
A_K1
>
;
A_K1
>
;
};
template
<
>
struct
AThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v5
<
ADataType
,
ComputeDataType
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
MRepeat
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
3
,
1
,
2
,
0
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
0
,
3
,
MRepeat
,
A_K1
>
;
};
template
<
bool
Transpose
>
struct
BThreadCopySelector
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
BDataType
,
template
<
>
struct
BThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v5
<
BDataType
,
ComputeDataType
,
ComputeDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
NRepeat
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
3
,
B_K1
,
B_K1
,
B_K1
>
;
B_K1
>
;
};
template
<
>
struct
BThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v5
<
BDataType
,
ComputeDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
NRepeat
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
3
,
1
,
2
,
0
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
0
,
3
,
NRepeat
,
B_K1
>
;
};
AThreadCopy
a_thread_copy_
{
Base
::
CalculateAThreadOriginDataIndex
()};
typename
AThreadCopySelector
<
TransposeA
>::
type
a_thread_copy_
{
BThreadCopy
b_thread_copy_
{
Base
::
CalculateBThreadOriginDataIndex
()};
Base
::
CalculateAThreadOriginDataIndex
()};
typename
BThreadCopySelector
<
TransposeB
>::
type
b_thread_copy_
{
Base
::
CalculateBThreadOriginDataIndex
()};
using
Base
::
c_thread_desc_
;
using
Base
::
c_thread_desc_
;
};
};
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
View file @
ea41fc2f
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
ea41fc2f
...
@@ -146,7 +146,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -146,7 +146,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
static
constexpr
index_t
KPack
=
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
gcd
(
AK1Number
,
BK1Number
),
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
@@ -1016,16 +1016,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1016,16 +1016,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for C shuffle in LDS
return
a_block_space_size_aligned
*
sizeof
(
ADataType
)
+
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
b_block_space_size_aligned
*
sizeof
(
BDataType
);
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
*
sizeof
(
ADataType
)
+
b_block_space_size_aligned
*
sizeof
(
BDataType
)),
c_block_size
*
sizeof
(
CShuffleDataType
));
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
@@ -1713,63 +1705,34 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1713,63 +1705,34 @@ struct GridwiseGemm_xdl_cshuffle_v3
c_thread_buf
,
c_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
);
// shuffle C and write out
// Epilogue
{
constexpr
auto
c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
=
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
blockwise_gemm_pipeline
.
GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
();
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm_pipeline
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_
shuffle_
block_desc_m
block_mperblock_nblock_nperblock
=
constexpr
auto
c_block_desc_m
0_m1_n0_m2_m3_n1_n2_m4
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_M1_N0_M2_M3_N1_N2_M4
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
constexpr
auto
M0
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
0
>
{});
static_cast
<
CShuffleDataType
*>
(
p_shared_0
),
constexpr
auto
M1
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
1
>
{});
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
N0
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
M2
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
M3
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
N1
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
N2
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
6
>
{});
constexpr
auto
M4
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
7
>
{});
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
const
auto
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_tuple
(
make_pass_through_transform
(
problem
.
MBlock
),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
)),
make_unmerge_transform
(
make_tuple
(
make_pass_through_transform
(
problem
.
NBlock
),
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
Sequence
<
0
>
{},
Sequence
<
2
,
3
,
5
,
6
,
9
>
{},
Sequence
<
1
>
{},
Sequence
<
4
,
7
,
8
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndex
Contiguous
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
...
@@ -1784,8 +1747,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1784,8 +1747,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
...
@@ -1794,121 +1756,39 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1794,121 +1756,39 @@ struct GridwiseGemm_xdl_cshuffle_v3
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
// Typecast -> Permute -> Coalesced vector store
auto
c_thread_copy_vgpr_to_lds
=
auto
c_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r4
<
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
AccDataType
,
CShuffleDataType
,
CDataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
CElementwiseOperation
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
Sequence
<
I1
,
I1
,
M0
,
I1
,
I1
,
M2
,
I1
,
I1
,
N2
,
M4
>
,
CShuffleNXdlPerWavePerShuffle
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
I1
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
9
,
8
>
,
I1
,
9
,
M2
,
8
,
I1
,
M4
,
M4
,
I1
>
,
N2
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
false
>
{
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
,
true
>
{
make_multi_index
(
block_m_id
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
block_n_id
,
make_multi_index
(
0
,
m_thread_data_on_block_idx
[
I0
],
0
,
m_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I
1
],
n_thread_data_on_block_idx
[
I
0
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I2
]),
n_thread_data_on_block_idx
[
I2
],
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
m_thread_data_on_block_idx
[
I4
]),
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
CShuffleDataType
,
// typename SrcData,
CDataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
),
c_element_op
};
c_element_op
};
// space filling curve for threadwise C in VGPR
c_thread_copy_vgpr_to_global
.
Run
(
c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
,
constexpr
auto
sfc_c_vgpr
=
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
c_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
});
}
}
}
template
<
bool
HasMainKBlockLoop
,
template
<
bool
HasMainKBlockLoop
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment