Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
23ce8e68
Commit
23ce8e68
authored
Jun 13, 2022
by
wangshaojie6
Browse files
add prefetch 3 for pipeline v2
parent
56598b1b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
419 additions
and
1 deletion
+419
-1
example/01_gemm/gemm_xdl_fp16_splitk.cpp
example/01_gemm/gemm_xdl_fp16_splitk.cpp
+4
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
+414
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+1
-1
No files found.
example/01_gemm/gemm_xdl_fp16_splitk.cpp
View file @
23ce8e68
...
...
@@ -162,6 +162,10 @@ int main(int argc, char* argv[])
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
std
::
cout
<<
"a device buf: "
<<
a_m_k_device_buf
.
GetDeviceBuffer
()
<<
std
::
endl
;
std
::
cout
<<
"b device buf: "
<<
b_k_n_device_buf
.
GetDeviceBuffer
()
<<
std
::
endl
;
std
::
cout
<<
"c device buf: "
<<
c_m_n_device_buf
.
GetDeviceBuffer
()
<<
std
::
endl
;
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
View file @
23ce8e68
...
...
@@ -286,5 +286,419 @@ struct GridwiseGemmPipeline_v2<2>
}
};
// 3-stage prefetch
template
<
>
struct
GridwiseGemmPipeline_v2
<
3
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
{
// TODO: improve applicability
return
num_loop
>
3
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
3
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
static_for
<
0
,
3
,
1
>
{}([
&
](
auto
i_pre
){
// global read i_pre
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
Number
<
i_pre
>
{});
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
Number
<
i_pre
>
{});
// move to i_pre + 1
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
});
// Initialize C
c_thread_buf
.
Clear
();
index_t
i
=
0
;
// main body
if
constexpr
(
HasMainLoop
)
{
do
{
static_for
<
0
,
3
,
1
>
{}([
&
](
auto
i_main
){
// LDS write i_main
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
Number
<
i_main
>
{});
// global Read i_main + 3
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
Number
<
i_main
>
{});
// LDS write i_main
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
Number
<
i_main
>
{});
// global Read i_main + 3
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
Number
<
i_main
>
{});
// move to i_main + 3
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
block_sync_lds
();
// GEMM i_main
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
});
i
+=
3
;
}
while
(
i
<
(
num_loop
-
3
));
}
// tail
if
(
i
==
num_loop
-
3
)
{
static_for
<
0
,
I3
,
1
>
{}([
&
](
auto
i_res
){
// Write num_loop - 3
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
Number
<
i_res
>
{});
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
Number
<
i_res
>
{});
block_sync_lds
();
// GEMM num_loop - 3
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
});
}
// tail
else
if
(
i
==
num_loop
-
2
)
{
static_for
<
0
,
I2
,
1
>
{}([
&
](
auto
i_res
){
// Write num_loop
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
Number
<
i_res
>
{});
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
Number
<
i_res
>
{});
block_sync_lds
();
// GEMM num_loop
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
});
}
// tail
else
if
(
i
==
num_loop
-
1
)
{
static_for
<
0
,
I1
,
1
>
{}([
&
](
auto
i_res
){
// Write num_loop
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
Number
<
i_res
>
{});
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
Number
<
i_res
>
{});
block_sync_lds
();
// GEMM num_loop
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
});
}
}
};
// 4-stage prefetch
template
<
>
struct
GridwiseGemmPipeline_v2
<
4
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
{
// TODO: improve applicability
return
num_loop
>
4
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
4
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
// global read 0
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I0
);
// move to 1
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// global read 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I1
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I1
);
// move to 2
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// global read 2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I2
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I2
);
// move to 3
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// global read 3
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I3
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I3
);
// Initialize C
c_thread_buf
.
Clear
();
index_t
i
=
0
;
// main body
if
constexpr
(
HasMainLoop
)
{
do
{
// move to i + 4
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write i
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
I0
);
// global Read i + 4
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
// LDS write i
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
I0
);
// global Read i + 4
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I0
);
block_sync_lds
();
// GEMM i
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
// move to i + 5
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write i + 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
I1
);
// global read i + 5
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I1
);
// LDS write i + 1
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
I1
);
// global read i + 5
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I1
);
block_sync_lds
();
// GEMM i + 1
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
// move to i + 6
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write i + 2
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
I2
);
// global read i + 6
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I2
);
// LDS write i + 2
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
I2
);
// global read i + 6
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I2
);
block_sync_lds
();
// GEMM i + 2
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
// move to i + 7
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write i + 3
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
I3
);
// global read i + 7
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I3
);
// LDS write i + 3
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
I3
);
// global read i + 7
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I3
);
block_sync_lds
();
// GEMM i + 3
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
i
+=
4
;
}
while
(
i
<
(
num_loop
-
4
));
}
// tail
if
(
i
==
num_loop
-
4
)
{
static_for
<
0
,
I4
,
1
>
{}([
&
](
auto
i_res
){
// Write num_loop - 3
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
Number
<
i_res
>
{});
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
Number
<
i_res
>
{});
block_sync_lds
();
// GEMM num_loop - 3
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
});
}
// tail
if
(
i
==
num_loop
-
3
)
{
static_for
<
0
,
I3
,
1
>
{}([
&
](
auto
i_res
){
// Write num_loop - 3
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
Number
<
i_res
>
{});
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
Number
<
i_res
>
{});
block_sync_lds
();
// GEMM num_loop - 3
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
});
}
// tail
else
if
(
i
==
num_loop
-
2
)
{
static_for
<
0
,
I2
,
1
>
{}([
&
](
auto
i_res
){
// Write num_loop
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
Number
<
i_res
>
{});
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
Number
<
i_res
>
{});
block_sync_lds
();
// GEMM num_loop
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
});
}
// tail
else
if
(
i
==
num_loop
-
1
)
{
static_for
<
0
,
I1
,
1
>
{}([
&
](
auto
i_res
){
// Write num_loop
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
Number
<
i_res
>
{});
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
Number
<
i_res
>
{});
block_sync_lds
();
// GEMM num_loop
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
});
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
23ce8e68
...
...
@@ -111,7 +111,7 @@ template <index_t BlockSize,
index_t
CShuffleNRepeatPerShuffle
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXDL
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
NumGemmKPrefetchStage
=
2
>
index_t
NumGemmKPrefetchStage
=
3
>
struct
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment