"tests/pipelines/animatediff/__init__.py" did not exist on "3f1e95928e905daccb4ae8bce0049612c53f0737"
Commit 7d85d04a authored by ltqin's avatar ltqin
Browse files

change file name

parent 8dd89366
add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
add_example_executable(example_gemm_xdl_skip_lds_fp16 gemm_xdl_skip_lds_fp16.cpp)
\ No newline at end of file
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
\ No newline at end of file
......@@ -11,7 +11,7 @@
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl_skip_lds.hpp"
#include "device_gemm_xdl_skip_b_lds.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
......@@ -43,7 +43,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off
#if USING_SKIP_LDS
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipLds
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipBLds
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BThreadTransfer| CThreadTransfer| CThreadTransfer|
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| SrcScalar| SrcDstVectorDim| DstScalar|
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | PerVector|
......@@ -55,7 +55,9 @@ using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
#else
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 7, 1>;
// < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 7, 1>;
// < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 32, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 7, 1>;
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 64, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 7, 1>;
using ADataType = float;
using BDataType = float;
using CDataType = float;
......@@ -68,7 +70,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, 2>;
// < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, 2>;
// < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 32, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, 2>;
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, 2>;
using ADataType = float;
using BDataType = float;
using CDataType = float;
......@@ -104,20 +108,20 @@ int main(int argc, char* argv[])
// GEMM shape
#if 1
ck::index_t M = 16;
ck::index_t N = 1152;
ck::index_t K = 5120;
ck::index_t M = 64;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 5120;
ck::index_t StrideB = 5120;
ck::index_t StrideC = 1152;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
#else
ck::index_t M = 16;
ck::index_t N = 16;
ck::index_t K = 24;
ck::index_t K = 8;
ck::index_t StrideA = 24;
ck::index_t StrideB = 24;
ck::index_t StrideA = 8;
ck::index_t StrideB = 8;
ck::index_t StrideC = 16;
#endif
......
#ifndef DEVICE_GEMM_XDL_SKIP_LDS_HPP
#define DEVICE_GEMM_XDL_SKIP_LDS_HPP
#ifndef DEVICE_GEMM_XDL_SKIP_B_LDS_HPP
#define DEVICE_GEMM_XDL_SKIP_B_LDS_HPP
#include <iostream>
#include <sstream>
......@@ -10,7 +10,7 @@
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_skip_lds_v2r3.hpp"
#include "gridwise_gemm_xdlops_skip_b_lds_v1.hpp"
#include "gemm_specialization.hpp"
namespace ck {
......@@ -47,7 +47,7 @@ template <typename ADataType,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceGemmXdlSkipLds
struct DeviceGemmXdlSkipBLds
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
......@@ -174,7 +174,7 @@ struct DeviceGemmXdlSkipLds
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3<
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
......@@ -239,9 +239,11 @@ struct DeviceGemmXdlSkipLds
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
a_grid_desc_k0_m_k1_ = DeviceGemmXdlSkipLds::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmXdlSkipLds::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmXdlSkipLds::MakeCGridDescriptor_M_N(M, N, StrideC);
a_grid_desc_k0_m_k1_ =
DeviceGemmXdlSkipBLds::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ =
DeviceGemmXdlSkipBLds::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmXdlSkipBLds::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
......@@ -279,7 +281,7 @@ struct DeviceGemmXdlSkipLds
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceGemmXdlSkipLds::Argument;
using Argument = DeviceGemmXdlSkipBLds::Argument;
float Run(const Argument& arg, int nrepeat = 1)
{
......@@ -316,12 +318,12 @@ struct DeviceGemmXdlSkipLds
if(has_main_k0_block_loop)
{
const auto kernel = kernel_gemm_xdlops_skip_lds_v2r3<
const auto kernel = kernel_gemm_xdlops_skip_b_lds_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSkipLds::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSkipLds::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSkipBLds::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSkipBLds::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
......@@ -348,12 +350,12 @@ struct DeviceGemmXdlSkipLds
}
else
{
const auto kernel = kernel_gemm_xdlops_skip_lds_v2r3<
const auto kernel = kernel_gemm_xdlops_skip_b_lds_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSkipLds::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSkipLds::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSkipBLds::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSkipBLds::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
......@@ -484,7 +486,7 @@ struct DeviceGemmXdlSkipLds
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmXdlSkipLds"
str << "DeviceGemmXdlSkipBLds"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
#ifndef CK_GRIDWISE_GEMM_XDLOPS_SKIP_B_LDS_V2R3_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_SKIP_B_LDS_V2R3_HPP
#ifndef CK_GRIDWISE_GEMM_XDLOPS_SKIP_B_LDS_V1_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_SKIP_B_LDS_V1_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
......@@ -29,7 +29,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_skip_lds_v2r3(
kernel_gemm_xdlops_skip_b_lds_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
......@@ -101,7 +101,7 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment