Commit 1a24ad25 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 7cd48ef1
......@@ -13,8 +13,7 @@
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
......@@ -54,47 +53,51 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle<
ADataType, // ADataType
BDataType, // BDataType
CDataType, // CDataType
AccDataType, // AccDataType
CShuffleDataType, // CShuffleDataType
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
PassThrough, // AElementwiseOperation
PassThrough, // BElementwiseOperation
RequantReluRequant, // CElementwiseOperation
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
64, // KPerBlock
16, // AK1
16, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 1, 64, 1, 1, 4>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
16>; // CBlockTransferScalarPerVector_NWaveNPerXdl
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<
ALayout, // typename ALayout,
BLayout, // typename BLayout,
CLayout, // typename CLayout,
ADataType, // typename ADataType,
BDataType, // typename BDataType,
CDataType, // typename CDataType,
AccDataType, // typename GemmAccDataType,
CShuffleDataType, // typename CShuffleDataType,
PassThrough, // typename AElementwiseOperation,
PassThrough, // typename BElementwiseOperation,
RequantReluRequant, // typename CElementwiseOperation,
GemmDefault, // GemmSpecialization GemmSpec,
1, // index_t NumGemmKPrefetchStage,
256, // index_t BlockSize,
256, // index_t MPerBlock,
128, // index_t NPerBlock,
64, // index_t KPerBlock,
16, // index_t AK1,
16, // index_t BK1,
32, // index_t MPerXDL,
32, // index_t NPerXDL,
4, // index_t MXdlPerWave,
2, // index_t NXdlPerWave,
S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder,
S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder,
2, // index_t ABlockTransferSrcVectorDim,
16, // index_t ABlockTransferSrcScalarPerVector,
16, // index_t ABlockTransferDstScalarPerVector_AK1,
1, // bool ABlockLdsExtraM,
S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder,
S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder,
2, // index_t BBlockTransferSrcVectorDim,
8, // index_t BBlockTransferSrcScalarPerVector,
8, // index_t BBlockTransferDstScalarPerVector_BK1,
1, // bool BBlockLdsExtraN,
1, // index_t CShuffleMXdlPerWavePerShuffle,
1, // index_t CShuffleNXdlPerWavePerShuffle,
S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
......
#ifndef REFERENCE_GEMM_HPP
#define REFERENCE_GEMM_HPP
#pragma once
#include <iostream>
#include <sstream>
#include "device_base.hpp"
......@@ -129,4 +127,3 @@ struct ReferenceGemm : public device::BaseOperator
} // namespace host
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef CHECK_ERR_HPP
#define CHECK_ERR_HPP
#pragma once
#include <algorithm>
#include <cmath>
#include <cstdlib>
......@@ -194,5 +192,3 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
std::copy(std::begin(v), std::end(v), std::ostream_iterator<T>(os, " "));
return os;
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment