Commit 250a89f3 authored by Mirza Halilcevic's avatar Mirza Halilcevic
Browse files

Replace gemm_gemm with gemm_multiple_d_gemm_multiple_d.

parent d1e9682a
......@@ -8,13 +8,13 @@
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_gemm_elementwise_gemm/problem.hpp"
#include "ck/host/device_batched_gemm_multiple_d_gemm_multiple_d/problem.hpp"
namespace ck {
namespace host {
namespace device_gemm_elementwise_gemm {
namespace device_batched_gemm_multiple_d_gemm_multiple_d {
// defines all values need for an instance of fwd conv
// defines all values needed for an instance
struct Operation_Xdl_CShuffle
{
// returns a vector of instances, only given fusion operators: will use default problem spec
......@@ -23,36 +23,40 @@ struct Operation_Xdl_CShuffle
// returns a vector of instances, given a problem spec and fusion operators
static std::vector<Operation_Xdl_CShuffle>
CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue);
TensorDesc A{};
TensorDesc A0{};
TensorDesc B0{};
std::vector<TensorDesc> D0s = {};
TensorDesc B1{};
TensorDesc C{};
DataType acc = DataType::Float;
DataType cs_type = DataType::Half;
std::string a_elem_op = PassThrough;
std::vector<TensorDesc> D1s = {};
TensorDesc E1{};
DataType acc_type = DataType::Float;
DataType cshuffle_type = DataType::Float;
std::string a0_elem_op = PassThrough;
std::string b0_elem_op = PassThrough;
std::string acc0_elem_op = PassThrough;
std::string cde0_elem_op = PassThrough;
std::string b1_elem_op = PassThrough;
std::string c_elem_op = PassThrough;
std::string cde1_elem_op = PassThrough;
std::string prologue = "";
std::string epilogue = "";
std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default";
// tuning parameters
operation::TileDescGemmElementwiseGemm tile_desc{};
operation::BlockTransferDesc a_block_transfer{};
operation::PaddingDesc padding_desc{};
operation::TileDescGemmGemm tile_desc{};
operation::BlockTransferDesc a0_block_transfer{};
operation::BlockTransferDesc b0_block_transfer{};
operation::BlockTransferDesc cde0_block_transfer{};
operation::BlockTransferDesc b1_block_transfer{};
operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{};
operation::CBlockTransferDesc cde1_block_transfer{};
// functions to update fusion operators if provided
void update_prologue(const std::string& prologue);
void update_epilogue(const std::string& epilogue);
/**constexpr**/ bool IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_);
/**constexpr**/ bool
IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_, std::size_t Gemm1NRaw_);
// returns a templated instance
Solution ToSolution() const;
};
} // namespace device_gemm_elementwise_gemm
} // namespace device_batched_gemm_multiple_d_gemm_multiple_d
} // namespace host
} // namespace ck
......@@ -10,28 +10,32 @@
namespace ck {
namespace host {
namespace device_gemm_elementwise_gemm {
namespace device_batched_gemm_multiple_d_gemm_multiple_d {
// defines the problem specification for a GEMM operation
// defines the problem specification for a GEMM_ELEMENTWISE_GEMM operation
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
std::size_t O = 0;
bool TransA = false;
bool TransA0 = false;
bool TransB0 = false;
std::vector<bool> D0sTrans = {};
bool TransB1 = false;
bool TransC = false;
DataType ADataType = DataType::Half;
std::vector<bool> D1sTrans = {};
bool TransE1 = false;
DataType A0DataType = DataType::Half;
DataType B0DataType = DataType::Half;
std::vector<DataType> D0sDataType = {};
DataType B1DataType = DataType::Half;
DataType CDataType = DataType::Half;
std::string AElementOp = PassThrough;
std::vector<DataType> D1sDataType = {};
DataType E1DataType = DataType::Half;
std::string A0ElementOp = PassThrough;
std::string B0ElementOp = PassThrough;
std::string Acc0ElementOp = PassThrough;
std::string CDE0ElementOp = PassThrough;
std::string B1ElementOp = PassThrough;
std::string CElementOp = PassThrough;
std::string CDE1ElementOp = PassThrough;
// returns the correct device op file for the operation
std::string GetIncludeHeader() const;
......@@ -42,6 +46,6 @@ struct Problem
const std::string& epilogue) const;
};
} // namespace device_gemm_elementwise_gemm
} // namespace device_batched_gemm_multiple_d_gemm_multiple_d
} // namespace host
} // namespace ck
......@@ -9,6 +9,15 @@ namespace ck {
namespace host {
namespace operation {
struct PaddingDesc
{
bool pad_gemm0_m = 0;
bool pad_gemm0_n = 0;
bool pad_gemm0_k = 0;
bool pad_gemm1_n = 0;
bool pad_gemm1_k = 0;
};
struct TileDesc
{
int block_size = 0;
......@@ -24,23 +33,23 @@ struct TileDesc
int num_gemmk_prefetch_stage = 0;
};
struct TileDescGemmElementwiseGemm
struct TileDescGemmGemm
{
int block_size = 0;
int gemm01_m_per_block = 0;
int gemm0_m_per_block = 0;
int gemm0_n_per_block = 0;
int gemm0_k_per_block = 0;
int gemm1_n_per_block = 0;
int gemm1_k_per_block = 0;
int ak1 = 0;
int bk1 = 0;
int a0k1 = 0;
int b0k1 = 0;
int b1k1 = 0;
int m_per_XDL = 0;
int n_per_XDL = 0;
int gemm0_m_Xdl_per_wave = 0;
int gemm0_n_Xdl_per_wave = 0;
int gemm1_n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0;
int num_gemm0k_prefetch_stage = 0;
};
struct BlockTransferDesc
......@@ -53,11 +62,13 @@ struct BlockTransferDesc
int dst_scalar_per_vector_k1 = 0;
int lds_add_extra_dim = 0;
};
struct CShuffleDesc
{
int m_Xdl_per_wave_per_shuffle = 0;
int n_Xdl_per_wave_per_shuffle = 0;
};
struct CBlockTransferDesc
{
std::string cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl = "";
......
......@@ -2,19 +2,20 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_gemm_elementwise_gemm/problem.hpp"
#include "ck/host/device_gemm_elementwise_gemm/operation.hpp"
#include "ck/host/device_batched_gemm_multiple_d_gemm_multiple_d/problem.hpp"
#include "ck/host/device_batched_gemm_multiple_d_gemm_multiple_d/operation.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
namespace ck {
namespace host {
namespace device_gemm_elementwise_gemm {
namespace device_batched_gemm_multiple_d_gemm_multiple_d {
// return the relevant device op file based on the operation
std::string Problem::GetIncludeHeader() const
{
return "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp";
return "ck/tensor_operation/gpu/device/impl/"
"device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp";
}
// returns templated instances when provided with a problem specification
......@@ -24,8 +25,8 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch,
{
if(get_xdlop_archs().count(arch) == 0)
return {};
auto ops = ck::host::device_gemm_elementwise_gemm::Operation_Xdl_CShuffle::CreateOperations(
*this, prologue, epilogue); // obtains vector of instances
auto ops = ck::host::device_batched_gemm_multiple_d_gemm_multiple_d::Operation_Xdl_CShuffle::
CreateOperations(*this, prologue, epilogue); // obtains vector of instances
std::vector<Solution> result;
std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) {
return op.ToSolution(); // template instance with correct values
......@@ -33,6 +34,6 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch,
return result;
}
} // namespace device_gemm_elementwise_gemm
} // namespace device_batched_gemm_multiple_d_gemm_multiple_d
} // namespace host
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment