"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "0537226c75d542225553ca2a7ae3b88fb75b0d7a"
Commit 66206c23 authored by Chao Liu's avatar Chao Liu
Browse files

rename

parent ad8c418d
#ifndef DEVICE_GEMM_SHUFFLE_XDL_HPP
#define DEVICE_GEMM_SHUFFLE_XDL_HPP
#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_HPP
#define DEVICE_GEMM_XDL_C_SHUFFLE_HPP
#include <iostream>
#include <sstream>
......@@ -55,7 +55,7 @@ template <
index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct DeviceGemmShuffleXdl
struct DeviceGemmXdl_C_Shuffle
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
......@@ -207,9 +207,11 @@ struct DeviceGemmShuffleXdl
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
a_grid_desc_k0_m_k1_ = DeviceGemmShuffleXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmShuffleXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmShuffleXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
a_grid_desc_k0_m_k1_ =
DeviceGemmXdl_C_Shuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ =
DeviceGemmXdl_C_Shuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmXdl_C_Shuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
......@@ -244,7 +246,7 @@ struct DeviceGemmShuffleXdl
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceGemmShuffleXdl::Argument;
using Argument = DeviceGemmXdl_C_Shuffle::Argument;
float Run(const Argument& arg, int nrepeat = 1)
{
......@@ -285,8 +287,8 @@ struct DeviceGemmShuffleXdl
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmShuffleXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmShuffleXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl_C_Shuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl_C_Shuffle::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
......@@ -319,8 +321,8 @@ struct DeviceGemmShuffleXdl
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmShuffleXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmShuffleXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl_C_Shuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl_C_Shuffle::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
......
......@@ -12,7 +12,7 @@
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_base.hpp"
#include "device_gemm_shuffle_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp"
template <ck::index_t... Is>
......@@ -32,44 +32,44 @@ using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
using DeviceGemmShuffleInstance =
ck::tensor_operation::device::DeviceGemmShuffleXdl<ADataType, // ADataType
BDataType, // BDataType
CDataType, // CDataType
AccDataType, // AccDataType
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CElementOp, // CElementwiseOperation
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle<
ADataType, // ADataType
BDataType, // BDataType
CDataType, // CDataType
AccDataType, // AccDataType
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CElementOp, // CElementwiseOperation
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
template <typename AType,
......@@ -192,7 +192,7 @@ int main(int argc, char* argv[])
c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data());
// do GEMM
auto gemm = DeviceGemmShuffleInstance{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment