Commit 7d85d04a authored by ltqin's avatar ltqin
Browse files

change file name

parent 8dd89366
add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
add_example_executable(example_gemm_xdl_skip_lds_fp16 gemm_xdl_skip_lds_fp16.cpp) add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
\ No newline at end of file \ No newline at end of file
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl_skip_lds.hpp" #include "device_gemm_xdl_skip_b_lds.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp" #include "device_gemm_xdl_c_shuffle.hpp"
#include "device_gemm_xdl_cshuffle.hpp" #include "device_gemm_xdl_cshuffle.hpp"
...@@ -43,7 +43,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -43,7 +43,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off // clang-format off
#if USING_SKIP_LDS #if USING_SKIP_LDS
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipLds using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipBLds
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BThreadTransfer| CThreadTransfer| CThreadTransfer| //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BThreadTransfer| CThreadTransfer| CThreadTransfer|
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| SrcScalar| SrcDstVectorDim| DstScalar| //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| SrcScalar| SrcDstVectorDim| DstScalar|
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | PerVector| //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | PerVector|
...@@ -55,7 +55,9 @@ using BDataType = ck::half_t; ...@@ -55,7 +55,9 @@ using BDataType = ck::half_t;
using CDataType = ck::half_t; using CDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
#else #else
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 7, 1>; // < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 7, 1>;
// < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 32, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 7, 1>;
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 64, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 7, 1>;
using ADataType = float; using ADataType = float;
using BDataType = float; using BDataType = float;
using CDataType = float; using CDataType = float;
...@@ -68,7 +70,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl ...@@ -68,7 +70,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, 2>; // < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, 2>;
// < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 32, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, 2>;
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, 2>;
using ADataType = float; using ADataType = float;
using BDataType = float; using BDataType = float;
using CDataType = float; using CDataType = float;
...@@ -104,20 +108,20 @@ int main(int argc, char* argv[]) ...@@ -104,20 +108,20 @@ int main(int argc, char* argv[])
// GEMM shape // GEMM shape
#if 1 #if 1
ck::index_t M = 16; ck::index_t M = 64;
ck::index_t N = 1152; ck::index_t N = 4096;
ck::index_t K = 5120; ck::index_t K = 4096;
ck::index_t StrideA = 5120; ck::index_t StrideA = 4096;
ck::index_t StrideB = 5120; ck::index_t StrideB = 4096;
ck::index_t StrideC = 1152; ck::index_t StrideC = 4096;
#else #else
ck::index_t M = 16; ck::index_t M = 16;
ck::index_t N = 16; ck::index_t N = 16;
ck::index_t K = 24; ck::index_t K = 8;
ck::index_t StrideA = 24; ck::index_t StrideA = 8;
ck::index_t StrideB = 24; ck::index_t StrideB = 8;
ck::index_t StrideC = 16; ck::index_t StrideC = 16;
#endif #endif
......
#ifndef DEVICE_GEMM_XDL_SKIP_LDS_HPP #ifndef DEVICE_GEMM_XDL_SKIP_B_LDS_HPP
#define DEVICE_GEMM_XDL_SKIP_LDS_HPP #define DEVICE_GEMM_XDL_SKIP_B_LDS_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_skip_lds_v2r3.hpp" #include "gridwise_gemm_xdlops_skip_b_lds_v1.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
namespace ck { namespace ck {
...@@ -47,7 +47,7 @@ template <typename ADataType, ...@@ -47,7 +47,7 @@ template <typename ADataType,
ck::index_t BBlockTransferSrcScalarPerVector, ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector> ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceGemmXdlSkipLds struct DeviceGemmXdlSkipBLds
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> : public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -174,7 +174,7 @@ struct DeviceGemmXdlSkipLds ...@@ -174,7 +174,7 @@ struct DeviceGemmXdlSkipLds
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3< using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1<
BlockSize, BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
...@@ -239,9 +239,11 @@ struct DeviceGemmXdlSkipLds ...@@ -239,9 +239,11 @@ struct DeviceGemmXdlSkipLds
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
a_grid_desc_k0_m_k1_ = DeviceGemmXdlSkipLds::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); a_grid_desc_k0_m_k1_ =
b_grid_desc_k0_n_k1_ = DeviceGemmXdlSkipLds::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); DeviceGemmXdlSkipBLds::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
c_grid_desc_m_n_ = DeviceGemmXdlSkipLds::MakeCGridDescriptor_M_N(M, N, StrideC); b_grid_desc_k0_n_k1_ =
DeviceGemmXdlSkipBLds::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmXdlSkipBLds::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity( if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
...@@ -279,7 +281,7 @@ struct DeviceGemmXdlSkipLds ...@@ -279,7 +281,7 @@ struct DeviceGemmXdlSkipLds
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
using Argument = DeviceGemmXdlSkipLds::Argument; using Argument = DeviceGemmXdlSkipBLds::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, int nrepeat = 1)
{ {
...@@ -316,12 +318,12 @@ struct DeviceGemmXdlSkipLds ...@@ -316,12 +318,12 @@ struct DeviceGemmXdlSkipLds
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
{ {
const auto kernel = kernel_gemm_xdlops_skip_lds_v2r3< const auto kernel = kernel_gemm_xdlops_skip_b_lds_v1<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceGemmXdlSkipLds::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmXdlSkipBLds::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSkipLds::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdlSkipBLds::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>, remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation, AElementwiseOperation,
...@@ -348,12 +350,12 @@ struct DeviceGemmXdlSkipLds ...@@ -348,12 +350,12 @@ struct DeviceGemmXdlSkipLds
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_skip_lds_v2r3< const auto kernel = kernel_gemm_xdlops_skip_b_lds_v1<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceGemmXdlSkipLds::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmXdlSkipBLds::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSkipLds::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdlSkipBLds::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>, remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation, AElementwiseOperation,
...@@ -484,7 +486,7 @@ struct DeviceGemmXdlSkipLds ...@@ -484,7 +486,7 @@ struct DeviceGemmXdlSkipLds
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGemmXdlSkipLds" str << "DeviceGemmXdlSkipBLds"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
#ifndef CK_GRIDWISE_GEMM_XDLOPS_SKIP_B_LDS_V2R3_HPP #ifndef CK_GRIDWISE_GEMM_XDLOPS_SKIP_B_LDS_V1_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_SKIP_B_LDS_V2R3_HPP #define CK_GRIDWISE_GEMM_XDLOPS_SKIP_B_LDS_V1_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
...@@ -29,7 +29,7 @@ __global__ void ...@@ -29,7 +29,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdlops_skip_lds_v2r3( kernel_gemm_xdlops_skip_b_lds_v1(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
...@@ -101,7 +101,7 @@ template <index_t BlockSize, ...@@ -101,7 +101,7 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector> index_t CThreadTransferDstScalarPerVector>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment