Commit 66206c23 authored by Chao Liu's avatar Chao Liu
Browse files

rename

parent ad8c418d
#ifndef DEVICE_GEMM_SHUFFLE_XDL_HPP #ifndef DEVICE_GEMM_XDL_C_SHUFFLE_HPP
#define DEVICE_GEMM_SHUFFLE_XDL_HPP #define DEVICE_GEMM_XDL_C_SHUFFLE_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -55,7 +55,7 @@ template < ...@@ -55,7 +55,7 @@ template <
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl> index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct DeviceGemmShuffleXdl struct DeviceGemmXdl_C_Shuffle
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> : public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -207,9 +207,11 @@ struct DeviceGemmShuffleXdl ...@@ -207,9 +207,11 @@ struct DeviceGemmShuffleXdl
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
a_grid_desc_k0_m_k1_ = DeviceGemmShuffleXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); a_grid_desc_k0_m_k1_ =
b_grid_desc_k0_n_k1_ = DeviceGemmShuffleXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); DeviceGemmXdl_C_Shuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
c_grid_desc_m_n_ = DeviceGemmShuffleXdl::MakeCGridDescriptor_M_N(M, N, StrideC); b_grid_desc_k0_n_k1_ =
DeviceGemmXdl_C_Shuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmXdl_C_Shuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity( if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
...@@ -244,7 +246,7 @@ struct DeviceGemmShuffleXdl ...@@ -244,7 +246,7 @@ struct DeviceGemmShuffleXdl
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
using Argument = DeviceGemmShuffleXdl::Argument; using Argument = DeviceGemmXdl_C_Shuffle::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, int nrepeat = 1)
{ {
...@@ -285,8 +287,8 @@ struct DeviceGemmShuffleXdl ...@@ -285,8 +287,8 @@ struct DeviceGemmShuffleXdl
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceGemmShuffleXdl::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmXdl_C_Shuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmShuffleXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdl_C_Shuffle::BGridDesc_K0_N_K1>,
remove_reference_t< remove_reference_t<
typename GridwiseGemm:: typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
...@@ -319,8 +321,8 @@ struct DeviceGemmShuffleXdl ...@@ -319,8 +321,8 @@ struct DeviceGemmShuffleXdl
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceGemmShuffleXdl::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmXdl_C_Shuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmShuffleXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdl_C_Shuffle::BGridDesc_K0_N_K1>,
remove_reference_t< remove_reference_t<
typename GridwiseGemm:: typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "host_gemm.hpp" #include "host_gemm.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "device_gemm_shuffle_xdl.hpp" #include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
...@@ -32,8 +32,8 @@ using BElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -32,8 +32,8 @@ using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
// clang-format off // clang-format off
using DeviceGemmShuffleInstance = using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle<
ck::tensor_operation::device::DeviceGemmShuffleXdl<ADataType, // ADataType ADataType, // ADataType
BDataType, // BDataType BDataType, // BDataType
CDataType, // CDataType CDataType, // CDataType
AccDataType, // AccDataType AccDataType, // AccDataType
...@@ -192,7 +192,7 @@ int main(int argc, char* argv[]) ...@@ -192,7 +192,7 @@ int main(int argc, char* argv[])
c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data());
// do GEMM // do GEMM
auto gemm = DeviceGemmShuffleInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment