Unverified Commit 2a30cfdd authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into codegen-enable-hiprtc

parents 9533a172 78195ccc
...@@ -56,6 +56,14 @@ if (GPU_TARGETS) ...@@ -56,6 +56,14 @@ if (GPU_TARGETS)
add_definitions(-DCK_USE_WMMA) add_definitions(-DCK_USE_WMMA)
set(CK_USE_WMMA "ON") set(CK_USE_WMMA "ON")
endif() endif()
if (GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx950")
add_definitions(-DCK_USE_OCP_FP8)
set(CK_USE_OCP_FP8 "ON")
endif()
if (GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx94")
add_definitions(-DCK_USE_FNUZ_FP8)
set(CK_USE_FNUZ_FP8 "ON")
endif()
else() else()
add_definitions(-DCK_USE_WMMA -DCK_USE_XDL) add_definitions(-DCK_USE_WMMA -DCK_USE_XDL)
set(CK_USE_XDL "ON") set(CK_USE_XDL "ON")
......
[Back to the main page](../README.md)
# Composable Kernel client examples
## ##
Client application links to CK library, and therefore CK library needs to be installed before building client applications. Client application links to CK library, and therefore CK library needs to be installed before building client applications.
......
...@@ -66,7 +66,7 @@ else() ...@@ -66,7 +66,7 @@ else()
-Wunreachable-code -Wunreachable-code
-Wunused -Wunused
-Wno-reserved-identifier -Wno-reserved-identifier
-Werror -Werror
-Wno-option-ignored -Wno-option-ignored
-Wsign-compare -Wsign-compare
-Wno-extra-semi-stmt -Wno-extra-semi-stmt
......
[Back to the main page](../README.md)
# Composable Kernel codegen
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <functional> #include <functional>
#include <iostream> #include <iostream>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
// defines all values need for an instance of fwd conv
struct Operation_Xdl_CShuffle
{
// returns a vector of instances, only given fusion operators: will use default problem spec
static std::vector<std::vector<Operation_Xdl_CShuffle>>
CreateOperations(const std::string& prologue, const std::string& epilogue);
// returns a vector of instances, given a problem spec and fusion operators
static std::vector<Operation_Xdl_CShuffle>
CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue);
TensorDesc A{};
TensorDesc B{};
TensorDesc B1{};
TensorDesc C{};
DataType acc = DataType::Float;
DataType cs_type = DataType::Half;
std::string a_elem_op = PassThrough;
std::string b_elem_op = PassThrough;
std::string b1_elem_op = PassThrough;
std::string c_elem_op = PassThrough;
std::string acc_elem_op = Scale;
std::string prologue = "";
std::string epilogue = "";
std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default";
// tuning parameters
operation::TileDescGemmGemm tile_desc{};
operation::BlockTransferDesc a_block_transfer{};
operation::BlockTransferDesc b0_block_transfer{};
operation::BlockTransferDesc b1_block_transfer{};
operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{};
bool mask_out_upper_triangle = false;
// functions to update fusion operators if provided
void update_prologue(const std::string& prologue);
void update_epilogue(const std::string& epilogue);
/**constexpr**/ bool
IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_, std::size_t Gemm1NRaw_);
// returns a templated instance
Solution ToSolution() const;
};
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
// defines the problem specification for a GEMM operation
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
std::size_t O = 0;
bool TransA = false;
bool TransB = false;
bool TransB1 = false;
bool TransC = false;
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType B1DataType = DataType::Half;
DataType CDataType = DataType::Half;
std::string AElementOp = PassThrough;
std::string BElementOp = PassThrough;
std::string B1ElementOp = PassThrough;
std::string CElementOp = PassThrough;
std::string AccElementOp = Scale;
// returns the correct device op file for the operation
std::string GetIncludeHeader() const;
// returns a list of instances based on the problem spec and provided fusion operations
std::vector<Solution> GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const;
};
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
...@@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle ...@@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle
operation::BlockTransferDesc b_block_transfer{}; operation::BlockTransferDesc b_block_transfer{};
operation::CShuffleDesc cshuffle{}; operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{}; operation::CBlockTransferDesc c_block_transfer{};
LoopScheduler loop_scheduler{};
PipelineVersion pipeline_version{};
// functions to update fusion operators if provided // functions to update fusion operators if provided
void update_prologue(const std::string& prologue); void update_prologue(const std::string& prologue);
......
...@@ -23,6 +23,26 @@ struct TileDesc ...@@ -23,6 +23,26 @@ struct TileDesc
int n_Xdl_per_wave = 0; int n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0; int num_gemmk_prefetch_stage = 0;
}; };
struct TileDescGemmGemm
{
int block_size = 0;
int gemm01_m_per_block = 0;
int gemm0_n_per_block = 0;
int gemm0_k_per_block = 0;
int gemm1_n_per_block = 0;
int gemm1_k_per_block = 0;
int ak1 = 0;
int bk1 = 0;
int b1k1 = 0;
int m_per_XDL = 0;
int n_per_XDL = 0;
int gemm0_m_Xdl_per_wave = 0;
int gemm0_n_Xdl_per_wave = 0;
int gemm1_n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0;
};
struct BlockTransferDesc struct BlockTransferDesc
{ {
std::string thread_cluster_length = ""; std::string thread_cluster_length = "";
......
...@@ -66,6 +66,20 @@ enum class GemmType ...@@ -66,6 +66,20 @@ enum class GemmType
}; };
std::string ToString(GemmType gt); std::string ToString(GemmType gt);
enum class LoopScheduler
{
Default,
Interwave,
};
std::string ToString(LoopScheduler ls);
enum class PipelineVersion
{
v1,
v2
};
std::string ToString(PipelineVersion pv);
struct TensorDesc struct TensorDesc
{ {
DataType element; DataType element;
...@@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...}); ...@@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...});
constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough"; constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough";
constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear"; constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear";
constexpr const char* Scale = "ck::tensor_operation::element_wise::Scale";
} // namespace host } // namespace host
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
// return the relevant device op file based on the operation
std::string Problem::GetIncludeHeader() const
{
return "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp";
}
// returns templated instances when provided with a problem specification
std::vector<Solution> Problem::GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const
{
if(get_xdlop_archs().count(arch) == 0)
return {};
auto ops = ck::host::device_batched_gemm_softmax_gemm::Operation_Xdl_CShuffle::CreateOperations(
*this, prologue, epilogue); // obtains vector of instances
std::vector<Solution> result;
std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) {
return op.ToSolution(); // template instance with correct values
});
return result;
}
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
...@@ -62,6 +62,12 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi) ...@@ -62,6 +62,12 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
// accounts for all possible combinations of Row/Col major // accounts for all possible combinations of Row/Col major
static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; }
// clang-format off
// DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1,
// DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>
// clang-format on
// Hard-code tuning parameters in modularized fashion, string them together into a vector of // Hard-code tuning parameters in modularized fashion, string them together into a vector of
// instances // instances
std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...@@ -83,6 +89,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -83,6 +89,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1}, { 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1},
{ 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1}, { 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1},
{ 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1}, { 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1},
// Irregular tile
{ 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, 1},
// clang-format on // clang-format on
}; };
...@@ -100,6 +108,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -100,6 +108,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
// Irregular tile
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
// clang-format on // clang-format on
}; };
...@@ -109,15 +119,17 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -109,15 +119,17 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| // ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | // Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | | // | | | | | | |
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
// Irregular tile
{ S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
// clang-format on // clang-format on
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
}; };
std::vector<operation::BlockTransferDesc> b_block_descriptions_rowmajor = { std::vector<operation::BlockTransferDesc> b_block_descriptions_rowmajor = {
...@@ -134,6 +146,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -134,6 +146,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
// Irregular tile
{ S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
// clang-format on // clang-format on
}; };
...@@ -151,6 +165,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -151,6 +165,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
// Irregular tile
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
// clang-format on // clang-format on
}; };
...@@ -167,6 +183,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -167,6 +183,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ 1, 1}, { 1, 1},
{ 1, 1}, { 1, 1},
{ 1, 1}, { 1, 1},
{ 1, 1},
{ 1, 1}, { 1, 1},
// clang-format on // clang-format on
}; };
...@@ -185,6 +202,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -185,6 +202,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<1, 16, 1, 8>, 8}, { S<1, 16, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8}, { S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8}, { S<1, 32, 1, 8>, 8},
// Irregular tile
{ S<1, 16, 1, 4>, 1},
// clang-format on // clang-format on
}; };
...@@ -199,33 +218,44 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -199,33 +218,44 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
assert(tile_descriptions.size() == cshuffle_descriptions.size()); assert(tile_descriptions.size() == cshuffle_descriptions.size());
assert(tile_descriptions.size() == c_block_descriptions.size()); assert(tile_descriptions.size() == c_block_descriptions.size());
// Put all values together into a single operation > store into the result vector const std::vector<std::tuple<LoopScheduler, PipelineVersion>> scheduler_pipeline_descriptions =
for(std::size_t i = 0; i < tile_descriptions.size(); i++) {
{LoopScheduler::Default, PipelineVersion::v1},
{LoopScheduler::Interwave, PipelineVersion::v1},
{LoopScheduler::Default, PipelineVersion::v2},
};
for(auto [loop_scheduler, pipeline_version] : scheduler_pipeline_descriptions)
{ {
Operation_Xdl_CShuffle x; // Put all values together into a single operation > store into the result vector
x.tile_desc = tile_descriptions[i]; for(std::size_t i = 0; i < tile_descriptions.size(); i++)
x.a_block_transfer = a_block_descriptions[i]; {
x.b_block_transfer = b_block_descriptions[i]; Operation_Xdl_CShuffle x;
x.cshuffle = cshuffle_descriptions[i]; x.tile_desc = tile_descriptions[i];
x.c_block_transfer = c_block_descriptions[i]; x.a_block_transfer = a_block_descriptions[i];
x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; x.b_block_transfer = b_block_descriptions[i];
x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)}; x.cshuffle = cshuffle_descriptions[i];
x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)}; x.c_block_transfer = c_block_descriptions[i];
x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) { x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)};
return TensorDesc{dt, ToLayout(trans)}; x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)};
}); x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)};
x.a_elem_op = prob.AElementOp; x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) {
x.b_elem_op = prob.BElementOp; return TensorDesc{dt, ToLayout(trans)};
x.cde_elem_op = prob.CDEElementOp; });
x.gemm_specialization = GetGemmSpec(prob.M, x.a_elem_op = prob.AElementOp;
prob.N, x.b_elem_op = prob.BElementOp;
prob.K, x.cde_elem_op = prob.CDEElementOp;
x.tile_desc.m_per_block, x.gemm_specialization = GetGemmSpec(prob.M,
x.tile_desc.n_per_block, prob.N,
x.tile_desc.k_per_block); prob.K,
x.update_prologue(prologue); x.tile_desc.m_per_block,
x.update_epilogue(epilogue); x.tile_desc.n_per_block,
result.push_back(x); x.tile_desc.k_per_block);
x.loop_scheduler = loop_scheduler;
x.pipeline_version = pipeline_version;
x.update_prologue(prologue);
x.update_epilogue(epilogue);
result.push_back(x);
}
} }
return result; return result;
} }
...@@ -263,7 +293,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate = ...@@ -263,7 +293,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
"${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, " "${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, "
"${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, " "${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
"${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, " "${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
"${CDEBlockTransferScalarPerVector_NPerBlock}>"; "${CDEBlockTransferScalarPerVector_NPerBlock}, ${LoopScheduler}, ${PipelineVersion}>";
// use hardcoded instances from vector of operations to substitute values into instance template // use hardcoded instances from vector of operations to substitute values into instance template
Solution Operation_Xdl_CShuffle::ToSolution() const Solution Operation_Xdl_CShuffle::ToSolution() const
...@@ -336,6 +366,8 @@ Solution Operation_Xdl_CShuffle::ToSolution() const ...@@ -336,6 +366,8 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl},
{"CDEBlockTransferScalarPerVector_NPerBlock", {"CDEBlockTransferScalarPerVector_NPerBlock",
std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)},
{"LoopScheduler", ToString(this->loop_scheduler)},
{"PipelineVersion", ToString(this->pipeline_version)},
}; };
return Solution{InterpolateString(DeviceGemmMultipleD_Xdl_CShuffleTemplate, values), return Solution{InterpolateString(DeviceGemmMultipleD_Xdl_CShuffleTemplate, values),
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/headers.hpp" #include "ck/host/headers.hpp"
#include "ck_headers.hpp" #include "ck_headers.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/types.hpp" #include "ck/host/types.hpp"
#include "ck/host/stringutils.hpp" #include "ck/host/stringutils.hpp"
#include <algorithm> #include <algorithm>
...@@ -56,6 +59,26 @@ std::string ToString(GemmType gt) ...@@ -56,6 +59,26 @@ std::string ToString(GemmType gt)
throw std::runtime_error("Incorrect gemm type"); throw std::runtime_error("Incorrect gemm type");
} }
std::string ToString(LoopScheduler ls)
{
switch(ls)
{
case LoopScheduler::Default: return "ck::LoopScheduler::Default";
case LoopScheduler::Interwave: return "ck::LoopScheduler::Interwave";
}
throw std::runtime_error("Incorrect LoopScheduler type");
}
std::string ToString(PipelineVersion pv)
{
switch(pv)
{
case PipelineVersion::v1: return "ck::PipelineVersion::v1";
case PipelineVersion::v2: return "ck::PipelineVersion::v2";
}
throw std::runtime_error("Incorrect PipelineVersion type");
}
std::string SequenceStr(const std::vector<int>& v) std::string SequenceStr(const std::vector<int>& v)
{ {
return "ck::Sequence<" + return "ck::Sequence<" +
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_gemm_multiple_d/problem.hpp" #include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp" #include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/headers.hpp" #include "ck/host/headers.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
#include "ck/host/headers.hpp" #include "ck/host/headers.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
#include "ck/host/headers.hpp" #include "ck/host/headers.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
#include "ck/host/headers.hpp" #include "ck/host/headers.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
#include "ck/host/headers.hpp" #include "ck/host/headers.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment