Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1a2a1679
Commit
1a2a1679
authored
May 08, 2023
by
carlushuang
Browse files
initial stream-k implementation with example
parent
8bb2bb4a
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
2500 additions
and
7 deletions
+2500
-7
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+1
-0
example/01_gemm/gemm_xdl_fp16.cpp
example/01_gemm/gemm_xdl_fp16.cpp
+3
-1
example/01_gemm/gemm_xdl_streamk.cpp
example/01_gemm/gemm_xdl_streamk.cpp
+41
-0
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+25
-2
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
+15
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp
...n/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp
+134
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
...sor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
+280
-0
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+233
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp
+89
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
...ensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
+856
-0
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp
...on/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp
+213
-0
include/ck/utility/magic_division.hpp
include/ck/utility/magic_division.hpp
+38
-0
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+4
-4
test/block_swizzle_test/block_swizzle_test.cpp
test/block_swizzle_test/block_swizzle_test.cpp
+406
-0
test/block_swizzle_test/rebuild.sh
test/block_swizzle_test/rebuild.sh
+3
-0
test/block_swizzle_test/simple_args.h
test/block_swizzle_test/simple_args.h
+159
-0
No files found.
example/01_gemm/CMakeLists.txt
View file @
1a2a1679
...
...
@@ -44,3 +44,4 @@ if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS
add_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
endif
()
add_example_executable
(
example_gemm_xdl_streamk gemm_xdl_streamk.cpp
)
example/01_gemm/gemm_xdl_fp16.cpp
View file @
1a2a1679
...
...
@@ -39,7 +39,9 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
LoopScheduler
::
Default
,
ck
::
PipelineVersion
::
v1
>
;
// DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
// clang-format on
using
DeviceGemmInstance
=
DeviceGemmInstance1
;
...
...
example/01_gemm/gemm_xdl_streamk.cpp
0 → 100644
View file @
1a2a1679
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp"
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
CDataType
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
// clang-format off
using
DeviceGemmStreamK
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlStreamK
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// // clang-format on
// clang-format on
using
DeviceGemmInstance
=
DeviceGemmStreamK
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
#include "run_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
example/01_gemm/run_gemm_example.inc
View file @
1a2a1679
...
...
@@ -11,7 +11,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
using
namespace
ck
::
literals
;
auto
&
[
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
]
=
problem_size
;
auto
[
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
]
=
problem_size
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
...
...
@@ -25,12 +25,35 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
}
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
{
// give a chance if stride is zero, return a defalt packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
};
StrideA
=
f_get_default_stride
(
M
,
K
,
StrideA
,
ALayout
{});
StrideB
=
f_get_default_stride
(
K
,
N
,
StrideB
,
BLayout
{});
StrideC
=
f_get_default_stride
(
M
,
N
,
StrideC
,
CLayout
{});
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
switch
(
config
.
init_method
)
{
case
0
:
break
;
case
0
:
ck
::
utils
::
FillConstant
<
ADataType
>
{
1.
f
}(
a_m_k
);
ck
::
utils
::
FillConstant
<
BDataType
>
{
1.
f
}(
b_k_n
);
break
;
case
1
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
View file @
1a2a1679
...
...
@@ -94,6 +94,21 @@ struct ThreadGroupTensorSliceTransfer_v4r1
}
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
SrcBuffer
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp
0 → 100644
View file @
1a2a1679
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp"
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
typename
ThreadGroup
,
typename
ElementwiseOperation
,
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
ThreadGroupTensorSliceTransfer_v6r1r2
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadGroupTensorSliceTransfer_v6r1r2
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
(),
element_op
)
{
static_assert
(
nDim
==
remove_cvref_t
<
SrcDesc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! ThreadGroup::GetNumOfThread() too small"
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
InMemoryDataOperationEnum
DstInMemOp
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
template
Run
<
SrcBuffer
,
DstBuffer
,
DstInMemOp
>(
src_desc
,
src_buf
,
dst_desc
,
dst_buf
);
}
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v6r1r2
<
SrcData
,
DstData
,
SrcDesc
,
DstDesc
,
ElementwiseOperation
,
decltype
(
thread_slice_lengths
),
DimAccessOrder
,
VectorDim
,
ScalarPerVector
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
0 → 100644
View file @
1a2a1679
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
ck
::
index_t
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
ck
::
index_t
BBlockLdsAddExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXDL
>
struct
DeviceGemmXdlStreamK
:
public
DeviceGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
<
BlockSize
,
BlockToCTileMap_GemmStreamK
<
MPerBlock
,
NPerBlock
,
K0PerBlock
*
K1
>
,
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
// Invoker
struct
Invoker
:
public
BaseInvoker
{
void
Print
(
const
Argument
&
karg
)
{
karg
.
Print
();
}
float
Run
(
const
Argument
&
karg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
Print
(
karg
);
}
if
(
!
GridwiseGemm
::
CheckValidity
(
karg
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid "
"setting"
);
}
dim3
grid_dims
=
karg
.
block_mapping
.
get_grid_dims
();
float
ave_time
=
0
;
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
// TODO: remove clear buffer for streamk kernels
hipGetErrorString
(
hipMemset
(
karg
.
p_c_grid
,
0
,
karg
.
M
*
karg
.
N
*
sizeof
(
CDataType
)));
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
grid_dims
,
dim3
(
BlockSize
),
0
,
karg
);
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
karg
)
{
return
GridwiseGemm
::
CheckValidity
(
karg
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
)
{
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
int
occupancy
,
num_cu
;
hipError_t
rtn
;
rtn
=
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
GridwiseGemm
::
GetSharedMemoryNumberOfByte
());
hip_check_error
(
rtn
);
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
rtn
=
hipGetDevice
(
&
dev
);
hip_check_error
(
rtn
);
rtn
=
hipGetDeviceProperties
(
&
dev_prop
,
dev
);
hip_check_error
(
rtn
);
num_cu
=
dev_prop
.
multiProcessorCount
;
printf
(
"XXX occupancy:%d, num_cu:%d, BLOCK_SIZE:%d, LDS:%d
\n
"
,
occupancy
,
num_cu
,
BlockSize
,
GridwiseGemm
::
GetSharedMemoryNumberOfByte
());
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
static_cast
<
uint32_t
>
(
num_cu
),
static_cast
<
uint32_t
>
(
occupancy
)};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
)
override
{
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
int
occupancy
,
num_cu
;
hipError_t
rtn
;
rtn
=
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
GridwiseGemm
::
GetSharedMemoryNumberOfByte
());
hip_check_error
(
rtn
);
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
rtn
=
hipGetDevice
(
&
dev
);
hip_check_error
(
rtn
);
rtn
=
hipGetDeviceProperties
(
&
dev_prop
,
dev
);
hip_check_error
(
rtn
);
num_cu
=
dev_prop
.
multiProcessorCount
;
return
std
::
make_unique
<
Argument
>
(
reinterpret_cast
<
const
ADataType
*>
(
p_a
),
reinterpret_cast
<
const
BDataType
*>
(
p_b
),
reinterpret_cast
<
CDataType
*>
(
p_c
),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
static_cast
<
uint32_t
>
(
num_cu
),
static_cast
<
uint32_t
>
(
occupancy
));
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
return
GridwiseGemm
::
GetTypeString
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
1a2a1679
...
...
@@ -7,6 +7,7 @@
#include "ck/utility/number.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include <limits>
namespace
ck
{
...
...
@@ -635,4 +636,236 @@ struct BlockToCTileMap_3DGrid_KSplit
}
};
template
<
uint32_t
MPerBlock_
,
uint32_t
NPerBlock_
,
uint32_t
KPerBlock_
>
struct
BlockToCTileMap_GemmStreamK
{
static
constexpr
uint32_t
min_k_iters_per_sk_block
=
2
;
static
constexpr
uint32_t
MPerBlock
=
MPerBlock_
;
static
constexpr
uint32_t
NPerBlock
=
NPerBlock_
;
static
constexpr
uint32_t
KPerBlock
=
KPerBlock_
;
//--------------------------------------
// pass to device
uint32_t
sk_num_blocks
;
uint32_t
sk_num_big_blocks
;
uint32_t
sk_total_iters
;
uint32_t
dp_start_block_idx
;
uint32_t
dp_iters_per_block
;
uint32_t
dp_num_blocks
;
uint32_t
k_iters_per_big_block
;
MDiv
k_iters_per_tile
;
MDiv
n_tiles
;
//--------------------------------------
// prefer construct on host
BlockToCTileMap_GemmStreamK
(
uint32_t
m
,
uint32_t
n
,
uint32_t
k
,
uint32_t
num_cu
,
uint32_t
occupancy
)
{
uint32_t
num_tiles
=
math
::
integer_divide_ceil
(
m
,
MPerBlock
)
*
math
::
integer_divide_ceil
(
n
,
NPerBlock
);
k_iters_per_tile
=
MDiv
(
math
::
integer_divide_ceil
(
k
,
KPerBlock
));
// one cu can hold one wg at one time, from the whole chip's point of view
// if number of wg is same as num_cu, we call it 1 dispatch
// if number of wg is 2x num_cu, we call it 2 dispatches.
// one dispatch can deliever wg same as num_cu (full dispatch), or less than num_cu (partial
// dispatch)
//
uint32_t
full_dispatches
=
num_tiles
/
num_cu
;
uint32_t
full_dispatch_tiles
=
full_dispatches
*
num_cu
;
uint32_t
partial_dispatche_tiles
=
num_tiles
-
full_dispatch_tiles
;
uint32_t
sk_occupancy
=
occupancy
;
uint32_t
dp_tiles
=
full_dispatch_tiles
;
uint32_t
sk_tiles
=
partial_dispatche_tiles
;
if
(
full_dispatches
<
occupancy
)
{
// in this case, we allocate all blocks as sk blocks
// sk_occupancy = occupancy - full_dispatches;
sk_occupancy
=
1
;
// TODO: single occ seems better
dp_tiles
=
full_dispatch_tiles
;
sk_tiles
=
partial_dispatche_tiles
;
}
else
if
((
occupancy
>
1
)
&&
(
full_dispatches
%
occupancy
==
occupancy
-
1
))
{
// e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
// occupancy = 3, full_dispatches = 5, 8, 11 ...
// occupancy = 4, full_dispatches = 7, 11 ...
sk_occupancy
=
1
;
// left 1 slot for sk occupancy
dp_tiles
=
full_dispatch_tiles
;
sk_tiles
=
partial_dispatche_tiles
;
}
else
{
// others, we reduce 1 dispatch from dp, together with partial dispatch,
// to construct sk dispatch
sk_occupancy
=
occupancy
-
((
full_dispatches
-
1
)
%
occupancy
);
dp_tiles
=
full_dispatch_tiles
-
num_cu
;
sk_tiles
=
partial_dispatche_tiles
+
num_cu
;
}
// dp_num_blocks = dp_tiles;
// dp_start_block_idx = num_cu * sk_occupancy;
dp_iters_per_block
=
k_iters_per_tile
.
get
();
sk_total_iters
=
k_iters_per_tile
.
get
()
*
sk_tiles
;
// printf("num_tiles:%d, full_dispatches:%d, full_dispatch_tiles:%d,
// partial_dispatche_tiles:%d\n",
// num_tiles, full_dispatches, full_dispatch_tiles, partial_dispatche_tiles);
{
uint32_t
min_sk_tiles
=
(
sk_tiles
>=
num_cu
)
?
num_cu
:
(
sk_tiles
+
1
);
uint32_t
max_sk_tiles
=
(
sk_tiles
>=
num_cu
)
?
num_cu
*
sk_occupancy
:
math
::
min
(
num_cu
,
sk_total_iters
/
min_k_iters_per_sk_block
);
// if use dp for sk-block, how many iters do we need
uint32_t
dp_for_sk_iters
=
k_iters_per_tile
.
get
();
uint32_t
best_sk_score
=
std
::
numeric_limits
<
int
>::
max
();
// we need to find the smallest sk iters
for
(
uint32_t
tentative_sk_blocks
=
min_sk_tiles
;
tentative_sk_blocks
<
max_sk_tiles
;
tentative_sk_blocks
++
)
{
uint32_t
tentative_sk_iters_per_block
=
(
sk_total_iters
+
tentative_sk_blocks
-
1
)
/
tentative_sk_blocks
;
uint32_t
tentative_sk_iters
=
tentative_sk_iters_per_block
;
uint32_t
sk_blocks_per_tile
=
(
tentative_sk_blocks
+
sk_tiles
-
1
)
/
sk_tiles
;
// TODO: carefully adjust this parameter
// the more sk_blocks_per_tile, the worse the overhead
uint32_t
cross_sk_blocks_overhead
=
sk_blocks_per_tile
;
if
(
tentative_sk_blocks
%
sk_tiles
!=
0
)
{
// penalty for uneven divide
cross_sk_blocks_overhead
+=
sk_blocks_per_tile
*
tentative_sk_iters_per_block
/
50
;
}
uint32_t
tentative_sk_score
=
tentative_sk_iters
+
cross_sk_blocks_overhead
;
if
(
tentative_sk_score
<
best_sk_score
)
{
best_sk_score
=
tentative_sk_score
;
sk_num_blocks
=
tentative_sk_blocks
;
}
}
if
(
best_sk_score
>=
dp_for_sk_iters
)
{
sk_num_blocks
=
0
;
}
if
(
sk_num_blocks
==
0
)
{
sk_num_big_blocks
=
0
;
k_iters_per_big_block
=
0
;
dp_num_blocks
=
num_tiles
;
// all tile to be dp block
dp_start_block_idx
=
0
;
sk_total_iters
=
0
;
// clear this tiles
}
else
{
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
// we need to decide how many iters for each sk block
// let m = k_iters_per_sk_block
// some of the sk block (little) will cover m iters, some (big) will cover m+1
// we have
// 1) l + b = sk_blocks
// 2) l * m + b * (m + 1) = sk_total_iters
// => (l + b) * m + b = sk_total_iters
// => sk_blocks * m + b = sk_total_iters
// => b = sk_total_iters - m * sk_blocks
// NOTE: big could be zero
uint32_t
k_iters_per_sk_block
=
sk_total_iters
/
sk_num_blocks
;
sk_num_big_blocks
=
sk_total_iters
-
k_iters_per_sk_block
*
sk_num_blocks
;
k_iters_per_big_block
=
k_iters_per_sk_block
+
1
;
dp_num_blocks
=
dp_tiles
;
dp_start_block_idx
=
(
sk_num_blocks
+
num_cu
-
1
)
/
num_cu
*
num_cu
;
}
}
n_tiles
=
MDiv
(
math
::
integer_divide_ceil
(
n
,
NPerBlock
));
printf
(
"cu:%d, occupancy:%d, grids:%d, sk_num_big_blocks:%d, sk_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, "
"k_iters_per_tile:%d, k_iters_per_big_block:%d
\n
"
,
num_cu
,
occupancy
,
get_grid_dims
().
x
,
sk_num_big_blocks
,
sk_num_blocks
,
sk_total_iters
,
dp_start_block_idx
,
dp_iters_per_block
,
dp_num_blocks
,
k_iters_per_tile
.
get
(),
k_iters_per_big_block
);
}
__host__
__device__
dim3
get_grid_dims
()
const
{
return
dim3
(
dp_start_block_idx
+
dp_num_blocks
,
1
,
1
);
}
__device__
uint32_t
get_block_idx
()
const
{
// TODO: swizzle block index for better locality
return
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
);
}
__device__
void
get_block_itr
(
uint32_t
block_idx
,
uint32_t
&
iter_start
,
uint32_t
&
iter_end
)
const
{
if
(
block_idx
<
sk_num_big_blocks
)
{
iter_start
=
block_idx
*
k_iters_per_big_block
;
iter_end
=
iter_start
+
k_iters_per_big_block
;
}
else
if
(
block_idx
<
sk_num_blocks
)
{
iter_start
=
(
sk_num_big_blocks
*
k_iters_per_big_block
)
+
(
block_idx
-
sk_num_big_blocks
)
*
(
k_iters_per_big_block
-
1
);
iter_end
=
iter_start
+
(
k_iters_per_big_block
-
1
);
}
else
if
(
block_idx
>=
dp_start_block_idx
)
{
iter_start
=
sk_total_iters
+
(
block_idx
-
dp_start_block_idx
)
*
dp_iters_per_block
;
iter_end
=
iter_start
+
dp_iters_per_block
;
}
}
__device__
uint32_t
get_current_iter_length
(
uint32_t
iter_start
,
uint32_t
iter_end
,
uint32_t
total_iter_length
)
const
{
uint32_t
iter_length_mod
,
iter_length_quo
/*unused*/
;
k_iters_per_tile
.
divmod
(
iter_end
,
iter_length_quo
,
iter_length_mod
);
uint32_t
current_iter_length
=
math
::
min
(
iter_length_mod
==
0
?
(
iter_end
-
iter_start
)
:
iter_length_mod
,
total_iter_length
);
return
current_iter_length
;
}
__device__
uint32_t
get_tile_idx
(
uint32_t
iter
)
const
{
return
k_iters_per_tile
.
div
(
iter
);
}
__device__
void
get_tile_idx_with_offset
(
uint32_t
iter
,
uint32_t
&
tile_idx
,
uint32_t
&
iter_offset
)
const
{
k_iters_per_tile
.
divmod
(
iter
,
tile_idx
,
iter_offset
);
}
__device__
auto
tile_to_spatial
(
uint32_t
tile_idx
)
const
{
// TODO:
uint32_t
m_tile_idx
,
n_tile_idx
;
n_tiles
.
divmod
(
tile_idx
,
m_tile_idx
,
n_tile_idx
);
return
make_tuple
(
m_tile_idx
,
n_tile_idx
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp
0 → 100644
View file @
1a2a1679
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
namespace
ck
{
struct
GridwiseGemmPipeline_v3
{
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
)
{
// TODO: improve applicability
return
true
;
}
template
<
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
// global read 0
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// LDS write 0
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
num_loop
--
;
while
(
num_loop
>
0
)
{
block_sync_lds
();
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
num_loop
--
;
}
// tail
{
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
0 → 100644
View file @
1a2a1679
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_xdlops_streamk
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
GridwiseGemm
::
Run
(
karg
,
static_cast
<
void
*>
(
p_shared
));
#else
ignore
=
karg
;
ignore
=
b2c_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
index_t
BlockSize
,
typename
Block2CTileMap_
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
K1Value
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXDL
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
struct
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
M01
=
1
;
static
constexpr
auto
N01
=
1
;
static
constexpr
auto
KPerBlock
=
K0PerBlock
*
K1
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
Block2CTileMap
=
Block2CTileMap_
;
struct
Argument
:
public
ck
::
tensor_operation
::
device
::
BaseArgument
{
const
FloatAB
*
p_a_grid
;
const
FloatAB
*
p_b_grid
;
FloatC
*
p_c_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
Block2CTileMap
block_mapping
;
Argument
(
const
FloatAB
*
p_a_grid_
,
const
FloatAB
*
p_b_grid_
,
FloatC
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
,
uint32_t
num_cu
,
uint32_t
occupancy
)
:
p_a_grid
(
p_a_grid_
),
p_b_grid
(
p_b_grid_
),
p_c_grid
(
p_c_grid_
),
M
(
M_
),
N
(
N_
),
K
(
K_
),
StrideA
(
StrideA_
),
StrideB
(
StrideB_
),
StrideC
(
StrideC_
),
block_mapping
(
M
,
N
,
K
,
num_cu
,
occupancy
)
{
}
void
Print
()
const
{
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
std
::
endl
;
}
};
__host__
__device__
static
auto
CalculateGridSize
(
const
Argument
&
karg
)
{
return
std
::
make_tuple
(
math
::
integer_divide_ceil
(
karg
.
N
,
NPerBlock
),
math
::
integer_divide_ceil
(
karg
.
M
,
MPerBlock
),
karg
.
k_batch
);
}
#if 0
// prefer this to be called on host
__host__ __device__ static auto CalculateMPadded(index_t M)
{
return (M + MPerBlock - 1) / MPerBlock * MPerBlock;
}
__host__ __device__ static auto CalculateNPadded(index_t N)
{
return (N + NPerBlock - 1) / NPerBlock * NPerBlock;
}
__host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1)
{
// k_batch * k0 * k0_per_block * k1
auto K_t = K_Batch * K0PerBlock * K1;
return (K + K_t - 1) / K_t * K0PerBlock;
}
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{
auto K0 = CalculateK0(K, K_Batch);
return K_Batch * K0 * K1;
}
#endif
__host__
__device__
static
auto
CalculateK0
(
index_t
KPad
)
{
return
KPad
/
K1
;
}
__host__
__device__
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
KPad
,
index_t
StrideA
)
{
const
index_t
K0
=
CalculateK0
(
KPad
);
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
__host__
__device__
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
KPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
)
{
const
index_t
K0
=
CalculateK0
(
KPad
);
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
}
}();
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPad
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
MPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideC
)
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}();
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
K1
,
K1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
K1
,
K1
,
I1
));
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
auto
max_lds_align
=
K1
;
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b_block_desc_k0_n_k1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
c_block_size
=
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
().
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
),
c_block_size
*
sizeof
(
FloatC
));
}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
return
false
;
}
else
{
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
return
false
;
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
return
false
;
}
else
{
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
return
false
;
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
if
(
karg
.
N
%
CBlockTransferScalarPerVector_NWaveNPerXDL
!=
0
)
return
false
;
}
else
{
if
(
karg
.
M
%
CBlockTransferScalarPerVector_NWaveNPerXDL
!=
0
)
return
false
;
}
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
const
bool
has_main_k0_block_loop
=
K0
>
K0PerBlock
;
return
has_main_k0_block_loop
;
}
template
<
typename
CGridDesc
>
__host__
__device__
static
constexpr
auto
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc
&
c_m_n_grid_desc
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
return
transform_tensor_descriptor
(
c_m_n_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
}
// return block_id to C matrix tile idx (m0, n0) mapping
template
<
typename
CGridDesc
>
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptor
(
const
CGridDesc
&
c_m_n_grid_desc
,
index_t
/* M01 */
,
index_t
/* N01 */
,
index_t
KBatch
)
{
return
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc
>
(
c_m_n_grid_desc
,
8
,
KBatch
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
>
{},
I1
,
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
{}));
}
// return block_id to C matrix tile idx (m0, n0, k_split) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
()
{
return
BlockToCTileMap_3DGrid_KSplit
<
MPerBlock
,
NPerBlock
>
();
}
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
,
1
,
1
))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
())
>
;
__device__
static
void
Run
(
const
Argument
&
karg
,
void
*
__restrict__
p_shared_block
)
{
const
FloatAB
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatAB
*
p_b_grid
=
karg
.
p_b_grid
;
FloatC
*
p_c_grid
=
karg
.
p_c_grid
;
uint32_t
m
=
karg
.
M
;
uint32_t
n
=
karg
.
N
;
uint32_t
k
=
karg
.
K
;
uint32_t
pad_m
=
(
m
+
MPerBlock
-
1
)
/
MPerBlock
*
MPerBlock
;
uint32_t
pad_n
=
(
n
+
NPerBlock
-
1
)
/
NPerBlock
*
NPerBlock
;
uint32_t
pad_k
=
(
k
+
KPerBlock
-
1
)
/
KPerBlock
*
KPerBlock
;
uint32_t
stride_a
=
karg
.
StrideA
;
uint32_t
stride_b
=
karg
.
StrideB
;
uint32_t
stride_c
=
karg
.
StrideC
;
const
auto
a_k0_m_k1_grid_desc
=
MakeAGridDescriptor_K0_M_K1
(
m
,
pad_m
,
k
,
pad_k
,
stride_a
);
const
auto
b_k0_n_k1_grid_desc
=
MakeBGridDescriptor_K0_N_K1
(
k
,
pad_k
,
n
,
pad_n
,
stride_b
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
m
,
pad_m
,
n
,
pad_n
,
stride_c
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
const
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{};
const
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{};
const
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{};
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_k0_m_k1_grid_desc
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_k0_n_k1_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// lds max alignment
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
K1
>
{};
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block
=
static_cast
<
FloatAB
*>
(
p_shared_block
);
FloatAB
*
p_b_block
=
static_cast
<
FloatAB
*>
(
p_shared_block
)
+
a_block_space_size
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block
,
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block
,
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v3
();
auto
&
block_mapping
=
karg
.
block_mapping
;
uint32_t
block_idx
=
block_mapping
.
get_block_idx
();
bool
is_sk_block
=
block_idx
<
block_mapping
.
sk_num_blocks
;
bool
is_dp_block
=
block_idx
>=
block_mapping
.
dp_start_block_idx
;
uint32_t
iter_start
,
iter_end
;
block_mapping
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
uint32_t
total_iter_length
=
iter_end
-
iter_start
;
// if(threadIdx.x == 0)
// printf("xxx bid:%d\n", static_cast<int>(blockIdx.x));
if
(
!
is_sk_block
&&
!
is_dp_block
)
return
;
while
(
true
)
{
uint32_t
current_iter_length
=
__builtin_amdgcn_readfirstlane
(
block_mapping
.
get_current_iter_length
(
iter_start
,
iter_end
,
total_iter_length
));
uint32_t
tile_idx
,
iter_offset
;
block_mapping
.
get_tile_idx_with_offset
(
iter_end
-
1
,
tile_idx
,
iter_offset
);
iter_offset
=
__builtin_amdgcn_readfirstlane
(
iter_offset
-
current_iter_length
+
1
);
auto
spatial_idx
=
block_mapping
.
tile_to_spatial
(
tile_idx
);
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I1
]
*
NPerBlock
);
const
index_t
k0_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
iter_offset
*
K0PerBlock
);
// if(threadIdx.x == 0)
// printf("[%s], bid:%d, block_idx:%d, tile_idx:%d(%d, %d, %d), iter_start:%d(%d |
// %d), iter_end:%d, len:%d\n",
// is_sk_block ? "sk_block" : (is_dp_block ? "dp_block" : "other "),
// static_cast<int>(blockIdx.x), block_idx, tile_idx, m_block_data_idx_on_grid,
// n_block_data_idx_on_grid, k0_block_data_idx_on_grid, iter_end -
// current_iter_length, iter_offset, iter_start, iter_end, current_iter_length);
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
a_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_k0_m_k1_grid_desc
,
make_multi_index
(
k0_block_data_idx_on_grid
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_k0_m_k1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
b_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_k0_n_k1_grid_desc
,
make_multi_index
(
k0_block_data_idx_on_grid
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
const
index_t
num_k_block_main_loop
=
current_iter_length
;
gridwise_gemm_pipeline
.
Run
(
a_k0_m_k1_grid_desc
,
a_block_desc_k0_m_k1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_k0_n_k1_grid_desc
,
b_block_desc_k0_n_k1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
num_k_block_main_loop
);
// output: register to global memory
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I7
);
constexpr
auto
c_block_desc_mblock_mperblock_nblock_nperblock
=
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatC
*>
(
p_shared_block
),
c_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
// freeze mblock
make_unmerge_transform
(
make_tuple
(
CShuffleMRepeatPerShuffle
,
M1
,
M2
,
M3
,
M4
)),
// M1 = MWave, M2 * M3 * M4 = MPerXDL
make_freeze_transform
(
I0
),
// freeze nblock
make_unmerge_transform
(
make_tuple
(
CShuffleNRepeatPerShuffle
,
N1
,
N2
))),
// M1 = MWave, M2 * M3 * M4 = MPerXDL
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// LDS to global
auto
c_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1r2
<
ThisThreadBlock
,
// index_t BlockSize,
CElementwiseOperation
,
// ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence
<
1
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
,
1
,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
,
// BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatC
,
// typename SrcData,
FloatC
,
// typename DstData,
decltype
(
c_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun
{
c_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I0
]),
0
,
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I1
]),
0
),
c_element_op
};
constexpr
auto
mxdlperwave_forward_step
=
make_multi_index
(
0
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
,
0
,
0
);
constexpr
auto
nxdlperwave_forward_step
=
make_multi_index
(
0
,
0
,
0
,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
);
constexpr
auto
nxdlperwave_backward_step
=
make_multi_index
(
0
,
0
,
0
,
-
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
);
static_for
<
0
,
MRepeat
,
CShuffleMRepeatPerShuffle
>
{}([
&
](
auto
mxdlperwave_iter
)
{
constexpr
auto
mxdlperwave
=
mxdlperwave_iter
;
static_for
<
0
,
NRepeat
,
CShuffleNRepeatPerShuffle
>
{}([
&
](
auto
nxdlperwave_iter
)
{
constexpr
bool
nxdlperwave_forward_sweep
=
(
mxdlperwave
%
(
2
*
CShuffleMRepeatPerShuffle
)
==
0
);
constexpr
index_t
nxdlperwave_value
=
nxdlperwave_forward_sweep
?
nxdlperwave_iter
:
(
NRepeat
-
nxdlperwave_iter
-
CShuffleNRepeatPerShuffle
);
constexpr
auto
nxdlperwave
=
Number
<
nxdlperwave_value
>
{};
// make sure it's safe to do ds_write
block_sync_lds
();
// VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
,
make_tuple
(
mxdlperwave
,
nxdlperwave
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_block_buf
);
// make sure it's safe to do ds_read
block_sync_lds
();
// LDS to global
if
(
is_dp_block
)
c_block_copy_lds_to_global
.
template
Run
<
decltype
(
c_block_buf
),
decltype
(
c_grid_buf
),
InMemoryDataOperationEnum
::
Set
>(
c_block_desc_mblock_mperblock_nblock_nperblock
,
c_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
else
if
(
is_sk_block
)
c_block_copy_lds_to_global
.
template
Run
<
decltype
(
c_block_buf
),
decltype
(
c_grid_buf
),
InMemoryDataOperationEnum
::
AtomicAdd
>(
c_block_desc_mblock_mperblock_nblock_nperblock
,
c_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
// move on nxdlperwave dimension
if
constexpr
(
nxdlperwave_forward_sweep
&&
(
nxdlperwave
<
NRepeat
-
CShuffleNRepeatPerShuffle
))
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
nxdlperwave_forward_step
);
}
else
if
constexpr
((
!
nxdlperwave_forward_sweep
)
&&
(
nxdlperwave
>
0
))
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
nxdlperwave_backward_step
);
}
});
// move on mxdlperwave dimension
if
constexpr
(
mxdlperwave
<
MRepeat
-
CShuffleMRepeatPerShuffle
)
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
mxdlperwave_forward_step
);
}
});
}
// exit condition
iter_end
-=
current_iter_length
;
if
(
iter_end
<=
iter_start
)
break
;
// make sure next loop LDS is ready for use
block_sync_lds
();
}
// if(threadIdx.x == 0)
// printf("xxx bid:%d, xx_total_iter_length:%d \n", static_cast<int>(blockIdx.x),
// xx_total_iter_length);
}
template
<
typename
Layout
>
struct
LStr
{
static
std
::
string
Get
()
{
return
""
;
}
};
template
<
>
struct
LStr
<
ck
::
tensor_layout
::
gemm
::
RowMajor
>
{
static
std
::
string
Get
()
{
return
"R"
;
}
};
template
<
>
struct
LStr
<
ck
::
tensor_layout
::
gemm
::
ColumnMajor
>
{
static
std
::
string
Get
()
{
return
"C"
;
}
};
static
std
::
string
GetTypeString
()
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"GemmXdlStreamK_"
<<
std
::
string
(
ALayout
::
name
)[
0
]
<<
std
::
string
(
BLayout
::
name
)[
0
]
<<
std
::
string
(
CLayout
::
name
)[
0
]
<<
"_"
<<
"B"
<<
BlockSize
<<
"_"
<<
"Vec"
<<
ABlockTransferSrcScalarPerVector
<<
"x"
<<
BBlockTransferSrcScalarPerVector
<<
"x"
<<
CBlockTransferScalarPerVector_NWaveNPerXDL
<<
"_"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
K0PerBlock
<<
"x"
<<
K1
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp
0 → 100644
View file @
1a2a1679
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
namespace
ck
{
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions:
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
// instead
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
// tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
bool
SrcResetCoordinateAfterRun
,
bool
DstResetCoordinateAfterRun
>
struct
ThreadwiseTensorSliceTransfer_v6r1r2
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
SrcCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
static
constexpr
auto
I0
=
Number
<
0
>
{};
__device__
constexpr
ThreadwiseTensorSliceTransfer_v6r1r2
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin
,
const
ElementwiseOperation
&
element_op
)
:
src_coord_
(
make_tensor_coordinate
(
src_desc
,
src_slice_origin
)),
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin
)),
element_op_
(
element_op
)
{
static_assert
(
SliceLengths
::
At
(
Number
<
VectorDim
>
{})
%
ScalarPerVector
==
0
,
"wrong! cannot evenly divide"
);
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
{
dst_coord_
=
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
InMemoryDataOperationEnum
DstInMemOp
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
VectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
remove_cv_t
<
decltype
(
scalar_per_access
)
>>
;
// loop over space-filling curve
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
idx_1d
)
{
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
ScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
ScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
// copy data from src_buf into src_vector_container
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
auto
dst_vector_container
=
dst_vector_type
{};
// apply pointwise operation
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
SrcData
v
;
// apply element-wise operation
element_op_
(
v
,
src_vector_container
.
template
AsType
<
SrcData
>()[
i
]);
// apply type convert
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
v
);
});
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
// copy data from dst_vector into dst_buf
dst_buf
.
template
Update
<
DstInMemOp
,
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
// move coordinate
if
constexpr
(
idx_1d
.
value
!=
num_access
-
1
)
{
constexpr
auto
forward_step
=
SpaceFillingCurve
::
GetForwardStep
(
idx_1d
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
make_tensor_coordinate_step
(
src_desc
,
forward_step
));
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
make_tensor_coordinate_step
(
dst_desc
,
forward_step
));
}
});
// move coordinate back to slice origin (or not)
if
constexpr
(
SrcResetCoordinateAfterRun
)
{
const
auto
src_reset_step
=
make_tensor_coordinate_step
(
src_desc
,
GetCoordinateResetStep
());
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_reset_step
);
}
if
constexpr
(
DstResetCoordinateAfterRun
)
{
const
auto
dst_reset_step
=
make_tensor_coordinate_step
(
dst_desc
,
GetCoordinateResetStep
());
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_reset_step
);
}
}
__device__
static
constexpr
auto
GetCoordinateResetStep
()
{
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
VectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
remove_cv_t
<
decltype
(
scalar_per_access
)
>>
;
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
if
constexpr
(
num_access
==
0
)
{
return
typename
SpaceFillingCurve
::
Index
{};
}
else
{
constexpr
auto
reset_step
=
SpaceFillingCurve
::
GetStepBetween
(
Number
<
num_access
-
1
>
{},
Number
<
0
>
{});
return
reset_step
;
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRun
?
src_slice_origin_step_idx
:
src_slice_origin_step_idx
+
GetCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
src_desc
,
adjusted_step_idx
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_step_idx
)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const
auto
adjusted_step_idx
=
DstResetCoordinateAfterRun
?
dst_slice_origin_step_idx
:
dst_slice_origin_step_idx
+
GetCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
dst_desc
,
adjusted_step_idx
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
adjusted_step
);
}
private:
SrcCoord
src_coord_
;
DstCoord
dst_coord_
;
const
ElementwiseOperation
element_op_
;
};
}
// namespace ck
include/ck/utility/magic_division.hpp
View file @
1a2a1679
...
...
@@ -157,4 +157,42 @@ struct MagicDivision
}
};
struct
MDiv
{
// 1 dword -> 3 dword storage
uint32_t
divisor
;
uint32_t
multiplier
;
uint32_t
shift
;
// prefer construct on host
__host__
__device__
MDiv
(
uint32_t
divisor_
)
:
divisor
(
divisor_
)
{
ck
::
tie
(
multiplier
,
shift
)
=
MagicDivision
::
CalculateMagicNumbers
(
divisor_
);
}
__host__
__device__
MDiv
()
:
divisor
(
0
),
multiplier
(
0
),
shift
(
0
)
{}
__host__
__device__
void
update
(
uint32_t
divisor_
)
{
divisor
=
divisor_
;
ck
::
tie
(
multiplier
,
shift
)
=
MagicDivision
::
CalculateMagicNumbers
(
divisor_
);
}
__host__
__device__
uint32_t
div
(
uint32_t
dividend
)
const
{
return
MagicDivision
::
DoMagicDivision
(
dividend
,
multiplier
,
shift
);
}
__host__
__device__
void
divmod
(
uint32_t
dividend
,
uint32_t
&
quotient
,
uint32_t
&
remainder
)
const
{
quotient
=
div
(
dividend
);
remainder
=
dividend
-
(
quotient
*
divisor
);
}
__host__
__device__
uint32_t
operator
/
(
uint32_t
dividend
)
const
{
return
div
(
dividend
);
}
__host__
__device__
uint32_t
get
()
const
{
return
divisor
;
}
};
}
// namespace ck
library/include/ck/library/utility/check_err.hpp
View file @
1a2a1679
...
...
@@ -55,7 +55,7 @@ check_err(const Range& out,
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
>
0
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
...
@@ -102,7 +102,7 @@ check_err(const Range& out,
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
>
0
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
...
@@ -148,7 +148,7 @@ check_err(const Range& out,
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
>
0
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
...
@@ -199,7 +199,7 @@ check_err(const Range& out,
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
>
0
)
{
std
::
cerr
<<
msg
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
...
test/block_swizzle_test/block_swizzle_test.cpp
0 → 100644
View file @
1a2a1679
#include <stdio.h>
#include <string>
#include <algorithm>
#include <vector>
#include <limits>
#include "simple_args.h"
simple_args_t
create_arg
(
int
argc
,
char
**
argv
)
{
simple_args_t
args
;
args
.
insert
(
"m"
,
"1024"
,
"matrix m"
)
.
insert
(
"n"
,
"1024"
,
"matrix n"
)
.
insert
(
"k"
,
"1024"
,
"matrix k"
)
.
insert
(
"m_per_block"
,
"128"
,
"m_per_block"
)
.
insert
(
"n_per_block"
,
"128"
,
"n_per_block"
)
.
insert
(
"k_per_block"
,
"32"
,
"k_per_block"
)
.
insert
(
"num_cu"
,
"104"
,
"num cu"
)
.
insert
(
"occupancy"
,
"2"
,
"occupancy"
)
.
parse
(
argc
,
argv
);
return
args
;
}
namespace
impl
{
template
<
typename
T
>
T
integer_divide_ceil
(
T
n
,
T
d
)
{
return
(
n
+
d
-
1
)
/
d
;
}
template
<
typename
T
>
T
min
(
T
a
,
T
b
)
{
return
a
>
b
?
b
:
a
;
}
template
<
typename
T
>
T
max
(
T
a
,
T
b
)
{
return
a
>
b
?
a
:
b
;
}
}
// namespace impl
struct
block_dispatcher_t
{
public:
uint32_t
m_per_block
;
uint32_t
n_per_block
;
uint32_t
k_per_block
;
uint32_t
num_cu
;
uint32_t
occupancy
;
uint32_t
m
;
uint32_t
n
;
uint32_t
k
;
//--------------------------------------
uint32_t
sk_num_blocks
;
uint32_t
sk_num_big_blocks
;
uint32_t
sk_total_iters
;
// uint32_t sk_num_blocks_per_tile; // how many
uint32_t
dp_start_block_idx
;
uint32_t
dp_iters_per_block
;
uint32_t
dp_num_blocks
;
uint32_t
k_iters_per_tile
;
uint32_t
k_iters_per_big_block
;
//--------------------------------------
static
constexpr
uint32_t
min_k_iters_per_sk_block
=
1
;
void
dump
()
{
printf
(
"%dx%dx%d(%dx%dx%d), cu:%d, occ:%d, grids:%d, sk_num_big_blocks:%d, "
"sk_num_blocks:%d, sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, "
"dp_num_blocks:%d, k_iters_per_tile:%d, k_iters_per_big_block:%d
\n
"
,
m
,
n
,
k
,
m_per_block
,
n_per_block
,
k_per_block
,
num_cu
,
occupancy
,
get_grid_dims_x
(),
sk_num_big_blocks
,
sk_num_blocks
,
sk_total_iters
,
dp_start_block_idx
,
dp_iters_per_block
,
dp_num_blocks
,
k_iters_per_tile
,
k_iters_per_big_block
);
}
block_dispatcher_t
(
uint32_t
m_per_block_
,
uint32_t
n_per_block_
,
uint32_t
k_per_block_
,
uint32_t
num_cu_
,
uint32_t
occupancy_
,
uint32_t
m_
,
uint32_t
n_
,
uint32_t
k_
)
:
m_per_block
(
m_per_block_
),
n_per_block
(
n_per_block_
),
k_per_block
(
k_per_block_
),
num_cu
(
num_cu_
),
occupancy
(
occupancy_
),
m
(
m_
),
n
(
n_
),
k
(
k_
)
{
init
();
}
uint32_t
get_grid_dims_x
()
{
return
dp_start_block_idx
+
dp_num_blocks
;
}
uint32_t
get_block_idx
(
uint32_t
bid
)
{
// block id is linearily allocated along sk blocks (dp blocks are fine)
// this function will compute blockIdx.x and the linear sk block mapping
// uint32_t block_idx = 0;
// if(bid < sk_num_big_blocks) {
// uint32_t current_k_iter = bid * k_iters_per_big_block;
// tile_idx = current_k_iter / k_iters_per_tile;
// }
return
bid
;
}
uint32_t
get_current_itr
(
uint32_t
block_idx
)
{
uint32_t
current_itr
=
0
;
if
(
block_idx
<
sk_num_big_blocks
)
{
current_itr
=
block_idx
*
k_iters_per_big_block
;
}
else
if
(
block_idx
<
sk_num_blocks
)
{
current_itr
=
(
sk_num_big_blocks
*
k_iters_per_big_block
)
+
(
block_idx
-
sk_num_big_blocks
)
*
(
k_iters_per_big_block
-
1
);
}
else
if
(
block_idx
>=
dp_start_block_idx
)
{
current_itr
=
sk_total_iters
+
(
block_idx
-
dp_start_block_idx
)
*
dp_iters_per_block
;
}
return
current_itr
;
}
void
get_block_itr
(
uint32_t
block_idx
,
uint32_t
&
iter_start
,
uint32_t
&
iter_end
)
{
if
(
block_idx
<
sk_num_big_blocks
)
{
iter_start
=
block_idx
*
k_iters_per_big_block
;
iter_end
=
iter_start
+
k_iters_per_big_block
;
}
else
if
(
block_idx
<
sk_num_blocks
)
{
iter_start
=
(
sk_num_big_blocks
*
k_iters_per_big_block
)
+
(
block_idx
-
sk_num_big_blocks
)
*
(
k_iters_per_big_block
-
1
);
iter_end
=
iter_start
+
(
k_iters_per_big_block
-
1
);
}
else
if
(
block_idx
>=
dp_start_block_idx
)
{
iter_start
=
sk_total_iters
+
(
block_idx
-
dp_start_block_idx
)
*
dp_iters_per_block
;
iter_end
=
iter_start
+
dp_iters_per_block
;
}
}
private:
void
init
()
{
uint32_t
num_tiles
=
impl
::
integer_divide_ceil
(
m
,
m_per_block
)
*
impl
::
integer_divide_ceil
(
n
,
n_per_block
);
k_iters_per_tile
=
impl
::
integer_divide_ceil
(
k
,
k_per_block
);
// one cu can hold one wg at one time, from the whole chip's point of view
// if number of wg is same as num_cu, we call it 1 dispatch
// if number of wg is 2x num_cu, we call it 2 dispatches.
// one dispatch can deliever wg same as num_cu (full dispatch), or less than num_cu (partial
// dispatch)
//
uint32_t
full_dispatches
=
num_tiles
/
num_cu
;
uint32_t
full_dispatch_tiles
=
full_dispatches
*
num_cu
;
uint32_t
partial_dispatche_tiles
=
num_tiles
-
full_dispatch_tiles
;
uint32_t
sk_occupancy
=
occupancy
;
uint32_t
dp_tiles
=
full_dispatch_tiles
;
uint32_t
sk_tiles
=
partial_dispatche_tiles
;
if
(
full_dispatches
<
occupancy
)
{
// in this case, we allocate all blocks as sk blocks
// sk_occupancy = occupancy - full_dispatches;
sk_occupancy
=
1
;
// TODO: single occ seems better
dp_tiles
=
full_dispatch_tiles
;
sk_tiles
=
partial_dispatche_tiles
;
}
else
if
((
occupancy
>
1
)
&&
(
full_dispatches
%
occupancy
==
occupancy
-
1
))
{
// e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
// occupancy = 3, full_dispatches = 5, 8, 11 ...
// occupancy = 4, full_dispatches = 7, 11 ...
sk_occupancy
=
1
;
// left 1 slot for sk occupancy
dp_tiles
=
full_dispatch_tiles
;
sk_tiles
=
partial_dispatche_tiles
;
}
else
{
// others, we reduce 1 dispatch from dp, together with partial dispatch,
// to construct sk dispatch
sk_occupancy
=
occupancy
-
((
full_dispatches
-
1
)
%
occupancy
);
dp_tiles
=
full_dispatch_tiles
-
num_cu
;
sk_tiles
=
partial_dispatche_tiles
+
num_cu
;
}
// dp_num_blocks = dp_tiles;
// dp_start_block_idx = num_cu * sk_occupancy;
dp_iters_per_block
=
k_iters_per_tile
;
sk_total_iters
=
k_iters_per_tile
*
sk_tiles
;
// printf("num_tiles:%d, full_dispatches:%d, full_dispatch_tiles:%d,
// partial_dispatche_tiles:%d\n",
// num_tiles, full_dispatches, full_dispatch_tiles, partial_dispatche_tiles);
{
uint32_t
min_sk_tiles
=
(
sk_tiles
>=
num_cu
)
?
num_cu
:
(
sk_tiles
+
1
);
uint32_t
max_sk_tiles
=
(
sk_tiles
>=
num_cu
)
?
num_cu
*
sk_occupancy
:
impl
::
min
(
num_cu
,
sk_total_iters
/
min_k_iters_per_sk_block
);
// if use dp for sk-block, how many iters do we need
uint32_t
dp_for_sk_iters
=
k_iters_per_tile
;
uint32_t
best_sk_score
=
std
::
numeric_limits
<
int
>::
max
();
// we need to find the smallest sk iters
for
(
uint32_t
tentative_sk_blocks
=
min_sk_tiles
;
tentative_sk_blocks
<
max_sk_tiles
;
tentative_sk_blocks
++
)
{
uint32_t
tentative_sk_iters_per_block
=
(
sk_total_iters
+
tentative_sk_blocks
-
1
)
/
tentative_sk_blocks
;
uint32_t
tentative_sk_iters
=
tentative_sk_iters_per_block
;
uint32_t
sk_blocks_per_tile
=
(
tentative_sk_blocks
+
sk_tiles
-
1
)
/
sk_tiles
;
// TODO: carefully adjust this parameter
// the more sk_blocks_per_tile, the worse the overhead
uint32_t
cross_sk_blocks_overhead
=
sk_blocks_per_tile
;
if
(
tentative_sk_blocks
%
sk_tiles
!=
0
)
{
// penalty for uneven divide
cross_sk_blocks_overhead
+=
sk_blocks_per_tile
*
tentative_sk_iters_per_block
/
50
;
}
uint32_t
tentative_sk_score
=
tentative_sk_iters
+
cross_sk_blocks_overhead
;
if
(
tentative_sk_score
<
best_sk_score
)
{
best_sk_score
=
tentative_sk_score
;
sk_num_blocks
=
tentative_sk_blocks
;
}
}
if
(
best_sk_score
>=
dp_for_sk_iters
)
{
sk_num_blocks
=
0
;
}
if
(
sk_num_blocks
==
0
)
{
sk_num_big_blocks
=
0
;
k_iters_per_big_block
=
0
;
dp_num_blocks
=
num_tiles
;
// all tile to be dp block
dp_start_block_idx
=
0
;
sk_total_iters
=
0
;
// clear this tiles
}
else
{
uint32_t
k_iters_per_sk_block
=
sk_total_iters
/
sk_num_blocks
;
sk_num_big_blocks
=
sk_total_iters
-
k_iters_per_sk_block
*
sk_num_blocks
;
k_iters_per_big_block
=
k_iters_per_sk_block
+
1
;
dp_num_blocks
=
dp_tiles
;
dp_start_block_idx
=
(
sk_num_blocks
+
num_cu
-
1
)
/
num_cu
*
num_cu
;
}
}
}
};
struct
tile_work_t
{
uint32_t
tile_idx
;
uint32_t
iter_begin
;
uint32_t
k_begin
;
uint32_t
k_end
;
uint32_t
k_iters_remaining
;
};
int
main
(
int
argc
,
char
**
argv
)
{
simple_args_t
arg
=
create_arg
(
argc
,
argv
);
block_dispatcher_t
block_dispatcher
{
arg
.
get_uint32
(
"m_per_block"
),
arg
.
get_uint32
(
"n_per_block"
),
arg
.
get_uint32
(
"k_per_block"
),
arg
.
get_uint32
(
"num_cu"
),
arg
.
get_uint32
(
"occupancy"
),
arg
.
get_uint32
(
"m"
),
arg
.
get_uint32
(
"n"
),
arg
.
get_uint32
(
"k"
)};
block_dispatcher
.
dump
();
// simulate actual kernel launch
uint32_t
dim_x
=
block_dispatcher
.
get_grid_dims_x
();
uint32_t
total_k_iters
=
impl
::
integer_divide_ceil
(
arg
.
get_uint32
(
"k"
),
arg
.
get_uint32
(
"k_per_block"
));
uint32_t
num_tiles
=
impl
::
integer_divide_ceil
(
arg
.
get_uint32
(
"m"
),
arg
.
get_uint32
(
"m_per_block"
))
*
impl
::
integer_divide_ceil
(
arg
.
get_uint32
(
"n"
),
arg
.
get_uint32
(
"n_per_block"
));
std
::
vector
<
int
>
valid_tile_record
(
num_tiles
*
total_k_iters
);
for
(
uint32_t
bid
=
0
;
bid
<
dim_x
;
bid
++
)
{
uint32_t
block_idx
=
block_dispatcher
.
get_block_idx
(
bid
);
bool
is_sk_block
=
block_idx
<
(
block_dispatcher
.
sk_num_blocks
);
bool
is_dp_block
=
block_idx
>=
block_dispatcher
.
dp_start_block_idx
;
uint32_t
iter_start
,
iter_end
;
block_dispatcher
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
uint32_t
total_iter_length
=
iter_end
-
iter_start
;
while
(
true
)
{
uint32_t
iter_length_mod
=
iter_end
%
block_dispatcher
.
k_iters_per_tile
;
uint32_t
current_iter_length
=
impl
::
min
(
iter_length_mod
==
0
?
(
iter_end
-
iter_start
)
:
iter_length_mod
,
total_iter_length
);
uint32_t
tile_idx
=
(
iter_end
-
1
)
/
block_dispatcher
.
k_iters_per_tile
;
uint32_t
tile_iter_start
=
((
iter_end
-
1
)
%
block_dispatcher
.
k_iters_per_tile
)
-
current_iter_length
+
1
;
if
(
is_sk_block
)
{
printf
(
"[sk_block] bid:%3d, block_idx:%3d, tile_idx:%3d, iter_start:%d(%d | %d), "
"iter_end:%d (len:%d)
\n
"
,
bid
,
block_idx
,
tile_idx
,
iter_end
-
current_iter_length
,
tile_iter_start
,
iter_start
,
iter_end
,
current_iter_length
);
}
else
if
(
is_dp_block
)
{
printf
(
"[dp_block] bid:%3d, block_idx:%3d, tile_idx:%3d, iter_start:%d(%d | %d), "
"iter_end:%d (len:%d)
\n
"
,
bid
,
block_idx
,
tile_idx
,
iter_end
-
current_iter_length
,
tile_iter_start
,
iter_start
,
iter_end
,
current_iter_length
);
}
else
{
printf
(
"[other ] bid:%3d, block_idx:%3d
\n
"
,
bid
,
block_idx
);
}
// some validation check
for
(
auto
i
=
iter_end
-
current_iter_length
;
i
<
iter_end
;
i
++
)
{
if
(
i
>=
valid_tile_record
.
size
())
{
printf
(
"unexpected, current iter:%d larger than max:%d
\n
"
,
i
,
valid_tile_record
.
size
());
return
-
1
;
}
valid_tile_record
[
i
]
=
1
;
}
iter_end
-=
current_iter_length
;
if
(
iter_end
<=
iter_start
)
break
;
}
}
int
untouched
=
0
;
for
(
auto
i
=
0
;
i
<
valid_tile_record
.
size
();
i
++
)
{
if
(
valid_tile_record
[
i
]
!=
1
)
{
printf
(
"untouched at %d (%d)
\n
"
,
i
,
valid_tile_record
.
size
());
untouched
++
;
}
}
printf
(
"untouched %d/%d, %s
\n
"
,
untouched
,
valid_tile_record
.
size
(),
untouched
==
0
?
"valid"
:
"fail"
);
}
test/block_swizzle_test/rebuild.sh
0 → 100644
View file @
1a2a1679
CC
=
g++
$CC
-Wall
-std
=
c++17
-Iinclude
-O3
block_swizzle_test.cpp
-o
block_swizzle_test.exe
\ No newline at end of file
test/block_swizzle_test/simple_args.h
0 → 100644
View file @
1a2a1679
#pragma once
#include <iomanip>
#include <iostream>
#include <stdlib.h>
#include <string>
#include <unordered_map>
#include <vector>
#include <assert.h>
struct
arg_content_t
{
std
::
string
name
;
// key
std
::
string
value
;
std
::
string
help_text
;
};
class
simple_args_t
{
public:
simple_args_t
()
{}
simple_args_t
&
insert
(
const
std
::
string
&
name_
,
const
std
::
string
&
default_value_
,
const
std
::
string
&
help_text_
)
{
arg_content_t
arg
{
name_
,
default_value_
,
help_text_
};
if
(
arg_map
.
count
(
arg
.
name
)
!=
0
)
{
std
::
cout
<<
"arg:"
<<
arg
.
name
<<
"already exist"
<<
std
::
endl
;
}
else
{
arg_map
[
arg
.
name
]
=
arg
;
}
return
*
this
;
}
void
usage
()
{
for
(
auto
&
content
:
arg_map
)
{
std
::
vector
<
std
::
string
>
help_text_lines
;
size_t
pos
=
0
;
for
(
size_t
next_pos
=
content
.
second
.
help_text
.
find
(
'\n'
,
pos
);
next_pos
!=
std
::
string
::
npos
;)
{
help_text_lines
.
push_back
(
std
::
string
(
content
.
second
.
help_text
.
begin
()
+
pos
,
content
.
second
.
help_text
.
begin
()
+
next_pos
++
));
pos
=
next_pos
;
next_pos
=
content
.
second
.
help_text
.
find
(
'\n'
,
pos
);
}
help_text_lines
.
push_back
(
std
::
string
(
content
.
second
.
help_text
.
begin
()
+
pos
,
content
.
second
.
help_text
.
end
()));
int
arg_name_width
=
16
-
content
.
second
.
name
.
length
();
arg_name_width
=
arg_name_width
>
0
?
arg_name_width
:
2
;
std
::
cout
<<
std
::
setw
(
4
)
<<
"-"
<<
content
.
second
.
name
<<
std
::
setw
(
arg_name_width
)
<<
" "
<<
help_text_lines
[
0
]
<<
std
::
endl
;
for
(
auto
help_next_line
=
std
::
next
(
help_text_lines
.
begin
());
help_next_line
!=
help_text_lines
.
end
();
++
help_next_line
)
{
std
::
cout
<<
std
::
setw
(
28
)
<<
" "
<<
*
help_next_line
<<
std
::
endl
;
}
}
}
bool
parse
(
int
argc
,
char
*
argv
[],
int
start_index
=
1
)
{
if
(
argc
<=
start_index
)
{
// std::cout << "not enough args (" << argc << ") with starting index " << start_index
// << std::endl;
return
true
;
}
for
(
int
i
=
start_index
;
i
<
argc
;
i
++
)
{
std
::
string
cur_arg
=
std
::
string
(
argv
[
i
]);
if
(
cur_arg
[
0
]
!=
'-'
)
{
std
::
cout
<<
"illegal input"
<<
std
::
endl
;
usage
();
return
false
;
}
else
if
(
cur_arg
[
0
]
==
'-'
&&
cur_arg
[
1
]
==
'?'
)
{
usage
();
return
false
;
}
else
{
size_t
found_equal
=
cur_arg
.
find
(
'='
);
if
(
found_equal
==
std
::
string
::
npos
||
found_equal
==
(
cur_arg
.
length
()
-
1
))
{
std
::
cout
<<
"failed while parsing
\"
"
<<
cur_arg
<<
"
\"
, "
<<
"arg must be in the form
\"
-name=value
\"
"
<<
std
::
endl
;
return
false
;
}
std
::
string
arg_name
=
cur_arg
.
substr
(
1
,
found_equal
-
1
);
std
::
string
arg_value
=
cur_arg
.
substr
(
found_equal
+
1
);
if
(
arg_map
.
count
(
arg_name
)
==
0
)
{
std
::
cout
<<
"no such arg
\"
"
<<
arg_name
<<
"
\"
registered"
<<
std
::
endl
;
return
false
;
}
arg_map
[
arg_name
].
value
=
arg_value
;
}
}
return
true
;
}
std
::
string
get
(
const
std
::
string
&
name
)
const
{
return
get_str
(
name
);
}
std
::
string
get_str
(
const
std
::
string
&
name
)
const
{
assert
(
arg_map
.
count
(
name
)
!=
0
);
std
::
string
value
=
arg_map
.
at
(
name
).
value
;
return
value
;
}
int
get_int
(
const
std
::
string
&
name
)
const
{
assert
(
arg_map
.
count
(
name
)
!=
0
);
int
value
=
atoi
(
arg_map
.
at
(
name
).
value
.
c_str
());
return
value
;
}
uint32_t
get_uint32
(
const
std
::
string
&
name
)
const
{
assert
(
arg_map
.
count
(
name
)
!=
0
);
uint32_t
value
=
strtoul
(
arg_map
.
at
(
name
).
value
.
c_str
(),
nullptr
,
10
);
return
value
;
}
uint64_t
get_uint64
(
const
std
::
string
&
name
)
const
{
assert
(
arg_map
.
count
(
name
)
!=
0
);
uint64_t
value
=
strtoull
(
arg_map
.
at
(
name
).
value
.
c_str
(),
nullptr
,
10
);
return
value
;
}
double
get_double
(
const
std
::
string
&
name
)
const
{
assert
(
arg_map
.
count
(
name
)
!=
0
);
double
value
=
atof
(
arg_map
.
at
(
name
).
value
.
c_str
());
return
value
;
}
float
get_float
(
const
std
::
string
&
name
)
const
{
assert
(
arg_map
.
count
(
name
)
!=
0
);
float
value
=
atof
(
arg_map
.
at
(
name
).
value
.
c_str
());
return
value
;
}
private:
std
::
unordered_map
<
std
::
string
,
arg_content_t
>
arg_map
;
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment