Commit 250a89f3 authored by Mirza Halilcevic's avatar Mirza Halilcevic
Browse files

Replace gemm_gemm with gemm_multiple_d_gemm_multiple_d.

parent d1e9682a
...@@ -8,13 +8,13 @@ ...@@ -8,13 +8,13 @@
#include <string> #include <string>
#include "ck/host/types.hpp" #include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp" #include "ck/host/operation/gemm.hpp"
#include "ck/host/device_gemm_elementwise_gemm/problem.hpp" #include "ck/host/device_batched_gemm_multiple_d_gemm_multiple_d/problem.hpp"
namespace ck { namespace ck {
namespace host { namespace host {
namespace device_gemm_elementwise_gemm { namespace device_batched_gemm_multiple_d_gemm_multiple_d {
// defines all values need for an instance of fwd conv // defines all values needed for an instance
struct Operation_Xdl_CShuffle struct Operation_Xdl_CShuffle
{ {
// returns a vector of instances, only given fusion operators: will use default problem spec // returns a vector of instances, only given fusion operators: will use default problem spec
...@@ -23,36 +23,40 @@ struct Operation_Xdl_CShuffle ...@@ -23,36 +23,40 @@ struct Operation_Xdl_CShuffle
// returns a vector of instances, given a problem spec and fusion operators // returns a vector of instances, given a problem spec and fusion operators
static std::vector<Operation_Xdl_CShuffle> static std::vector<Operation_Xdl_CShuffle>
CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue); CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue);
TensorDesc A{}; TensorDesc A0{};
TensorDesc B0{}; TensorDesc B0{};
std::vector<TensorDesc> D0s = {};
TensorDesc B1{}; TensorDesc B1{};
TensorDesc C{}; std::vector<TensorDesc> D1s = {};
DataType acc = DataType::Float; TensorDesc E1{};
DataType cs_type = DataType::Half; DataType acc_type = DataType::Float;
std::string a_elem_op = PassThrough; DataType cshuffle_type = DataType::Float;
std::string b0_elem_op = PassThrough; std::string a0_elem_op = PassThrough;
std::string acc0_elem_op = PassThrough; std::string b0_elem_op = PassThrough;
std::string b1_elem_op = PassThrough; std::string cde0_elem_op = PassThrough;
std::string c_elem_op = PassThrough; std::string b1_elem_op = PassThrough;
std::string prologue = ""; std::string cde1_elem_op = PassThrough;
std::string epilogue = ""; std::string prologue = "";
std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default"; std::string epilogue = "";
// tuning parameters // tuning parameters
operation::TileDescGemmElementwiseGemm tile_desc{}; operation::PaddingDesc padding_desc{};
operation::BlockTransferDesc a_block_transfer{}; operation::TileDescGemmGemm tile_desc{};
operation::BlockTransferDesc a0_block_transfer{};
operation::BlockTransferDesc b0_block_transfer{}; operation::BlockTransferDesc b0_block_transfer{};
operation::BlockTransferDesc cde0_block_transfer{};
operation::BlockTransferDesc b1_block_transfer{}; operation::BlockTransferDesc b1_block_transfer{};
operation::CShuffleDesc cshuffle{}; operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{}; operation::CBlockTransferDesc cde1_block_transfer{};
// functions to update fusion operators if provided // functions to update fusion operators if provided
void update_prologue(const std::string& prologue); void update_prologue(const std::string& prologue);
void update_epilogue(const std::string& epilogue); void update_epilogue(const std::string& epilogue);
/**constexpr**/ bool IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_); /**constexpr**/ bool
IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_, std::size_t Gemm1NRaw_);
// returns a templated instance // returns a templated instance
Solution ToSolution() const; Solution ToSolution() const;
}; };
} // namespace device_gemm_elementwise_gemm } // namespace device_batched_gemm_multiple_d_gemm_multiple_d
} // namespace host } // namespace host
} // namespace ck } // namespace ck
...@@ -10,28 +10,32 @@ ...@@ -10,28 +10,32 @@
namespace ck { namespace ck {
namespace host { namespace host {
namespace device_gemm_elementwise_gemm { namespace device_batched_gemm_multiple_d_gemm_multiple_d {
// defines the problem specification for a GEMM operation // defines the problem specification for a GEMM_ELEMENTWISE_GEMM operation
struct Problem struct Problem
{ {
std::size_t M = 0; std::size_t M = 0;
std::size_t N = 0; std::size_t N = 0;
std::size_t K = 0; std::size_t K = 0;
std::size_t O = 0; std::size_t O = 0;
bool TransA = false; bool TransA0 = false;
bool TransB0 = false; bool TransB0 = false;
bool TransB1 = false; std::vector<bool> D0sTrans = {};
bool TransC = false; bool TransB1 = false;
DataType ADataType = DataType::Half; std::vector<bool> D1sTrans = {};
DataType B0DataType = DataType::Half; bool TransE1 = false;
DataType B1DataType = DataType::Half; DataType A0DataType = DataType::Half;
DataType CDataType = DataType::Half; DataType B0DataType = DataType::Half;
std::string AElementOp = PassThrough; std::vector<DataType> D0sDataType = {};
std::string B0ElementOp = PassThrough; DataType B1DataType = DataType::Half;
std::string Acc0ElementOp = PassThrough; std::vector<DataType> D1sDataType = {};
std::string B1ElementOp = PassThrough; DataType E1DataType = DataType::Half;
std::string CElementOp = PassThrough; std::string A0ElementOp = PassThrough;
std::string B0ElementOp = PassThrough;
std::string CDE0ElementOp = PassThrough;
std::string B1ElementOp = PassThrough;
std::string CDE1ElementOp = PassThrough;
// returns the correct device op file for the operation // returns the correct device op file for the operation
std::string GetIncludeHeader() const; std::string GetIncludeHeader() const;
...@@ -42,6 +46,6 @@ struct Problem ...@@ -42,6 +46,6 @@ struct Problem
const std::string& epilogue) const; const std::string& epilogue) const;
}; };
} // namespace device_gemm_elementwise_gemm } // namespace device_batched_gemm_multiple_d_gemm_multiple_d
} // namespace host } // namespace host
} // namespace ck } // namespace ck
...@@ -9,6 +9,15 @@ namespace ck { ...@@ -9,6 +9,15 @@ namespace ck {
namespace host { namespace host {
namespace operation { namespace operation {
struct PaddingDesc
{
bool pad_gemm0_m = 0;
bool pad_gemm0_n = 0;
bool pad_gemm0_k = 0;
bool pad_gemm1_n = 0;
bool pad_gemm1_k = 0;
};
struct TileDesc struct TileDesc
{ {
int block_size = 0; int block_size = 0;
...@@ -24,23 +33,23 @@ struct TileDesc ...@@ -24,23 +33,23 @@ struct TileDesc
int num_gemmk_prefetch_stage = 0; int num_gemmk_prefetch_stage = 0;
}; };
struct TileDescGemmElementwiseGemm struct TileDescGemmGemm
{ {
int block_size = 0; int block_size = 0;
int gemm01_m_per_block = 0; int gemm0_m_per_block = 0;
int gemm0_n_per_block = 0; int gemm0_n_per_block = 0;
int gemm0_k_per_block = 0; int gemm0_k_per_block = 0;
int gemm1_n_per_block = 0; int gemm1_n_per_block = 0;
int gemm1_k_per_block = 0; int gemm1_k_per_block = 0;
int ak1 = 0; int a0k1 = 0;
int bk1 = 0; int b0k1 = 0;
int b1k1 = 0; int b1k1 = 0;
int m_per_XDL = 0; int m_per_XDL = 0;
int n_per_XDL = 0; int n_per_XDL = 0;
int gemm0_m_Xdl_per_wave = 0; int gemm0_m_Xdl_per_wave = 0;
int gemm0_n_Xdl_per_wave = 0; int gemm0_n_Xdl_per_wave = 0;
int gemm1_n_Xdl_per_wave = 0; int gemm1_n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0; int num_gemm0k_prefetch_stage = 0;
}; };
struct BlockTransferDesc struct BlockTransferDesc
...@@ -53,11 +62,13 @@ struct BlockTransferDesc ...@@ -53,11 +62,13 @@ struct BlockTransferDesc
int dst_scalar_per_vector_k1 = 0; int dst_scalar_per_vector_k1 = 0;
int lds_add_extra_dim = 0; int lds_add_extra_dim = 0;
}; };
struct CShuffleDesc struct CShuffleDesc
{ {
int m_Xdl_per_wave_per_shuffle = 0; int m_Xdl_per_wave_per_shuffle = 0;
int n_Xdl_per_wave_per_shuffle = 0; int n_Xdl_per_wave_per_shuffle = 0;
}; };
struct CBlockTransferDesc struct CBlockTransferDesc
{ {
std::string cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl = ""; std::string cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl = "";
......
...@@ -2,19 +2,20 @@ ...@@ -2,19 +2,20 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_gemm_elementwise_gemm/problem.hpp" #include "ck/host/device_batched_gemm_multiple_d_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_elementwise_gemm/operation.hpp" #include "ck/host/device_batched_gemm_multiple_d_gemm_multiple_d/operation.hpp"
#include "ck/host/utils.hpp" #include "ck/host/utils.hpp"
#include <algorithm> #include <algorithm>
namespace ck { namespace ck {
namespace host { namespace host {
namespace device_gemm_elementwise_gemm { namespace device_batched_gemm_multiple_d_gemm_multiple_d {
// return the relevant device op file based on the operation // return the relevant device op file based on the operation
std::string Problem::GetIncludeHeader() const std::string Problem::GetIncludeHeader() const
{ {
return "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"; return "ck/tensor_operation/gpu/device/impl/"
"device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp";
} }
// returns templated instances when provided with a problem specification // returns templated instances when provided with a problem specification
...@@ -24,8 +25,8 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch, ...@@ -24,8 +25,8 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch,
{ {
if(get_xdlop_archs().count(arch) == 0) if(get_xdlop_archs().count(arch) == 0)
return {}; return {};
auto ops = ck::host::device_gemm_elementwise_gemm::Operation_Xdl_CShuffle::CreateOperations( auto ops = ck::host::device_batched_gemm_multiple_d_gemm_multiple_d::Operation_Xdl_CShuffle::
*this, prologue, epilogue); // obtains vector of instances CreateOperations(*this, prologue, epilogue); // obtains vector of instances
std::vector<Solution> result; std::vector<Solution> result;
std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) { std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) {
return op.ToSolution(); // template instance with correct values return op.ToSolution(); // template instance with correct values
...@@ -33,6 +34,6 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch, ...@@ -33,6 +34,6 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch,
return result; return result;
} }
} // namespace device_gemm_elementwise_gemm } // namespace device_batched_gemm_multiple_d_gemm_multiple_d
} // namespace host } // namespace host
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment