Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9bd44685
Commit
9bd44685
authored
Dec 09, 2022
by
aska-0096
Browse files
Correctness OK, waiting for optimization
parent
73959956
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
637 additions
and
101 deletions
+637
-101
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+9
-2
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+5
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+131
-5
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+42
-38
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+338
-34
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+26
-4
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+4
-7
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+35
-0
include/ck/utility/common_header.hpp
include/ck/utility/common_header.hpp
+21
-5
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+8
-4
library/include/ck/library/utility/fill.hpp
library/include/ck/library/utility/fill.hpp
+18
-0
No files found.
example/01_gemm/gemm_wmma_fp16.cpp
View file @
9bd44685
...
...
@@ -22,14 +22,21 @@ using CElementOp = PassThrough;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmWmma
//
using DeviceGemmInstance
0
= ck::tensor_operation::device::DeviceGemmWmma
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MWMMA|NMMMA| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
16
,
16
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
6
,
1
>
;
//
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 6, 1>;
// clang-format on
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmWmma_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MWmma|NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
16
,
16
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
...
...
example/01_gemm/run_gemm_example.inc
View file @
9bd44685
...
...
@@ -32,8 +32,11 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
case
0
:
break
;
case
1
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
1.
f
,
1.
f
}(
a_m_k
.
begin
(),
a_m_k
.
end
());
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
1.
f
,
1.
f
}(
b_k_n
.
begin
(),
b_k_n
.
end
());
// CONFIRMED
// ck::utils::FillMNID<ADataType>{}(a_m_k.begin(), a_m_k.end());
// ck::utils::FillMNID<BDataType>{}(b_k_n.begin(), b_k_n.end());
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
.
begin
(),
a_m_k
.
end
());
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
.
begin
(),
b_k_n
.
end
());
break
;
default
:
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
.
begin
(),
a_m_k
.
end
());
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
9bd44685
...
...
@@ -137,7 +137,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
"wrong!"
);
}
// Thread level, register decriptor.
// Thread level, register decriptor.
Vector-write
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
{
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
...
...
@@ -168,6 +168,51 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
);
}
// Thread level, register decriptor. Per-pixel write
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_MAccVgprs_NRepeat_NWave_NThreadPerSubGroup
()
{
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
constexpr
auto
MSubGroup
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I0
];
constexpr
auto
NThreadPerSubGroup
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I1
];
constexpr
auto
MAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
return
make_naive_tensor_descriptor_packed
(
// |MRepeat |MWave |MSubGroup |MAccVgprs |NRepeat |NWave |NThreadPerSubGroup
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
MSubGroup
,
MAccVgprs
,
Number
<
NRepeat
>
{},
I1
,
NThreadPerSubGroup
));
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerWMMA
),
MWaves
,
MPerWMMA
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerWMMA
),
NWaves
,
NPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
);
}
// Provide dimension size
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
{
constexpr
auto
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
);
}
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_K0_M0_M1_M2_K1
()
{
return
transform_tensor_descriptor
(
...
...
@@ -205,8 +250,28 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
constexpr
auto
RepeatDiff
=
MRepeat
-
NRepeat
;
// constexpr auto RepeatDiff = MRepeat - NRepeat;
// debug_hexprinter(0xffffffff, a_thread_buf[Number<a_thread_desc_.CalculateOffset( make_tuple(0, 0, 0, 0,0))>{}], "Avalue ");
/* First local prefetch, move out of blockwise operation.
static_for<0, NRepeat, 1>{}([&](auto iN){
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_thread_buf);
});
static_for<0, MRepeat, 1>{}([&](auto iN){
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_thread_buf);
});
*/
/*
static_for<0, KPerBlock, WmmaK>{}([&](auto iWmmaK){
// Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
static_for<0, RepeatDiff, 1>{}([&](auto iCut){
...
...
@@ -297,16 +362,77 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
b_thread_buf);
});
});
*/
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
>
{},
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
>
{},
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
vector_type
<
FloatAB
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatAB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
,
0
,
0
,
0
,
i
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
,
0
,
0
,
0
,
i
%
B_K1
))
>
{}];
});
using
wmma_input_type
=
typename
vector_type
<
FloatAB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
// static_for<0, 16, 1>{}([&](auto i){
// char info[4];
// info[0] = 'A';
// info[1] = i/10 + '0';
// info[2] = i%10 + '0';
// info[3] = '\0';
// debug_hexprinter(0xffffffff, a_thread_buf[Number<i>{}], info);
// });
// static_for<0, 16, 1>{}([&](auto i){
// char info[4];
// info[0] = 'B';
// info[1] = i/10 + '0';
// info[2] = i%10 + '0';
// info[3] = '\0';
// debug_hexprinter(0xffffffff, b_thread_buf[Number<i>{}], info);
// });
}
protected:
// A[M0, M1, M2, K0 = WmmaK]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
WmmaK
/
A_K1
>
{},
Number
<
MRepeat
>
{}
,
I1
,
I1
,
Number
<
A_K1
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
WmmaK
/
A_K1
>
{},
I1
,
I1
,
I1
,
Number
<
A_K1
>
{}));
// B[N0, N1, N2, K0 = WmmaK]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
WmmaK
/
B_K1
>
{},
Number
<
MRepeat
>
{}
,
I1
,
I1
,
Number
<
B_K1
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
WmmaK
/
B_K1
>
{},
I1
,
I1
,
I1
,
Number
<
B_K1
>
{}));
// C[M, N, NumRegWMMA]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
9bd44685
...
...
@@ -20,13 +20,14 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ADataType
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
...
...
@@ -54,20 +55,22 @@ template <typename ADataType,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
ck
::
index_t
NumPrefetch
=
1
,
ck
::
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
v1
>
struct
DeviceGemmWmma
:
public
DeviceGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
struct
DeviceGemmWmma
_CShuffle
:
public
DeviceGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -200,6 +203,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CShuffleDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
...
...
@@ -232,9 +236,10 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
// CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
NumPrefetch
,
LoopSched
,
PipelineVer
>
;
...
...
@@ -262,7 +267,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
a_grid_desc_k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_mblock
xrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_
{},
c_grid_desc_mblock
_mperblock_nblock_nperblock
{},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
...
...
@@ -270,9 +275,9 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
a_grid_desc_k0_m_k1_
=
DeviceGemmWmma
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceGemmWmma
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
c_grid_desc_m_n_
=
DeviceGemmWmma
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
a_grid_desc_k0_m_k1_
=
DeviceGemmWmma
_CShuffle
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceGemmWmma
_CShuffle
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
c_grid_desc_m_n_
=
DeviceGemmWmma
_CShuffle
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
...
...
@@ -282,8 +287,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
{
c_grid_desc_mblock
xrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock
xRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_grid_desc_m_n_
);
c_grid_desc_mblock
_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock
_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
}
}
...
...
@@ -294,8 +299,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock
xRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
c_grid_desc_mblock
xrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock
_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock
_mperblock_nblock_nperblock
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
...
...
@@ -307,7 +312,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceGemmWmma
::
Argument
;
using
Argument
=
DeviceGemmWmma
_CShuffle
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
...
...
@@ -350,9 +355,9 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmWmma
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmWmma
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock
xRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
>
,
remove_reference_t
<
DeviceGemmWmma
_CShuffle
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmWmma
_CShuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock
_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
...
...
@@ -369,7 +374,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock
xrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_
,
arg
.
c_grid_desc_mblock
_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
...
...
@@ -381,9 +386,9 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmWmma
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmWmma
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock
xRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
>
,
remove_reference_t
<
DeviceGemmWmma
_CShuffle
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmWmma
_CShuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock
_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
...
...
@@ -400,7 +405,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock
xrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_
,
arg
.
c_grid_desc_mblock
_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
...
...
@@ -428,8 +433,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
{
if
(
ck
::
get_device_name
()
==
"gfx1100"
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
return
false
;
}
...
...
@@ -530,7 +534,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
str
<<
"DeviceGemmWmma"
str
<<
"DeviceGemmWmma
_CShuffle
"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
9bd44685
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
9bd44685
...
...
@@ -119,7 +119,29 @@ struct ThreadwiseTensorSliceTransfer_v1r3
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
remove_cv_t
<
decltype
(
dst_scalar_per_access
)
>>
;
// printf("SpaceFillingCurve access_lengths = (%d, %d, %d, %d, %d, %d, %d)\n", SpaceFillingCurve::access_lengths[Number<0>{}].value,
// SpaceFillingCurve::access_lengths[Number<1>{}].value,
// SpaceFillingCurve::access_lengths[Number<2>{}].value,
// SpaceFillingCurve::access_lengths[Number<3>{}].value,
// SpaceFillingCurve::access_lengths[Number<4>{}].value,
// SpaceFillingCurve::access_lengths[Number<5>{}].value,
// SpaceFillingCurve::access_lengths[Number<6>{}].value);
//
// // printf("SpaceFillingCurve dim_access_order = (%d, %d, %d, %d, %d, %d, %d)\n", SpaceFillingCurve::dim_access_order[Number<0>{}].value,
// SpaceFillingCurve::dim_access_order[Number<1>{}].value,
// SpaceFillingCurve::dim_access_order[Number<2>{}].value,
// SpaceFillingCurve::dim_access_order[Number<3>{}].value,
// SpaceFillingCurve::dim_access_order[Number<4>{}].value,
// SpaceFillingCurve::dim_access_order[Number<5>{}].value,
// SpaceFillingCurve::dim_access_order[Number<6>{}].value);
//
// // // printf("SpaceFillingCurve ordered_access_lengths = (%d, %d, %d, %d, %d, %d, %d)\n", SpaceFillingCurve::ordered_access_lengths[Number<0>{}].value,
// SpaceFillingCurve::ordered_access_lengths[Number<1>{}].value,
// SpaceFillingCurve::ordered_access_lengths[Number<2>{}].value,
// SpaceFillingCurve::ordered_access_lengths[Number<3>{}].value,
// SpaceFillingCurve::ordered_access_lengths[Number<4>{}].value,
// SpaceFillingCurve::ordered_access_lengths[Number<5>{}].value,
// SpaceFillingCurve::ordered_access_lengths[Number<6>{}].value);
// TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector?
static_assert
(
DstScalarPerVector
==
SpaceFillingCurve
::
ScalarPerVector
,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"
);
...
...
@@ -136,7 +158,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
// debug_hexprinter(0xffffffff, src_offset, "src_coord_iteration");
SrcData
v
;
// apply element-wise operation
...
...
@@ -154,11 +176,11 @@ struct ThreadwiseTensorSliceTransfer_v1r3
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}]);
// debug_hexprinter(0xffffffff, dst_coord_.GetOffset(), "dst_coord_iteration");
if
constexpr
(
idx_1d
.
value
!=
num_access
-
1
)
{
constexpr
auto
forward_step
=
SpaceFillingCurve
::
GetForwardStep
(
idx_1d
);
// printf("move forward = (%d, %d, %d, %d, %d, %d, %d)\n", forward_step[Number<0>{}], forward_step[Number<1>{}], forward_step[Number<2>{}], forward_step[Number<3>{}], forward_step[Number<4>{}], forward_step[Number<5>{}], forward_step[Number<6>{}]);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
make_tensor_coordinate_step
(
dst_desc
,
forward_step
));
}
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
9bd44685
...
...
@@ -96,7 +96,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_element_op_
(
src_element_op
),
dst_element_op_
(
dst_element_op
)
{
printf
(
"global desc: %s
\n
"
,
__PRETTY_FUNCTION__
);
//
printf("global desc: %s\n", __PRETTY_FUNCTION__);
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
...
...
@@ -128,12 +128,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
printf
(
"src_access_lengths: %d, %d, %d
\n
"
,
(
src_access_lengths
[
Number
<
0
>
{}])(),
src_access_lengths
[
Number
<
1
>
{}](),
src_access_lengths
[
Number
<
2
>
{}]());
//
printf("src_access_lengths: %d, %d, %d\n", (src_access_lengths[Number<0>{}])(), src_access_lengths[Number<1>{}](), src_access_lengths[Number<2>{}]());
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
ordered_src_access_lengths
=
container_reorder_given_new2old
(
src_access_lengths
,
src_dim_access_order
);
printf
(
"ordered_src_access_lengths: %d, %d, %d
\n
"
,
(
ordered_src_access_lengths
[
Number
<
0
>
{}])(),
ordered_src_access_lengths
[
Number
<
1
>
{}](),
ordered_src_access_lengths
[
Number
<
2
>
{}]());
//
printf("ordered_src_access_lengths: %d, %d, %d\n", (ordered_src_access_lengths[Number<0>{}])(), ordered_src_access_lengths[Number<1>{}](), ordered_src_access_lengths[Number<2>{}]());
// make forward steps
const
auto
src_forward_steps
=
generate_tuple
(
...
...
@@ -147,9 +147,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
return
make_tensor_coordinate_step
(
src_desc
,
forward_step_idx
);
},
Number
<
nDim
>
{});
printf
(
"src_forward_steps: %d, %d, %d
\n
"
,
(
src_forward_steps
.
GetIndexDiff
()[
Number
<
0
>
{}])(),
(
src_forward_steps
.
GetIndexDiff
()[
Number
<
1
>
{}])(),
(
src_forward_steps
.
GetIndexDiff
()[
Number
<
2
>
{}])()
);
// make backward steps
const
auto
src_backward_steps
=
generate_tuple
(
...
...
@@ -213,7 +210,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
// apply SrcElementwiseOperation on src_vector_container
debug_hexprinter
(
0xffffffff
,
src_coord_
.
GetOffset
());
//
debug_hexprinter(0xffffffff, src_coord_.GetOffset());
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
SrcData
src_v
;
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
9bd44685
...
...
@@ -205,6 +205,7 @@ struct WmmaGemm
}
// WMMA output supporting C = A * B
// Vector Write
// MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave
template
<
typename
CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA
>
__host__
__device__
static
constexpr
auto
...
...
@@ -239,6 +240,40 @@ struct WmmaGemm
Sequence
<
5
>
{}));
}
// Per-Pixel write
template
<
typename
CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA
>
__host__
__device__
static
constexpr
auto
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
(
const
CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA
&
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
)
{
const
auto
MBlockxRepeat
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I0
);
const
auto
NBlockxRepeat
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I3
);
const
auto
MWave
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I1
);
const
auto
NWave
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I4
);
return
transform_tensor_descriptor
(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
,
make_tuple
(
make_pass_through_transform
(
MBlockxRepeat
),
make_pass_through_transform
(
MWave
),
make_unmerge_transform
(
make_tuple
(
Number
<
wmma_instr
.
num_subgroups
>
{},
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{})),
make_pass_through_transform
(
NBlockxRepeat
),
make_pass_through_transform
(
NWave
),
make_pass_through_transform
(
Number
<
wmma_instr
.
num_thread_per_subgroups
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{}));
}
__device__
static
constexpr
index_t
GetRegSizePerWmma
()
{
return
wmma_instr
.
num_acc_vgprs_per_wave
;
...
...
include/ck/utility/common_header.hpp
View file @
9bd44685
...
...
@@ -73,10 +73,26 @@ constexpr auto type_name() {
return
name
;
}
// Accepet int, float, and Number<> as input
template
<
typename
T
>
__device__
void
debug_hexprinter
(
const
uint32_t
v_target
,
T
v_val
){
const
uint32_t
v_dbg
=
*
(
reinterpret_cast
<
uint32_t
*>
(
&
v_val
));
if
(
v_dbg
!=
v_target
)
printf
(
"@Thread: %d, Val: %08x != Target: %08x
\n
"
,
ck
::
get_thread_local_1d_id
(),
v_dbg
,
v_target
);
__host__
__device__
void
debug_hexprinter
(
const
uint32_t
v_target
,
const
T
v_val
,
const
char
*
info
){
if
constexpr
(
std
::
is_same_v
<
T
,
int
>
||
std
::
is_same_v
<
T
,
float
>
)
{
const
uint32_t
v_dbg
=
*
(
reinterpret_cast
<
const
uint32_t
*>
(
&
v_val
));
if
(
v_dbg
!=
v_target
)
printf
(
"%s@Thread: %d, Val: %08x != Target: %08x
\n
"
,
info
,
ck
::
get_thread_local_1d_id
(),
v_dbg
,
v_target
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
_Float16
>
)
{
const
uint16_t
v_dbg
=
*
(
reinterpret_cast
<
const
uint16_t
*>
(
&
v_val
));
if
(
v_dbg
!=
v_target
)
printf
(
"%s@Thread: %d, Val: %04x != Target: %08x
\n
"
,
info
,
ck
::
get_thread_local_1d_id
(),
v_dbg
,
v_target
);
}
else
{
const
uint32_t
v_dbg
=
*
(
reinterpret_cast
<
const
uint32_t
*>
(
&
(
v_val
.
value
)));
if
(
v_dbg
!=
v_target
)
printf
(
"%s@Thread: %d, Val: %08x != Target: %08x
\n
"
,
info
,
ck
::
get_thread_local_1d_id
(),
v_dbg
,
v_target
);
}
}
library/include/ck/library/utility/check_err.hpp
View file @
9bd44685
...
...
@@ -49,7 +49,7 @@ check_err(const std::vector<T>& out,
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
16384
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
out
[
i
]
<<
" != "
<<
ref
[
i
]
<<
std
::
endl
;
...
...
@@ -59,6 +59,7 @@ check_err(const std::vector<T>& out,
}
if
(
!
res
)
{
std
::
cerr
<<
"err count: "
<<
err_count
<<
std
::
endl
;
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
return
res
;
...
...
@@ -93,7 +94,7 @@ check_err(const std::vector<T>& out,
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
16384
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
...
@@ -103,6 +104,7 @@ check_err(const std::vector<T>& out,
}
if
(
!
res
)
{
std
::
cerr
<<
"err count: "
<<
err_count
<<
std
::
endl
;
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
return
res
;
...
...
@@ -136,7 +138,7 @@ check_err(span<const T> out,
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
16384
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
...
@@ -146,6 +148,7 @@ check_err(span<const T> out,
}
if
(
!
res
)
{
std
::
cerr
<<
"err count: "
<<
err_count
<<
std
::
endl
;
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
return
res
;
...
...
@@ -196,7 +199,7 @@ check_err(const std::vector<T>& out,
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
16384
)
{
std
::
cerr
<<
msg
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
...
@@ -206,6 +209,7 @@ check_err(const std::vector<T>& out,
}
if
(
!
res
)
{
std
::
cerr
<<
"err count: "
<<
err_count
<<
std
::
endl
;
std
::
cerr
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
return
res
;
...
...
library/include/ck/library/utility/fill.hpp
View file @
9bd44685
...
...
@@ -103,5 +103,23 @@ struct FillConstant
}
};
template
<
typename
T
>
struct
FillMNID
{
T
step_
{
0.1
};
int
k_num_
{
32
};
int
mn_num_
{
128
};
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
std
::
generate
(
first
,
last
,
[
=
,
iter
=
0
]()
mutable
{
auto
tmp
=
((
iter
/
k_num_
)
%
mn_num_
)
*
step_
;
iter
++
;
return
tmp
;
});
}
};
}
// namespace utils
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment