Commit 9bd44685 authored by aska-0096's avatar aska-0096
Browse files

Correctness OK, waiting for optimization

parent 73959956
...@@ -22,14 +22,21 @@ using CElementOp = PassThrough; ...@@ -22,14 +22,21 @@ using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma // using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmWmma
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MWMMA|NMMMA| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| // ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MWMMA|NMMMA| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| // ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 6, 1>; // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 6, 1>;
// clang-format on // clang-format on
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MWmma|NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 8>, 8>;
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
......
...@@ -32,8 +32,11 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -32,8 +32,11 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{ {
case 0: break; case 0: break;
case 1: case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k.begin(), a_m_k.end()); // CONFIRMED
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n.begin(), b_k_n.end()); // ck::utils::FillMNID<ADataType>{}(a_m_k.begin(), a_m_k.end());
// ck::utils::FillMNID<BDataType>{}(b_k_n.begin(), b_k_n.end());
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k.begin(), a_m_k.end());
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n.begin(), b_k_n.end());
break; break;
default: default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k.begin(), a_m_k.end()); ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k.begin(), a_m_k.end());
......
...@@ -137,7 +137,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -137,7 +137,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && NPerBlock % (NPerWMMA * NRepeat) == 0, static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!"); "wrong!");
} }
// Thread level, register decriptor. // Thread level, register decriptor. Vector-write
__host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() __host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{ {
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
...@@ -168,6 +168,51 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -168,6 +168,51 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
} }
// Thread level, register decriptor. Per-pixel write
__host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_MAccVgprs_NRepeat_NWave_NThreadPerSubGroup()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |MAccVgprs |NRepeat |NWave |NThreadPerSubGroup
make_tuple(Number<MRepeat>{}, I1, MSubGroup, MAccVgprs, Number<NRepeat>{}, I1, NThreadPerSubGroup));
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
}
// Provide dimension size
__host__ __device__ static constexpr auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWMMA>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWMMA>{}));
return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
__host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1()
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -205,8 +250,28 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -205,8 +250,28 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
constexpr auto RepeatDiff = MRepeat - NRepeat; // constexpr auto RepeatDiff = MRepeat - NRepeat;
// debug_hexprinter(0xffffffff, a_thread_buf[Number<a_thread_desc_.CalculateOffset( make_tuple(0, 0, 0, 0,0))>{}], "Avalue ");
/* First local prefetch, move out of blockwise operation.
static_for<0, NRepeat, 1>{}([&](auto iN){
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_thread_buf);
});
static_for<0, MRepeat, 1>{}([&](auto iN){
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_thread_buf);
});
*/
/*
static_for<0, KPerBlock, WmmaK>{}([&](auto iWmmaK){ static_for<0, KPerBlock, WmmaK>{}([&](auto iWmmaK){
// Cut to Repeat Retangle to Square, assume MRepeat > NRepeat // Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
static_for<0, RepeatDiff, 1>{}([&](auto iCut){ static_for<0, RepeatDiff, 1>{}([&](auto iCut){
...@@ -297,16 +362,77 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -297,16 +362,77 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
b_thread_buf); b_thread_buf);
}); });
}); });
*/
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k*WmmaK/A_K1>{}, m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k*WmmaK/B_K1>{}, n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
b_thread_buf);
vector_type<FloatAB, WmmaK> a_thread_vec;
vector_type<FloatAB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(i/A_K1, 0, 0, 0, i%A_K1))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(i/B_K1, 0, 0, 0, i%B_K1))>{}];
});
using wmma_input_type =
typename vector_type<FloatAB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// static_for<0, 16, 1>{}([&](auto i){
// char info[4];
// info[0] = 'A';
// info[1] = i/10 + '0';
// info[2] = i%10 + '0';
// info[3] = '\0';
// debug_hexprinter(0xffffffff, a_thread_buf[Number<i>{}], info);
// });
// static_for<0, 16, 1>{}([&](auto i){
// char info[4];
// info[0] = 'B';
// info[1] = i/10 + '0';
// info[2] = i%10 + '0';
// info[3] = '\0';
// debug_hexprinter(0xffffffff, b_thread_buf[Number<i>{}], info);
// });
} }
protected: protected:
// A[M0, M1, M2, K0 = WmmaK] // A[M0, M1, M2, K0 = WmmaK]
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<WmmaK/A_K1>{}, Number<MRepeat>{}, I1, I1, Number<A_K1>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<WmmaK/A_K1>{}, I1, I1, I1, Number<A_K1>{}));
// B[N0, N1, N2, K0 = WmmaK] // B[N0, N1, N2, K0 = WmmaK]
static constexpr auto b_thread_desc_ = static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<WmmaK/B_K1>{}, Number<MRepeat>{}, I1, I1, Number<B_K1>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<WmmaK/B_K1>{}, I1, I1, I1, Number<B_K1>{}));
// C[M, N, NumRegWMMA] // C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
......
...@@ -20,13 +20,14 @@ namespace ck { ...@@ -20,13 +20,14 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename ADataType, template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AccDataType, typename AccDataType,
typename ALayout, typename CShuffleDataType,
typename BLayout,
typename CLayout,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
...@@ -54,20 +55,22 @@ template <typename ADataType, ...@@ -54,20 +55,22 @@ template <typename ADataType,
ck::index_t BBlockTransferSrcScalarPerVector, ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1, ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN, bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim, index_t CShuffleMRepeatPerShuffle,
ck::index_t CThreadTransferDstScalarPerVector, index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
ck::index_t NumPrefetch = 1, ck::index_t NumPrefetch = 1,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(), ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1> ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceGemmWmma : public DeviceGemm<ALayout, struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation> CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -200,6 +203,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -200,6 +203,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
BlockSize, BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CShuffleDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
...@@ -232,9 +236,10 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -232,9 +236,10 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN, BBlockLdsAddExtraN,
Sequence<0, 1, 2, 3, 4, 5, 6>, // CThreadTransferSrcDstAccessOrder, CShuffleMRepeatPerShuffle,
CThreadTransferSrcDstVectorDim, CShuffleNRepeatPerShuffle,
CThreadTransferDstScalarPerVector, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
NumPrefetch, NumPrefetch,
LoopSched, LoopSched,
PipelineVer>; PipelineVer>;
...@@ -262,7 +267,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -262,7 +267,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
a_grid_desc_k0_m_k1_{}, a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{}, b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{}, c_grid_desc_m_n_{},
c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_{}, c_grid_desc_mblock_mperblock_nblock_nperblock{},
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
...@@ -270,9 +275,9 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -270,9 +275,9 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
a_grid_desc_k0_m_k1_ = DeviceGemmWmma::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); a_grid_desc_k0_m_k1_ = DeviceGemmWmma_CShuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmWmma::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); b_grid_desc_k0_n_k1_ = DeviceGemmWmma_CShuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmWmma::MakeCGridDescriptor_M_N(M, N, StrideC); c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ = block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
...@@ -282,8 +287,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -282,8 +287,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
c_grid_desc_m_n_, c_grid_desc_m_n_,
block_2_ctile_map_)) block_2_ctile_map_))
{ {
c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_ = c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
} }
} }
...@@ -294,8 +299,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -294,8 +299,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_; c_grid_desc_mblock_mperblock_nblock_nperblock;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
...@@ -307,7 +312,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -307,7 +312,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
using Argument = DeviceGemmWmma::Argument; using Argument = DeviceGemmWmma_CShuffle::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
...@@ -350,9 +355,9 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -350,9 +355,9 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceGemmWmma::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs>, remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -369,7 +374,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -369,7 +374,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -381,9 +386,9 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -381,9 +386,9 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceGemmWmma::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs>, remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -400,7 +405,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -400,7 +405,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -428,8 +433,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -428,8 +433,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
{ {
if(ck::get_device_name() == "gfx1100") if(ck::get_device_name() == "gfx1100")
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> || if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
is_same_v<AccDataType, int32_t>))
{ {
return false; return false;
} }
...@@ -530,7 +534,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -530,7 +534,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
{PipelineVersion::v2, "v2"}}; {PipelineVersion::v2, "v2"}};
// clang-format off // clang-format off
str << "DeviceGemmWmma" str << "DeviceGemmWmma_CShuffle"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -119,7 +119,29 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -119,7 +119,29 @@ struct ThreadwiseTensorSliceTransfer_v1r3
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths, using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder, DimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>>; remove_cv_t<decltype(dst_scalar_per_access)>>;
// printf("SpaceFillingCurve access_lengths = (%d, %d, %d, %d, %d, %d, %d)\n", SpaceFillingCurve::access_lengths[Number<0>{}].value,
// SpaceFillingCurve::access_lengths[Number<1>{}].value,
// SpaceFillingCurve::access_lengths[Number<2>{}].value,
// SpaceFillingCurve::access_lengths[Number<3>{}].value,
// SpaceFillingCurve::access_lengths[Number<4>{}].value,
// SpaceFillingCurve::access_lengths[Number<5>{}].value,
// SpaceFillingCurve::access_lengths[Number<6>{}].value);
//
// // printf("SpaceFillingCurve dim_access_order = (%d, %d, %d, %d, %d, %d, %d)\n", SpaceFillingCurve::dim_access_order[Number<0>{}].value,
// SpaceFillingCurve::dim_access_order[Number<1>{}].value,
// SpaceFillingCurve::dim_access_order[Number<2>{}].value,
// SpaceFillingCurve::dim_access_order[Number<3>{}].value,
// SpaceFillingCurve::dim_access_order[Number<4>{}].value,
// SpaceFillingCurve::dim_access_order[Number<5>{}].value,
// SpaceFillingCurve::dim_access_order[Number<6>{}].value);
//
// // // printf("SpaceFillingCurve ordered_access_lengths = (%d, %d, %d, %d, %d, %d, %d)\n", SpaceFillingCurve::ordered_access_lengths[Number<0>{}].value,
// SpaceFillingCurve::ordered_access_lengths[Number<1>{}].value,
// SpaceFillingCurve::ordered_access_lengths[Number<2>{}].value,
// SpaceFillingCurve::ordered_access_lengths[Number<3>{}].value,
// SpaceFillingCurve::ordered_access_lengths[Number<4>{}].value,
// SpaceFillingCurve::ordered_access_lengths[Number<5>{}].value,
// SpaceFillingCurve::ordered_access_lengths[Number<6>{}].value);
// TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector? // TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector?
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
...@@ -136,7 +158,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -136,7 +158,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
// debug_hexprinter(0xffffffff, src_offset, "src_coord_iteration");
SrcData v; SrcData v;
// apply element-wise operation // apply element-wise operation
...@@ -154,11 +176,11 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -154,11 +176,11 @@ struct ThreadwiseTensorSliceTransfer_v1r3
dst_coord_.GetOffset(), dst_coord_.GetOffset(),
is_dst_valid, is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]); dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
// debug_hexprinter(0xffffffff, dst_coord_.GetOffset(), "dst_coord_iteration");
if constexpr(idx_1d.value != num_access - 1) if constexpr(idx_1d.value != num_access - 1)
{ {
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
// printf("move forward = (%d, %d, %d, %d, %d, %d, %d)\n", forward_step[Number<0>{}], forward_step[Number<1>{}], forward_step[Number<2>{}], forward_step[Number<3>{}], forward_step[Number<4>{}], forward_step[Number<5>{}], forward_step[Number<6>{}]);
move_tensor_coordinate( move_tensor_coordinate(
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
} }
......
...@@ -96,7 +96,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -96,7 +96,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_element_op_(src_element_op), src_element_op_(src_element_op),
dst_element_op_(dst_element_op) dst_element_op_(dst_element_op)
{ {
printf("global desc: %s\n", __PRETTY_FUNCTION__); // printf("global desc: %s\n", __PRETTY_FUNCTION__);
} }
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
...@@ -128,12 +128,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -128,12 +128,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
printf("src_access_lengths: %d, %d, %d\n", (src_access_lengths[Number<0>{}])(), src_access_lengths[Number<1>{}](), src_access_lengths[Number<2>{}]()); // printf("src_access_lengths: %d, %d, %d\n", (src_access_lengths[Number<0>{}])(), src_access_lengths[Number<1>{}](), src_access_lengths[Number<2>{}]());
constexpr auto src_dim_access_order = SrcDimAccessOrder{}; constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths = constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order); container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
printf("ordered_src_access_lengths: %d, %d, %d\n", (ordered_src_access_lengths[Number<0>{}])(), ordered_src_access_lengths[Number<1>{}](), ordered_src_access_lengths[Number<2>{}]()); // printf("ordered_src_access_lengths: %d, %d, %d\n", (ordered_src_access_lengths[Number<0>{}])(), ordered_src_access_lengths[Number<1>{}](), ordered_src_access_lengths[Number<2>{}]());
// make forward steps // make forward steps
const auto src_forward_steps = generate_tuple( const auto src_forward_steps = generate_tuple(
...@@ -147,9 +147,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -147,9 +147,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
return make_tensor_coordinate_step(src_desc, forward_step_idx); return make_tensor_coordinate_step(src_desc, forward_step_idx);
}, },
Number<nDim>{}); Number<nDim>{});
printf("src_forward_steps: %d, %d, %d\n", (src_forward_steps.GetIndexDiff()[Number<0>{}])(),
(src_forward_steps.GetIndexDiff()[Number<1>{}])(),
(src_forward_steps.GetIndexDiff()[Number<2>{}])() );
// make backward steps // make backward steps
const auto src_backward_steps = generate_tuple( const auto src_backward_steps = generate_tuple(
...@@ -213,7 +210,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -213,7 +210,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)}; src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
// apply SrcElementwiseOperation on src_vector_container // apply SrcElementwiseOperation on src_vector_container
debug_hexprinter(0xffffffff, src_coord_.GetOffset()); // debug_hexprinter(0xffffffff, src_coord_.GetOffset());
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
SrcData src_v; SrcData src_v;
......
...@@ -205,6 +205,7 @@ struct WmmaGemm ...@@ -205,6 +205,7 @@ struct WmmaGemm
} }
// WMMA output supporting C = A * B // WMMA output supporting C = A * B
// Vector Write
// MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave // MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave
template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA> template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -239,6 +240,40 @@ struct WmmaGemm ...@@ -239,6 +240,40 @@ struct WmmaGemm
Sequence<5>{})); Sequence<5>{}));
} }
// Per-Pixel write
template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
__host__ __device__ static constexpr auto
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
{
const auto MBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
const auto NBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
const auto MWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
const auto NWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
return transform_tensor_descriptor(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
make_tuple(make_pass_through_transform(MBlockxRepeat),
make_pass_through_transform(MWave),
make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
Number<wmma_instr.num_acc_vgprs_per_wave>{})),
make_pass_through_transform(NBlockxRepeat),
make_pass_through_transform(NWave),
make_pass_through_transform(Number<wmma_instr.num_thread_per_subgroups>{})),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2, 3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{}));
}
__device__ static constexpr index_t GetRegSizePerWmma() __device__ static constexpr index_t GetRegSizePerWmma()
{ {
return wmma_instr.num_acc_vgprs_per_wave; return wmma_instr.num_acc_vgprs_per_wave;
......
...@@ -73,10 +73,26 @@ constexpr auto type_name() { ...@@ -73,10 +73,26 @@ constexpr auto type_name() {
return name; return name;
} }
// Accepet int, float, and Number<> as input
template <typename T> template <typename T>
__device__ __host__ __device__
void debug_hexprinter(const uint32_t v_target, T v_val){ void debug_hexprinter(const uint32_t v_target, const T v_val, const char* info){
const uint32_t v_dbg = *(reinterpret_cast<uint32_t*>(&v_val)); if constexpr(std::is_same_v<T, int> || std::is_same_v<T, float> )
if(v_dbg != v_target) {
printf("@Thread: %d, Val: %08x != Target: %08x\n", ck::get_thread_local_1d_id(), v_dbg, v_target); const uint32_t v_dbg = *(reinterpret_cast<const uint32_t*>(&v_val));
if(v_dbg != v_target)
printf("%s@Thread: %d, Val: %08x != Target: %08x\n", info, ck::get_thread_local_1d_id(), v_dbg, v_target);
}
else if constexpr(std::is_same_v<T, _Float16>)
{
const uint16_t v_dbg = *(reinterpret_cast<const uint16_t*>(&v_val));
if(v_dbg != v_target)
printf("%s@Thread: %d, Val: %04x != Target: %08x\n", info, ck::get_thread_local_1d_id(), v_dbg, v_target);
}
else
{
const uint32_t v_dbg = *(reinterpret_cast<const uint32_t*>(&(v_val.value)));
if(v_dbg != v_target)
printf("%s@Thread: %d, Val: %08x != Target: %08x\n", info, ck::get_thread_local_1d_id(), v_dbg, v_target);
}
} }
...@@ -49,7 +49,7 @@ check_err(const std::vector<T>& out, ...@@ -49,7 +49,7 @@ check_err(const std::vector<T>& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 16384)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << out[i] << " != " << ref[i] << std::endl; << "] != ref[" << i << "]: " << out[i] << " != " << ref[i] << std::endl;
...@@ -59,6 +59,7 @@ check_err(const std::vector<T>& out, ...@@ -59,6 +59,7 @@ check_err(const std::vector<T>& out,
} }
if(!res) if(!res)
{ {
std::cerr << "err count: " << err_count << std::endl;
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
} }
return res; return res;
...@@ -93,7 +94,7 @@ check_err(const std::vector<T>& out, ...@@ -93,7 +94,7 @@ check_err(const std::vector<T>& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 16384)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
...@@ -103,6 +104,7 @@ check_err(const std::vector<T>& out, ...@@ -103,6 +104,7 @@ check_err(const std::vector<T>& out,
} }
if(!res) if(!res)
{ {
std::cerr << "err count: " << err_count << std::endl;
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
} }
return res; return res;
...@@ -136,7 +138,7 @@ check_err(span<const T> out, ...@@ -136,7 +138,7 @@ check_err(span<const T> out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 16384)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
...@@ -146,6 +148,7 @@ check_err(span<const T> out, ...@@ -146,6 +148,7 @@ check_err(span<const T> out,
} }
if(!res) if(!res)
{ {
std::cerr << "err count: " << err_count << std::endl;
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
} }
return res; return res;
...@@ -196,7 +199,7 @@ check_err(const std::vector<T>& out, ...@@ -196,7 +199,7 @@ check_err(const std::vector<T>& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 16384)
{ {
std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
<< std::endl; << std::endl;
...@@ -206,6 +209,7 @@ check_err(const std::vector<T>& out, ...@@ -206,6 +209,7 @@ check_err(const std::vector<T>& out,
} }
if(!res) if(!res)
{ {
std::cerr << "err count: " << err_count << std::endl;
std::cerr << "max err: " << max_err << std::endl; std::cerr << "max err: " << max_err << std::endl;
} }
return res; return res;
......
...@@ -103,5 +103,23 @@ struct FillConstant ...@@ -103,5 +103,23 @@ struct FillConstant
} }
}; };
template <typename T>
struct FillMNID
{
T step_{0.1};
int k_num_{32};
int mn_num_{128};
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::generate(first, last, [=, iter = 0]() mutable {
auto tmp = ((iter/k_num_) % mn_num_ ) * step_;
iter ++;
return tmp;
});
}
};
} // namespace utils } // namespace utils
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment