"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "41bcd7d09f37be9cc3e28f7172b697f1ec677082"
Commit 66206c23 authored by Chao Liu's avatar Chao Liu
Browse files

rename

parent ad8c418d
#ifndef DEVICE_GEMM_SHUFFLE_XDL_HPP #ifndef DEVICE_GEMM_XDL_C_SHUFFLE_HPP
#define DEVICE_GEMM_SHUFFLE_XDL_HPP #define DEVICE_GEMM_XDL_C_SHUFFLE_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -55,7 +55,7 @@ template < ...@@ -55,7 +55,7 @@ template <
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl> index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct DeviceGemmShuffleXdl struct DeviceGemmXdl_C_Shuffle
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> : public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -207,9 +207,11 @@ struct DeviceGemmShuffleXdl ...@@ -207,9 +207,11 @@ struct DeviceGemmShuffleXdl
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
a_grid_desc_k0_m_k1_ = DeviceGemmShuffleXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); a_grid_desc_k0_m_k1_ =
b_grid_desc_k0_n_k1_ = DeviceGemmShuffleXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); DeviceGemmXdl_C_Shuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
c_grid_desc_m_n_ = DeviceGemmShuffleXdl::MakeCGridDescriptor_M_N(M, N, StrideC); b_grid_desc_k0_n_k1_ =
DeviceGemmXdl_C_Shuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmXdl_C_Shuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity( if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
...@@ -244,7 +246,7 @@ struct DeviceGemmShuffleXdl ...@@ -244,7 +246,7 @@ struct DeviceGemmShuffleXdl
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
using Argument = DeviceGemmShuffleXdl::Argument; using Argument = DeviceGemmXdl_C_Shuffle::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, int nrepeat = 1)
{ {
...@@ -285,8 +287,8 @@ struct DeviceGemmShuffleXdl ...@@ -285,8 +287,8 @@ struct DeviceGemmShuffleXdl
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceGemmShuffleXdl::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmXdl_C_Shuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmShuffleXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdl_C_Shuffle::BGridDesc_K0_N_K1>,
remove_reference_t< remove_reference_t<
typename GridwiseGemm:: typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
...@@ -319,8 +321,8 @@ struct DeviceGemmShuffleXdl ...@@ -319,8 +321,8 @@ struct DeviceGemmShuffleXdl
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceGemmShuffleXdl::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmXdl_C_Shuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmShuffleXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdl_C_Shuffle::BGridDesc_K0_N_K1>,
remove_reference_t< remove_reference_t<
typename GridwiseGemm:: typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "host_gemm.hpp" #include "host_gemm.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "device_gemm_shuffle_xdl.hpp" #include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
...@@ -32,44 +32,44 @@ using BElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -32,44 +32,44 @@ using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
// clang-format off // clang-format off
using DeviceGemmShuffleInstance = using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle<
ck::tensor_operation::device::DeviceGemmShuffleXdl<ADataType, // ADataType ADataType, // ADataType
BDataType, // BDataType BDataType, // BDataType
CDataType, // CDataType CDataType, // CDataType
AccDataType, // AccDataType AccDataType, // AccDataType
ALayout, // ALayout ALayout, // ALayout
BLayout, // BLayout BLayout, // BLayout
CLayout, // CLayout CLayout, // CLayout
AElementOp, // AElementwiseOperation AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation BElementOp, // BElementwiseOperation
CElementOp, // CElementwiseOperation CElementOp, // CElementwiseOperation
256, // BlockSize 256, // BlockSize
256, // MPerBlock 256, // MPerBlock
128, // NPerBlock 128, // NPerBlock
4, // K0PerBlock 4, // K0PerBlock
8, // K1 8, // K1
32, // MPerXDL 32, // MPerXDL
32, // NPerXDL 32, // NPerXDL
4, // MXdlPerWave 4, // MXdlPerWave
2, // NXdlPerWave 2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim 2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1 8, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM true, // ABlockLdsAddExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim 2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_K1 8, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on // clang-format on
template <typename AType, template <typename AType,
...@@ -192,7 +192,7 @@ int main(int argc, char* argv[]) ...@@ -192,7 +192,7 @@ int main(int argc, char* argv[])
c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data());
// do GEMM // do GEMM
auto gemm = DeviceGemmShuffleInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment