Commit 1a24ad25 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 7cd48ef1
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "host_gemm.hpp" #include "host_gemm.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
...@@ -54,47 +53,51 @@ using ALayout = ck::tensor_layout::gemm::RowMajor; ...@@ -54,47 +53,51 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle< using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<
ADataType, // ADataType ALayout, // typename ALayout,
BDataType, // BDataType BLayout, // typename BLayout,
CDataType, // CDataType CLayout, // typename CLayout,
AccDataType, // AccDataType ADataType, // typename ADataType,
CShuffleDataType, // CShuffleDataType BDataType, // typename BDataType,
ALayout, // ALayout CDataType, // typename CDataType,
BLayout, // BLayout AccDataType, // typename GemmAccDataType,
CLayout, // CLayout CShuffleDataType, // typename CShuffleDataType,
PassThrough, // AElementwiseOperation PassThrough, // typename AElementwiseOperation,
PassThrough, // BElementwiseOperation PassThrough, // typename BElementwiseOperation,
RequantReluRequant, // CElementwiseOperation RequantReluRequant, // typename CElementwiseOperation,
256, // BlockSize GemmDefault, // GemmSpecialization GemmSpec,
256, // MPerBlock 1, // index_t NumGemmKPrefetchStage,
128, // NPerBlock 256, // index_t BlockSize,
64, // KPerBlock 256, // index_t MPerBlock,
16, // AK1 128, // index_t NPerBlock,
16, // BK1 64, // index_t KPerBlock,
32, // MPerXDL 16, // index_t AK1,
32, // NPerXDL 16, // index_t BK1,
4, // MXdlPerWave 32, // index_t MPerXDL,
2, // NXdlPerWave 32, // index_t NPerXDL,
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 4, // index_t MXdlPerWave,
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder 2, // index_t NXdlPerWave,
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
2, // ABlockTransferSrcVectorDim S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder,
16, // ABlockTransferSrcScalarPerVector S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder,
16, // ABlockTransferDstScalarPerVector_K1 2, // index_t ABlockTransferSrcVectorDim,
true, // ABlockLdsAddExtraM 16, // index_t ABlockTransferSrcScalarPerVector,
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 16, // index_t ABlockTransferDstScalarPerVector_AK1,
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder 1, // bool ABlockLdsExtraM,
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
2, // BBlockTransferSrcVectorDim S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder,
16, // BBlockTransferSrcScalarPerVector S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder,
16, // BBlockTransferDstScalarPerVector_K1 2, // index_t BBlockTransferSrcVectorDim,
true, // BBlockLdsAddExtraN 8, // index_t BBlockTransferSrcScalarPerVector,
1, // CShuffleMXdlPerWavePerShuffle 8, // index_t BBlockTransferDstScalarPerVector_BK1,
1, // CShuffleNXdlPerWavePerShuffle 1, // bool BBlockLdsExtraN,
S<1, 1, 64, 1, 1, 4>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl 1, // index_t CShuffleMXdlPerWavePerShuffle,
16>; // CBlockTransferScalarPerVector_NWaveNPerXdl 1, // index_t CShuffleNXdlPerWavePerShuffle,
S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
#ifndef REFERENCE_GEMM_HPP #pragma once
#define REFERENCE_GEMM_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "device_base.hpp" #include "device_base.hpp"
...@@ -129,4 +127,3 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -129,4 +127,3 @@ struct ReferenceGemm : public device::BaseOperator
} // namespace host } // namespace host
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
#ifndef CHECK_ERR_HPP #pragma once
#define CHECK_ERR_HPP
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <cstdlib> #include <cstdlib>
...@@ -194,5 +192,3 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v) ...@@ -194,5 +192,3 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
std::copy(std::begin(v), std::end(v), std::ostream_iterator<T>(os, " ")); std::copy(std::begin(v), std::end(v), std::ostream_iterator<T>(os, " "));
return os; return os;
} }
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment