Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1a2a1679
Commit
1a2a1679
authored
May 08, 2023
by
carlushuang
Browse files
initial stream-k implementation with example
parent
8bb2bb4a
Changes
16
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
2500 additions
and
7 deletions
+2500
-7
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+1
-0
example/01_gemm/gemm_xdl_fp16.cpp
example/01_gemm/gemm_xdl_fp16.cpp
+3
-1
example/01_gemm/gemm_xdl_streamk.cpp
example/01_gemm/gemm_xdl_streamk.cpp
+41
-0
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+25
-2
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
+15
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp
...n/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp
+134
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
...sor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
+280
-0
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+233
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp
+89
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
...ensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
+856
-0
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp
...on/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp
+213
-0
include/ck/utility/magic_division.hpp
include/ck/utility/magic_division.hpp
+38
-0
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+4
-4
test/block_swizzle_test/block_swizzle_test.cpp
test/block_swizzle_test/block_swizzle_test.cpp
+406
-0
test/block_swizzle_test/rebuild.sh
test/block_swizzle_test/rebuild.sh
+3
-0
test/block_swizzle_test/simple_args.h
test/block_swizzle_test/simple_args.h
+159
-0
No files found.
example/01_gemm/CMakeLists.txt
View file @
1a2a1679
...
@@ -44,3 +44,4 @@ if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS
...
@@ -44,3 +44,4 @@ if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS
add_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
add_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
endif
()
endif
()
add_example_executable
(
example_gemm_xdl_streamk gemm_xdl_streamk.cpp
)
example/01_gemm/gemm_xdl_fp16.cpp
View file @
1a2a1679
...
@@ -39,7 +39,9 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl
...
@@ -39,7 +39,9 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
LoopScheduler
::
Default
,
ck
::
PipelineVersion
::
v1
>
;
// DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
// clang-format on
// clang-format on
using
DeviceGemmInstance
=
DeviceGemmInstance1
;
using
DeviceGemmInstance
=
DeviceGemmInstance1
;
...
...
example/01_gemm/gemm_xdl_streamk.cpp
0 → 100644
View file @
1a2a1679
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp"
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
CDataType
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
// clang-format off
using
DeviceGemmStreamK
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlStreamK
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// // clang-format on
// clang-format on
using
DeviceGemmInstance
=
DeviceGemmStreamK
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
#include "run_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
example/01_gemm/run_gemm_example.inc
View file @
1a2a1679
...
@@ -11,7 +11,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -11,7 +11,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
using
namespace
ck
::
literals
;
using
namespace
ck
::
literals
;
auto
&
[
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
]
=
problem_size
;
auto
[
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
]
=
problem_size
;
auto
f_host_tensor_descriptor
=
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
...
@@ -25,12 +25,35 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -25,12 +25,35 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
}
}
};
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
{
// give a chance if stride is zero, return a defalt packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
};
StrideA
=
f_get_default_stride
(
M
,
K
,
StrideA
,
ALayout
{});
StrideB
=
f_get_default_stride
(
K
,
N
,
StrideB
,
BLayout
{});
StrideC
=
f_get_default_stride
(
M
,
N
,
StrideC
,
CLayout
{});
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
switch
(
config
.
init_method
)
switch
(
config
.
init_method
)
{
{
case
0
:
break
;
case
0
:
ck
::
utils
::
FillConstant
<
ADataType
>
{
1.
f
}(
a_m_k
);
ck
::
utils
::
FillConstant
<
BDataType
>
{
1.
f
}(
b_k_n
);
break
;
case
1
:
case
1
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
View file @
1a2a1679
...
@@ -94,6 +94,21 @@ struct ThreadGroupTensorSliceTransfer_v4r1
...
@@ -94,6 +94,21 @@ struct ThreadGroupTensorSliceTransfer_v4r1
}
}
}
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
SrcBuffer
,
index_t
ThreadScratchId
=
0
>
template
<
typename
SrcBuffer
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
SrcBuffer
&
src_buf
,
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp
0 → 100644
View file @
1a2a1679
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp"
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
typename
ThreadGroup
,
typename
ElementwiseOperation
,
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
ThreadGroupTensorSliceTransfer_v6r1r2
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadGroupTensorSliceTransfer_v6r1r2
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
(),
element_op
)
{
static_assert
(
nDim
==
remove_cvref_t
<
SrcDesc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! ThreadGroup::GetNumOfThread() too small"
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
InMemoryDataOperationEnum
DstInMemOp
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
template
Run
<
SrcBuffer
,
DstBuffer
,
DstInMemOp
>(
src_desc
,
src_buf
,
dst_desc
,
dst_buf
);
}
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v6r1r2
<
SrcData
,
DstData
,
SrcDesc
,
DstDesc
,
ElementwiseOperation
,
decltype
(
thread_slice_lengths
),
DimAccessOrder
,
VectorDim
,
ScalarPerVector
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
0 → 100644
View file @
1a2a1679
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
ck
::
index_t
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
ck
::
index_t
BBlockLdsAddExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXDL
>
struct
DeviceGemmXdlStreamK
:
public
DeviceGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
<
BlockSize
,
BlockToCTileMap_GemmStreamK
<
MPerBlock
,
NPerBlock
,
K0PerBlock
*
K1
>
,
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
// Invoker
struct
Invoker
:
public
BaseInvoker
{
void
Print
(
const
Argument
&
karg
)
{
karg
.
Print
();
}
float
Run
(
const
Argument
&
karg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
Print
(
karg
);
}
if
(
!
GridwiseGemm
::
CheckValidity
(
karg
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid "
"setting"
);
}
dim3
grid_dims
=
karg
.
block_mapping
.
get_grid_dims
();
float
ave_time
=
0
;
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
// TODO: remove clear buffer for streamk kernels
hipGetErrorString
(
hipMemset
(
karg
.
p_c_grid
,
0
,
karg
.
M
*
karg
.
N
*
sizeof
(
CDataType
)));
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
grid_dims
,
dim3
(
BlockSize
),
0
,
karg
);
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
karg
)
{
return
GridwiseGemm
::
CheckValidity
(
karg
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
)
{
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
int
occupancy
,
num_cu
;
hipError_t
rtn
;
rtn
=
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
GridwiseGemm
::
GetSharedMemoryNumberOfByte
());
hip_check_error
(
rtn
);
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
rtn
=
hipGetDevice
(
&
dev
);
hip_check_error
(
rtn
);
rtn
=
hipGetDeviceProperties
(
&
dev_prop
,
dev
);
hip_check_error
(
rtn
);
num_cu
=
dev_prop
.
multiProcessorCount
;
printf
(
"XXX occupancy:%d, num_cu:%d, BLOCK_SIZE:%d, LDS:%d
\n
"
,
occupancy
,
num_cu
,
BlockSize
,
GridwiseGemm
::
GetSharedMemoryNumberOfByte
());
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
static_cast
<
uint32_t
>
(
num_cu
),
static_cast
<
uint32_t
>
(
occupancy
)};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
)
override
{
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
int
occupancy
,
num_cu
;
hipError_t
rtn
;
rtn
=
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
GridwiseGemm
::
GetSharedMemoryNumberOfByte
());
hip_check_error
(
rtn
);
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
rtn
=
hipGetDevice
(
&
dev
);
hip_check_error
(
rtn
);
rtn
=
hipGetDeviceProperties
(
&
dev_prop
,
dev
);
hip_check_error
(
rtn
);
num_cu
=
dev_prop
.
multiProcessorCount
;
return
std
::
make_unique
<
Argument
>
(
reinterpret_cast
<
const
ADataType
*>
(
p_a
),
reinterpret_cast
<
const
BDataType
*>
(
p_b
),
reinterpret_cast
<
CDataType
*>
(
p_c
),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
static_cast
<
uint32_t
>
(
num_cu
),
static_cast
<
uint32_t
>
(
occupancy
));
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
return
GridwiseGemm
::
GetTypeString
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
1a2a1679
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include "ck/utility/number.hpp"
#include "ck/utility/number.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include <limits>
namespace
ck
{
namespace
ck
{
...
@@ -635,4 +636,236 @@ struct BlockToCTileMap_3DGrid_KSplit
...
@@ -635,4 +636,236 @@ struct BlockToCTileMap_3DGrid_KSplit
}
}
};
};
template
<
uint32_t
MPerBlock_
,
uint32_t
NPerBlock_
,
uint32_t
KPerBlock_
>
struct
BlockToCTileMap_GemmStreamK
{
static
constexpr
uint32_t
min_k_iters_per_sk_block
=
2
;
static
constexpr
uint32_t
MPerBlock
=
MPerBlock_
;
static
constexpr
uint32_t
NPerBlock
=
NPerBlock_
;
static
constexpr
uint32_t
KPerBlock
=
KPerBlock_
;
//--------------------------------------
// pass to device
uint32_t
sk_num_blocks
;
uint32_t
sk_num_big_blocks
;
uint32_t
sk_total_iters
;
uint32_t
dp_start_block_idx
;
uint32_t
dp_iters_per_block
;
uint32_t
dp_num_blocks
;
uint32_t
k_iters_per_big_block
;
MDiv
k_iters_per_tile
;
MDiv
n_tiles
;
//--------------------------------------
// prefer construct on host
BlockToCTileMap_GemmStreamK
(
uint32_t
m
,
uint32_t
n
,
uint32_t
k
,
uint32_t
num_cu
,
uint32_t
occupancy
)
{
uint32_t
num_tiles
=
math
::
integer_divide_ceil
(
m
,
MPerBlock
)
*
math
::
integer_divide_ceil
(
n
,
NPerBlock
);
k_iters_per_tile
=
MDiv
(
math
::
integer_divide_ceil
(
k
,
KPerBlock
));
// one cu can hold one wg at one time, from the whole chip's point of view
// if number of wg is same as num_cu, we call it 1 dispatch
// if number of wg is 2x num_cu, we call it 2 dispatches.
// one dispatch can deliever wg same as num_cu (full dispatch), or less than num_cu (partial
// dispatch)
//
uint32_t
full_dispatches
=
num_tiles
/
num_cu
;
uint32_t
full_dispatch_tiles
=
full_dispatches
*
num_cu
;
uint32_t
partial_dispatche_tiles
=
num_tiles
-
full_dispatch_tiles
;
uint32_t
sk_occupancy
=
occupancy
;
uint32_t
dp_tiles
=
full_dispatch_tiles
;
uint32_t
sk_tiles
=
partial_dispatche_tiles
;
if
(
full_dispatches
<
occupancy
)
{
// in this case, we allocate all blocks as sk blocks
// sk_occupancy = occupancy - full_dispatches;
sk_occupancy
=
1
;
// TODO: single occ seems better
dp_tiles
=
full_dispatch_tiles
;
sk_tiles
=
partial_dispatche_tiles
;
}
else
if
((
occupancy
>
1
)
&&
(
full_dispatches
%
occupancy
==
occupancy
-
1
))
{
// e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
// occupancy = 3, full_dispatches = 5, 8, 11 ...
// occupancy = 4, full_dispatches = 7, 11 ...
sk_occupancy
=
1
;
// left 1 slot for sk occupancy
dp_tiles
=
full_dispatch_tiles
;
sk_tiles
=
partial_dispatche_tiles
;
}
else
{
// others, we reduce 1 dispatch from dp, together with partial dispatch,
// to construct sk dispatch
sk_occupancy
=
occupancy
-
((
full_dispatches
-
1
)
%
occupancy
);
dp_tiles
=
full_dispatch_tiles
-
num_cu
;
sk_tiles
=
partial_dispatche_tiles
+
num_cu
;
}
// dp_num_blocks = dp_tiles;
// dp_start_block_idx = num_cu * sk_occupancy;
dp_iters_per_block
=
k_iters_per_tile
.
get
();
sk_total_iters
=
k_iters_per_tile
.
get
()
*
sk_tiles
;
// printf("num_tiles:%d, full_dispatches:%d, full_dispatch_tiles:%d,
// partial_dispatche_tiles:%d\n",
// num_tiles, full_dispatches, full_dispatch_tiles, partial_dispatche_tiles);
{
uint32_t
min_sk_tiles
=
(
sk_tiles
>=
num_cu
)
?
num_cu
:
(
sk_tiles
+
1
);
uint32_t
max_sk_tiles
=
(
sk_tiles
>=
num_cu
)
?
num_cu
*
sk_occupancy
:
math
::
min
(
num_cu
,
sk_total_iters
/
min_k_iters_per_sk_block
);
// if use dp for sk-block, how many iters do we need
uint32_t
dp_for_sk_iters
=
k_iters_per_tile
.
get
();
uint32_t
best_sk_score
=
std
::
numeric_limits
<
int
>::
max
();
// we need to find the smallest sk iters
for
(
uint32_t
tentative_sk_blocks
=
min_sk_tiles
;
tentative_sk_blocks
<
max_sk_tiles
;
tentative_sk_blocks
++
)
{
uint32_t
tentative_sk_iters_per_block
=
(
sk_total_iters
+
tentative_sk_blocks
-
1
)
/
tentative_sk_blocks
;
uint32_t
tentative_sk_iters
=
tentative_sk_iters_per_block
;
uint32_t
sk_blocks_per_tile
=
(
tentative_sk_blocks
+
sk_tiles
-
1
)
/
sk_tiles
;
// TODO: carefully adjust this parameter
// the more sk_blocks_per_tile, the worse the overhead
uint32_t
cross_sk_blocks_overhead
=
sk_blocks_per_tile
;
if
(
tentative_sk_blocks
%
sk_tiles
!=
0
)
{
// penalty for uneven divide
cross_sk_blocks_overhead
+=
sk_blocks_per_tile
*
tentative_sk_iters_per_block
/
50
;
}
uint32_t
tentative_sk_score
=
tentative_sk_iters
+
cross_sk_blocks_overhead
;
if
(
tentative_sk_score
<
best_sk_score
)
{
best_sk_score
=
tentative_sk_score
;
sk_num_blocks
=
tentative_sk_blocks
;
}
}
if
(
best_sk_score
>=
dp_for_sk_iters
)
{
sk_num_blocks
=
0
;
}
if
(
sk_num_blocks
==
0
)
{
sk_num_big_blocks
=
0
;
k_iters_per_big_block
=
0
;
dp_num_blocks
=
num_tiles
;
// all tile to be dp block
dp_start_block_idx
=
0
;
sk_total_iters
=
0
;
// clear this tiles
}
else
{
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
// we need to decide how many iters for each sk block
// let m = k_iters_per_sk_block
// some of the sk block (little) will cover m iters, some (big) will cover m+1
// we have
// 1) l + b = sk_blocks
// 2) l * m + b * (m + 1) = sk_total_iters
// => (l + b) * m + b = sk_total_iters
// => sk_blocks * m + b = sk_total_iters
// => b = sk_total_iters - m * sk_blocks
// NOTE: big could be zero
uint32_t
k_iters_per_sk_block
=
sk_total_iters
/
sk_num_blocks
;
sk_num_big_blocks
=
sk_total_iters
-
k_iters_per_sk_block
*
sk_num_blocks
;
k_iters_per_big_block
=
k_iters_per_sk_block
+
1
;
dp_num_blocks
=
dp_tiles
;
dp_start_block_idx
=
(
sk_num_blocks
+
num_cu
-
1
)
/
num_cu
*
num_cu
;
}
}
n_tiles
=
MDiv
(
math
::
integer_divide_ceil
(
n
,
NPerBlock
));
printf
(
"cu:%d, occupancy:%d, grids:%d, sk_num_big_blocks:%d, sk_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, "
"k_iters_per_tile:%d, k_iters_per_big_block:%d
\n
"
,
num_cu
,
occupancy
,
get_grid_dims
().
x
,
sk_num_big_blocks
,
sk_num_blocks
,
sk_total_iters
,
dp_start_block_idx
,
dp_iters_per_block
,
dp_num_blocks
,
k_iters_per_tile
.
get
(),
k_iters_per_big_block
);
}
__host__
__device__
dim3
get_grid_dims
()
const
{
return
dim3
(
dp_start_block_idx
+
dp_num_blocks
,
1
,
1
);
}
__device__
uint32_t
get_block_idx
()
const
{
// TODO: swizzle block index for better locality
return
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
);
}
__device__
void
get_block_itr
(
uint32_t
block_idx
,
uint32_t
&
iter_start
,
uint32_t
&
iter_end
)
const
{
if
(
block_idx
<
sk_num_big_blocks
)
{
iter_start
=
block_idx
*
k_iters_per_big_block
;
iter_end
=
iter_start
+
k_iters_per_big_block
;
}
else
if
(
block_idx
<
sk_num_blocks
)
{
iter_start
=
(
sk_num_big_blocks
*
k_iters_per_big_block
)
+
(
block_idx
-
sk_num_big_blocks
)
*
(
k_iters_per_big_block
-
1
);
iter_end
=
iter_start
+
(
k_iters_per_big_block
-
1
);
}
else
if
(
block_idx
>=
dp_start_block_idx
)
{
iter_start
=
sk_total_iters
+
(
block_idx
-
dp_start_block_idx
)
*
dp_iters_per_block
;
iter_end
=
iter_start
+
dp_iters_per_block
;
}
}
__device__
uint32_t
get_current_iter_length
(
uint32_t
iter_start
,
uint32_t
iter_end
,
uint32_t
total_iter_length
)
const
{
uint32_t
iter_length_mod
,
iter_length_quo
/*unused*/
;
k_iters_per_tile
.
divmod
(
iter_end
,
iter_length_quo
,
iter_length_mod
);
uint32_t
current_iter_length
=
math
::
min
(
iter_length_mod
==
0
?
(
iter_end
-
iter_start
)
:
iter_length_mod
,
total_iter_length
);
return
current_iter_length
;
}
__device__
uint32_t
get_tile_idx
(
uint32_t
iter
)
const
{
return
k_iters_per_tile
.
div
(
iter
);
}
__device__
void
get_tile_idx_with_offset
(
uint32_t
iter
,
uint32_t
&
tile_idx
,
uint32_t
&
iter_offset
)
const
{
k_iters_per_tile
.
divmod
(
iter
,
tile_idx
,
iter_offset
);
}
__device__
auto
tile_to_spatial
(
uint32_t
tile_idx
)
const
{
// TODO:
uint32_t
m_tile_idx
,
n_tile_idx
;
n_tiles
.
divmod
(
tile_idx
,
m_tile_idx
,
n_tile_idx
);
return
make_tuple
(
m_tile_idx
,
n_tile_idx
);
}
};
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp
0 → 100644
View file @
1a2a1679
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
namespace
ck
{
struct
GridwiseGemmPipeline_v3
{
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
)
{
// TODO: improve applicability
return
true
;
}
template
<
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
// global read 0
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// LDS write 0
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
num_loop
--
;
while
(
num_loop
>
0
)
{
block_sync_lds
();
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
num_loop
--
;
}
// tail
{
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
0 → 100644
View file @
1a2a1679
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp
0 → 100644
View file @
1a2a1679
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
namespace
ck
{
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions:
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
// instead
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
// tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
bool
SrcResetCoordinateAfterRun
,
bool
DstResetCoordinateAfterRun
>
struct
ThreadwiseTensorSliceTransfer_v6r1r2
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
SrcCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
static
constexpr
auto
I0
=
Number
<
0
>
{};
__device__
constexpr
ThreadwiseTensorSliceTransfer_v6r1r2
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin
,
const
ElementwiseOperation
&
element_op
)
:
src_coord_
(
make_tensor_coordinate
(
src_desc
,
src_slice_origin
)),
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin
)),
element_op_
(
element_op
)
{
static_assert
(
SliceLengths
::
At
(
Number
<
VectorDim
>
{})
%
ScalarPerVector
==
0
,
"wrong! cannot evenly divide"
);
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
{
dst_coord_
=
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
InMemoryDataOperationEnum
DstInMemOp
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
VectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
remove_cv_t
<
decltype
(
scalar_per_access
)
>>
;
// loop over space-filling curve
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
idx_1d
)
{
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
ScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
ScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
// copy data from src_buf into src_vector_container
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
auto
dst_vector_container
=
dst_vector_type
{};
// apply pointwise operation
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
SrcData
v
;
// apply element-wise operation
element_op_
(
v
,
src_vector_container
.
template
AsType
<
SrcData
>()[
i
]);
// apply type convert
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
v
);
});
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
// copy data from dst_vector into dst_buf
dst_buf
.
template
Update
<
DstInMemOp
,
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
// move coordinate
if
constexpr
(
idx_1d
.
value
!=
num_access
-
1
)
{
constexpr
auto
forward_step
=
SpaceFillingCurve
::
GetForwardStep
(
idx_1d
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
make_tensor_coordinate_step
(
src_desc
,
forward_step
));
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
make_tensor_coordinate_step
(
dst_desc
,
forward_step
));
}
});
// move coordinate back to slice origin (or not)
if
constexpr
(
SrcResetCoordinateAfterRun
)
{
const
auto
src_reset_step
=
make_tensor_coordinate_step
(
src_desc
,
GetCoordinateResetStep
());
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_reset_step
);
}
if
constexpr
(
DstResetCoordinateAfterRun
)
{
const
auto
dst_reset_step
=
make_tensor_coordinate_step
(
dst_desc
,
GetCoordinateResetStep
());
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_reset_step
);
}
}
__device__
static
constexpr
auto
GetCoordinateResetStep
()
{
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
VectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
remove_cv_t
<
decltype
(
scalar_per_access
)
>>
;
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
if
constexpr
(
num_access
==
0
)
{
return
typename
SpaceFillingCurve
::
Index
{};
}
else
{
constexpr
auto
reset_step
=
SpaceFillingCurve
::
GetStepBetween
(
Number
<
num_access
-
1
>
{},
Number
<
0
>
{});
return
reset_step
;
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRun
?
src_slice_origin_step_idx
:
src_slice_origin_step_idx
+
GetCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
src_desc
,
adjusted_step_idx
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_step_idx
)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const
auto
adjusted_step_idx
=
DstResetCoordinateAfterRun
?
dst_slice_origin_step_idx
:
dst_slice_origin_step_idx
+
GetCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
dst_desc
,
adjusted_step_idx
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
adjusted_step
);
}
private:
SrcCoord
src_coord_
;
DstCoord
dst_coord_
;
const
ElementwiseOperation
element_op_
;
};
}
// namespace ck
include/ck/utility/magic_division.hpp
View file @
1a2a1679
...
@@ -157,4 +157,42 @@ struct MagicDivision
...
@@ -157,4 +157,42 @@ struct MagicDivision
}
}
};
};
struct
MDiv
{
// 1 dword -> 3 dword storage
uint32_t
divisor
;
uint32_t
multiplier
;
uint32_t
shift
;
// prefer construct on host
__host__
__device__
MDiv
(
uint32_t
divisor_
)
:
divisor
(
divisor_
)
{
ck
::
tie
(
multiplier
,
shift
)
=
MagicDivision
::
CalculateMagicNumbers
(
divisor_
);
}
__host__
__device__
MDiv
()
:
divisor
(
0
),
multiplier
(
0
),
shift
(
0
)
{}
__host__
__device__
void
update
(
uint32_t
divisor_
)
{
divisor
=
divisor_
;
ck
::
tie
(
multiplier
,
shift
)
=
MagicDivision
::
CalculateMagicNumbers
(
divisor_
);
}
__host__
__device__
uint32_t
div
(
uint32_t
dividend
)
const
{
return
MagicDivision
::
DoMagicDivision
(
dividend
,
multiplier
,
shift
);
}
__host__
__device__
void
divmod
(
uint32_t
dividend
,
uint32_t
&
quotient
,
uint32_t
&
remainder
)
const
{
quotient
=
div
(
dividend
);
remainder
=
dividend
-
(
quotient
*
divisor
);
}
__host__
__device__
uint32_t
operator
/
(
uint32_t
dividend
)
const
{
return
div
(
dividend
);
}
__host__
__device__
uint32_t
get
()
const
{
return
divisor
;
}
};
}
// namespace ck
}
// namespace ck
library/include/ck/library/utility/check_err.hpp
View file @
1a2a1679
...
@@ -55,7 +55,7 @@ check_err(const Range& out,
...
@@ -55,7 +55,7 @@ check_err(const Range& out,
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
>
0
)
{
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
@@ -102,7 +102,7 @@ check_err(const Range& out,
...
@@ -102,7 +102,7 @@ check_err(const Range& out,
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
>
0
)
{
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
@@ -148,7 +148,7 @@ check_err(const Range& out,
...
@@ -148,7 +148,7 @@ check_err(const Range& out,
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
>
0
)
{
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
@@ -199,7 +199,7 @@ check_err(const Range& out,
...
@@ -199,7 +199,7 @@ check_err(const Range& out,
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
>
0
)
{
{
std
::
cerr
<<
msg
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
std
::
cerr
<<
msg
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
std
::
endl
;
...
...
test/block_swizzle_test/block_swizzle_test.cpp
0 → 100644
View file @
1a2a1679
#include <stdio.h>
#include <string>
#include <algorithm>
#include <vector>
#include <limits>
#include "simple_args.h"
simple_args_t
create_arg
(
int
argc
,
char
**
argv
)
{
simple_args_t
args
;
args
.
insert
(
"m"
,
"1024"
,
"matrix m"
)
.
insert
(
"n"
,
"1024"
,
"matrix n"
)
.
insert
(
"k"
,
"1024"
,
"matrix k"
)
.
insert
(
"m_per_block"
,
"128"
,
"m_per_block"
)
.
insert
(
"n_per_block"
,
"128"
,
"n_per_block"
)
.
insert
(
"k_per_block"
,
"32"
,
"k_per_block"
)
.
insert
(
"num_cu"
,
"104"
,
"num cu"
)
.
insert
(
"occupancy"
,
"2"
,
"occupancy"
)
.
parse
(
argc
,
argv
);
return
args
;
}
namespace
impl
{
template
<
typename
T
>
T
integer_divide_ceil
(
T
n
,
T
d
)
{
return
(
n
+
d
-
1
)
/
d
;
}
template
<
typename
T
>
T
min
(
T
a
,
T
b
)
{
return
a
>
b
?
b
:
a
;
}
template
<
typename
T
>
T
max
(
T
a
,
T
b
)
{
return
a
>
b
?
a
:
b
;
}
}
// namespace impl
struct
block_dispatcher_t
{
public:
uint32_t
m_per_block
;
uint32_t
n_per_block
;
uint32_t
k_per_block
;
uint32_t
num_cu
;
uint32_t
occupancy
;
uint32_t
m
;
uint32_t
n
;
uint32_t
k
;
//--------------------------------------
uint32_t
sk_num_blocks
;
uint32_t
sk_num_big_blocks
;
uint32_t
sk_total_iters
;
// uint32_t sk_num_blocks_per_tile; // how many
uint32_t
dp_start_block_idx
;
uint32_t
dp_iters_per_block
;
uint32_t
dp_num_blocks
;
uint32_t
k_iters_per_tile
;
uint32_t
k_iters_per_big_block
;
//--------------------------------------
static
constexpr
uint32_t
min_k_iters_per_sk_block
=
1
;
void
dump
()
{
printf
(
"%dx%dx%d(%dx%dx%d), cu:%d, occ:%d, grids:%d, sk_num_big_blocks:%d, "
"sk_num_blocks:%d, sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, "
"dp_num_blocks:%d, k_iters_per_tile:%d, k_iters_per_big_block:%d
\n
"
,
m
,
n
,
k
,
m_per_block
,
n_per_block
,
k_per_block
,
num_cu
,
occupancy
,
get_grid_dims_x
(),
sk_num_big_blocks
,
sk_num_blocks
,
sk_total_iters
,
dp_start_block_idx
,
dp_iters_per_block
,
dp_num_blocks
,
k_iters_per_tile
,
k_iters_per_big_block
);
}
block_dispatcher_t
(
uint32_t
m_per_block_
,
uint32_t
n_per_block_
,
uint32_t
k_per_block_
,
uint32_t
num_cu_
,
uint32_t
occupancy_
,
uint32_t
m_
,
uint32_t
n_
,
uint32_t
k_
)
:
m_per_block
(
m_per_block_
),
n_per_block
(
n_per_block_
),
k_per_block
(
k_per_block_
),
num_cu
(
num_cu_
),
occupancy
(
occupancy_
),
m
(
m_
),
n
(
n_
),
k
(
k_
)
{
init
();
}
uint32_t
get_grid_dims_x
()
{
return
dp_start_block_idx
+
dp_num_blocks
;
}
uint32_t
get_block_idx
(
uint32_t
bid
)
{
// block id is linearily allocated along sk blocks (dp blocks are fine)
// this function will compute blockIdx.x and the linear sk block mapping
// uint32_t block_idx = 0;
// if(bid < sk_num_big_blocks) {
// uint32_t current_k_iter = bid * k_iters_per_big_block;
// tile_idx = current_k_iter / k_iters_per_tile;
// }
return
bid
;
}
uint32_t
get_current_itr
(
uint32_t
block_idx
)
{
uint32_t
current_itr
=
0
;
if
(
block_idx
<
sk_num_big_blocks
)
{
current_itr
=
block_idx
*
k_iters_per_big_block
;
}
else
if
(
block_idx
<
sk_num_blocks
)
{
current_itr
=
(
sk_num_big_blocks
*
k_iters_per_big_block
)
+
(
block_idx
-
sk_num_big_blocks
)
*
(
k_iters_per_big_block
-
1
);
}
else
if
(
block_idx
>=
dp_start_block_idx
)
{
current_itr
=
sk_total_iters
+
(
block_idx
-
dp_start_block_idx
)
*
dp_iters_per_block
;
}
return
current_itr
;
}
void
get_block_itr
(
uint32_t
block_idx
,
uint32_t
&
iter_start
,
uint32_t
&
iter_end
)
{
if
(
block_idx
<
sk_num_big_blocks
)
{
iter_start
=
block_idx
*
k_iters_per_big_block
;
iter_end
=
iter_start
+
k_iters_per_big_block
;
}
else
if
(
block_idx
<
sk_num_blocks
)
{
iter_start
=
(
sk_num_big_blocks
*
k_iters_per_big_block
)
+
(
block_idx
-
sk_num_big_blocks
)
*
(
k_iters_per_big_block
-
1
);
iter_end
=
iter_start
+
(
k_iters_per_big_block
-
1
);
}
else
if
(
block_idx
>=
dp_start_block_idx
)
{
iter_start
=
sk_total_iters
+
(
block_idx
-
dp_start_block_idx
)
*
dp_iters_per_block
;
iter_end
=
iter_start
+
dp_iters_per_block
;
}
}
private:
void
init
()
{
uint32_t
num_tiles
=
impl
::
integer_divide_ceil
(
m
,
m_per_block
)
*
impl
::
integer_divide_ceil
(
n
,
n_per_block
);
k_iters_per_tile
=
impl
::
integer_divide_ceil
(
k
,
k_per_block
);
// one cu can hold one wg at one time, from the whole chip's point of view
// if number of wg is same as num_cu, we call it 1 dispatch
// if number of wg is 2x num_cu, we call it 2 dispatches.
// one dispatch can deliever wg same as num_cu (full dispatch), or less than num_cu (partial
// dispatch)
//
uint32_t
full_dispatches
=
num_tiles
/
num_cu
;
uint32_t
full_dispatch_tiles
=
full_dispatches
*
num_cu
;
uint32_t
partial_dispatche_tiles
=
num_tiles
-
full_dispatch_tiles
;
uint32_t
sk_occupancy
=
occupancy
;
uint32_t
dp_tiles
=
full_dispatch_tiles
;
uint32_t
sk_tiles
=
partial_dispatche_tiles
;
if
(
full_dispatches
<
occupancy
)
{
// in this case, we allocate all blocks as sk blocks
// sk_occupancy = occupancy - full_dispatches;
sk_occupancy
=
1
;
// TODO: single occ seems better
dp_tiles
=
full_dispatch_tiles
;
sk_tiles
=
partial_dispatche_tiles
;
}
else
if
((
occupancy
>
1
)
&&
(
full_dispatches
%
occupancy
==
occupancy
-
1
))
{
// e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
// occupancy = 3, full_dispatches = 5, 8, 11 ...
// occupancy = 4, full_dispatches = 7, 11 ...
sk_occupancy
=
1
;
// left 1 slot for sk occupancy
dp_tiles
=
full_dispatch_tiles
;
sk_tiles
=
partial_dispatche_tiles
;
}
else
{
// others, we reduce 1 dispatch from dp, together with partial dispatch,
// to construct sk dispatch
sk_occupancy
=
occupancy
-
((
full_dispatches
-
1
)
%
occupancy
);
dp_tiles
=
full_dispatch_tiles
-
num_cu
;
sk_tiles
=
partial_dispatche_tiles
+
num_cu
;
}
// dp_num_blocks = dp_tiles;
// dp_start_block_idx = num_cu * sk_occupancy;
dp_iters_per_block
=
k_iters_per_tile
;
sk_total_iters
=
k_iters_per_tile
*
sk_tiles
;
// printf("num_tiles:%d, full_dispatches:%d, full_dispatch_tiles:%d,
// partial_dispatche_tiles:%d\n",
// num_tiles, full_dispatches, full_dispatch_tiles, partial_dispatche_tiles);
{
uint32_t
min_sk_tiles
=
(
sk_tiles
>=
num_cu
)
?
num_cu
:
(
sk_tiles
+
1
);
uint32_t
max_sk_tiles
=
(
sk_tiles
>=
num_cu
)
?
num_cu
*
sk_occupancy
:
impl
::
min
(
num_cu
,
sk_total_iters
/
min_k_iters_per_sk_block
);
// if use dp for sk-block, how many iters do we need
uint32_t
dp_for_sk_iters
=
k_iters_per_tile
;
uint32_t
best_sk_score
=
std
::
numeric_limits
<
int
>::
max
();
// we need to find the smallest sk iters
for
(
uint32_t
tentative_sk_blocks
=
min_sk_tiles
;
tentative_sk_blocks
<
max_sk_tiles
;
tentative_sk_blocks
++
)
{
uint32_t
tentative_sk_iters_per_block
=
(
sk_total_iters
+
tentative_sk_blocks
-
1
)
/
tentative_sk_blocks
;
uint32_t
tentative_sk_iters
=
tentative_sk_iters_per_block
;
uint32_t
sk_blocks_per_tile
=
(
tentative_sk_blocks
+
sk_tiles
-
1
)
/
sk_tiles
;
// TODO: carefully adjust this parameter
// the more sk_blocks_per_tile, the worse the overhead
uint32_t
cross_sk_blocks_overhead
=
sk_blocks_per_tile
;
if
(
tentative_sk_blocks
%
sk_tiles
!=
0
)
{
// penalty for uneven divide
cross_sk_blocks_overhead
+=
sk_blocks_per_tile
*
tentative_sk_iters_per_block
/
50
;
}
uint32_t
tentative_sk_score
=
tentative_sk_iters
+
cross_sk_blocks_overhead
;
if
(
tentative_sk_score
<
best_sk_score
)
{
best_sk_score
=
tentative_sk_score
;
sk_num_blocks
=
tentative_sk_blocks
;
}
}
if
(
best_sk_score
>=
dp_for_sk_iters
)
{
sk_num_blocks
=
0
;
}
if
(
sk_num_blocks
==
0
)
{
sk_num_big_blocks
=
0
;
k_iters_per_big_block
=
0
;
dp_num_blocks
=
num_tiles
;
// all tile to be dp block
dp_start_block_idx
=
0
;
sk_total_iters
=
0
;
// clear this tiles
}
else
{
uint32_t
k_iters_per_sk_block
=
sk_total_iters
/
sk_num_blocks
;
sk_num_big_blocks
=
sk_total_iters
-
k_iters_per_sk_block
*
sk_num_blocks
;
k_iters_per_big_block
=
k_iters_per_sk_block
+
1
;
dp_num_blocks
=
dp_tiles
;
dp_start_block_idx
=
(
sk_num_blocks
+
num_cu
-
1
)
/
num_cu
*
num_cu
;
}
}
}
};
struct
tile_work_t
{
uint32_t
tile_idx
;
uint32_t
iter_begin
;
uint32_t
k_begin
;
uint32_t
k_end
;
uint32_t
k_iters_remaining
;
};
int
main
(
int
argc
,
char
**
argv
)
{
simple_args_t
arg
=
create_arg
(
argc
,
argv
);
block_dispatcher_t
block_dispatcher
{
arg
.
get_uint32
(
"m_per_block"
),
arg
.
get_uint32
(
"n_per_block"
),
arg
.
get_uint32
(
"k_per_block"
),
arg
.
get_uint32
(
"num_cu"
),
arg
.
get_uint32
(
"occupancy"
),
arg
.
get_uint32
(
"m"
),
arg
.
get_uint32
(
"n"
),
arg
.
get_uint32
(
"k"
)};
block_dispatcher
.
dump
();
// simulate actual kernel launch
uint32_t
dim_x
=
block_dispatcher
.
get_grid_dims_x
();
uint32_t
total_k_iters
=
impl
::
integer_divide_ceil
(
arg
.
get_uint32
(
"k"
),
arg
.
get_uint32
(
"k_per_block"
));
uint32_t
num_tiles
=
impl
::
integer_divide_ceil
(
arg
.
get_uint32
(
"m"
),
arg
.
get_uint32
(
"m_per_block"
))
*
impl
::
integer_divide_ceil
(
arg
.
get_uint32
(
"n"
),
arg
.
get_uint32
(
"n_per_block"
));
std
::
vector
<
int
>
valid_tile_record
(
num_tiles
*
total_k_iters
);
for
(
uint32_t
bid
=
0
;
bid
<
dim_x
;
bid
++
)
{
uint32_t
block_idx
=
block_dispatcher
.
get_block_idx
(
bid
);
bool
is_sk_block
=
block_idx
<
(
block_dispatcher
.
sk_num_blocks
);
bool
is_dp_block
=
block_idx
>=
block_dispatcher
.
dp_start_block_idx
;
uint32_t
iter_start
,
iter_end
;
block_dispatcher
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
uint32_t
total_iter_length
=
iter_end
-
iter_start
;
while
(
true
)
{
uint32_t
iter_length_mod
=
iter_end
%
block_dispatcher
.
k_iters_per_tile
;
uint32_t
current_iter_length
=
impl
::
min
(
iter_length_mod
==
0
?
(
iter_end
-
iter_start
)
:
iter_length_mod
,
total_iter_length
);
uint32_t
tile_idx
=
(
iter_end
-
1
)
/
block_dispatcher
.
k_iters_per_tile
;
uint32_t
tile_iter_start
=
((
iter_end
-
1
)
%
block_dispatcher
.
k_iters_per_tile
)
-
current_iter_length
+
1
;
if
(
is_sk_block
)
{
printf
(
"[sk_block] bid:%3d, block_idx:%3d, tile_idx:%3d, iter_start:%d(%d | %d), "
"iter_end:%d (len:%d)
\n
"
,
bid
,
block_idx
,
tile_idx
,
iter_end
-
current_iter_length
,
tile_iter_start
,
iter_start
,
iter_end
,
current_iter_length
);
}
else
if
(
is_dp_block
)
{
printf
(
"[dp_block] bid:%3d, block_idx:%3d, tile_idx:%3d, iter_start:%d(%d | %d), "
"iter_end:%d (len:%d)
\n
"
,
bid
,
block_idx
,
tile_idx
,
iter_end
-
current_iter_length
,
tile_iter_start
,
iter_start
,
iter_end
,
current_iter_length
);
}
else
{
printf
(
"[other ] bid:%3d, block_idx:%3d
\n
"
,
bid
,
block_idx
);
}
// some validation check
for
(
auto
i
=
iter_end
-
current_iter_length
;
i
<
iter_end
;
i
++
)
{
if
(
i
>=
valid_tile_record
.
size
())
{
printf
(
"unexpected, current iter:%d larger than max:%d
\n
"
,
i
,
valid_tile_record
.
size
());
return
-
1
;
}
valid_tile_record
[
i
]
=
1
;
}
iter_end
-=
current_iter_length
;
if
(
iter_end
<=
iter_start
)
break
;
}
}
int
untouched
=
0
;
for
(
auto
i
=
0
;
i
<
valid_tile_record
.
size
();
i
++
)
{
if
(
valid_tile_record
[
i
]
!=
1
)
{
printf
(
"untouched at %d (%d)
\n
"
,
i
,
valid_tile_record
.
size
());
untouched
++
;
}
}
printf
(
"untouched %d/%d, %s
\n
"
,
untouched
,
valid_tile_record
.
size
(),
untouched
==
0
?
"valid"
:
"fail"
);
}
test/block_swizzle_test/rebuild.sh
0 → 100644
View file @
1a2a1679
CC
=
g++
$CC
-Wall
-std
=
c++17
-Iinclude
-O3
block_swizzle_test.cpp
-o
block_swizzle_test.exe
\ No newline at end of file
test/block_swizzle_test/simple_args.h
0 → 100644
View file @
1a2a1679
#pragma once
#include <iomanip>
#include <iostream>
#include <stdlib.h>
#include <string>
#include <unordered_map>
#include <vector>
#include <assert.h>
struct
arg_content_t
{
std
::
string
name
;
// key
std
::
string
value
;
std
::
string
help_text
;
};
class
simple_args_t
{
public:
simple_args_t
()
{}
simple_args_t
&
insert
(
const
std
::
string
&
name_
,
const
std
::
string
&
default_value_
,
const
std
::
string
&
help_text_
)
{
arg_content_t
arg
{
name_
,
default_value_
,
help_text_
};
if
(
arg_map
.
count
(
arg
.
name
)
!=
0
)
{
std
::
cout
<<
"arg:"
<<
arg
.
name
<<
"already exist"
<<
std
::
endl
;
}
else
{
arg_map
[
arg
.
name
]
=
arg
;
}
return
*
this
;
}
void
usage
()
{
for
(
auto
&
content
:
arg_map
)
{
std
::
vector
<
std
::
string
>
help_text_lines
;
size_t
pos
=
0
;
for
(
size_t
next_pos
=
content
.
second
.
help_text
.
find
(
'\n'
,
pos
);
next_pos
!=
std
::
string
::
npos
;)
{
help_text_lines
.
push_back
(
std
::
string
(
content
.
second
.
help_text
.
begin
()
+
pos
,
content
.
second
.
help_text
.
begin
()
+
next_pos
++
));
pos
=
next_pos
;
next_pos
=
content
.
second
.
help_text
.
find
(
'\n'
,
pos
);
}
help_text_lines
.
push_back
(
std
::
string
(
content
.
second
.
help_text
.
begin
()
+
pos
,
content
.
second
.
help_text
.
end
()));
int
arg_name_width
=
16
-
content
.
second
.
name
.
length
();
arg_name_width
=
arg_name_width
>
0
?
arg_name_width
:
2
;
std
::
cout
<<
std
::
setw
(
4
)
<<
"-"
<<
content
.
second
.
name
<<
std
::
setw
(
arg_name_width
)
<<
" "
<<
help_text_lines
[
0
]
<<
std
::
endl
;
for
(
auto
help_next_line
=
std
::
next
(
help_text_lines
.
begin
());
help_next_line
!=
help_text_lines
.
end
();
++
help_next_line
)
{
std
::
cout
<<
std
::
setw
(
28
)
<<
" "
<<
*
help_next_line
<<
std
::
endl
;
}
}
}
bool
parse
(
int
argc
,
char
*
argv
[],
int
start_index
=
1
)
{
if
(
argc
<=
start_index
)
{
// std::cout << "not enough args (" << argc << ") with starting index " << start_index
// << std::endl;
return
true
;
}
for
(
int
i
=
start_index
;
i
<
argc
;
i
++
)
{
std
::
string
cur_arg
=
std
::
string
(
argv
[
i
]);
if
(
cur_arg
[
0
]
!=
'-'
)
{
std
::
cout
<<
"illegal input"
<<
std
::
endl
;
usage
();
return
false
;
}
else
if
(
cur_arg
[
0
]
==
'-'
&&
cur_arg
[
1
]
==
'?'
)
{
usage
();
return
false
;
}
else
{
size_t
found_equal
=
cur_arg
.
find
(
'='
);
if
(
found_equal
==
std
::
string
::
npos
||
found_equal
==
(
cur_arg
.
length
()
-
1
))
{
std
::
cout
<<
"failed while parsing
\"
"
<<
cur_arg
<<
"
\"
, "
<<
"arg must be in the form
\"
-name=value
\"
"
<<
std
::
endl
;
return
false
;
}
std
::
string
arg_name
=
cur_arg
.
substr
(
1
,
found_equal
-
1
);
std
::
string
arg_value
=
cur_arg
.
substr
(
found_equal
+
1
);
if
(
arg_map
.
count
(
arg_name
)
==
0
)
{
std
::
cout
<<
"no such arg
\"
"
<<
arg_name
<<
"
\"
registered"
<<
std
::
endl
;
return
false
;
}
arg_map
[
arg_name
].
value
=
arg_value
;
}
}
return
true
;
}
std
::
string
get
(
const
std
::
string
&
name
)
const
{
return
get_str
(
name
);
}
std
::
string
get_str
(
const
std
::
string
&
name
)
const
{
assert
(
arg_map
.
count
(
name
)
!=
0
);
std
::
string
value
=
arg_map
.
at
(
name
).
value
;
return
value
;
}
int
get_int
(
const
std
::
string
&
name
)
const
{
assert
(
arg_map
.
count
(
name
)
!=
0
);
int
value
=
atoi
(
arg_map
.
at
(
name
).
value
.
c_str
());
return
value
;
}
uint32_t
get_uint32
(
const
std
::
string
&
name
)
const
{
assert
(
arg_map
.
count
(
name
)
!=
0
);
uint32_t
value
=
strtoul
(
arg_map
.
at
(
name
).
value
.
c_str
(),
nullptr
,
10
);
return
value
;
}
uint64_t
get_uint64
(
const
std
::
string
&
name
)
const
{
assert
(
arg_map
.
count
(
name
)
!=
0
);
uint64_t
value
=
strtoull
(
arg_map
.
at
(
name
).
value
.
c_str
(),
nullptr
,
10
);
return
value
;
}
double
get_double
(
const
std
::
string
&
name
)
const
{
assert
(
arg_map
.
count
(
name
)
!=
0
);
double
value
=
atof
(
arg_map
.
at
(
name
).
value
.
c_str
());
return
value
;
}
float
get_float
(
const
std
::
string
&
name
)
const
{
assert
(
arg_map
.
count
(
name
)
!=
0
);
float
value
=
atof
(
arg_map
.
at
(
name
).
value
.
c_str
());
return
value
;
}
private:
std
::
unordered_map
<
std
::
string
,
arg_content_t
>
arg_map
;
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment