Commit ef326c73 authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/develop' into migraphx-update

parents b7775add e4dfe4d8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/host/types.hpp"
namespace ck {
namespace host {
namespace device_gemm_multiple_d {
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
bool TransA = false;
bool TransB = false;
bool TransE = false;
std::vector<bool> DsTrans = {};
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType EDataType = DataType::Half;
std::vector<DataType> DsDataType = {};
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string CDEElementOp = "ck::Tuple<>";
std::string GetIncludeHeader() const;
std::vector<Solution> GetSolutions(const std::string& arch) const;
};
} // namespace device_gemm_multiple_d
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_gemm_multiple_d/problem.hpp"
namespace ck {
namespace host {
namespace device_gemm_multiple_d {
// defines all values need for an instance of fwd conv
struct Operation_Xdl_CShuffle
{
// returns a vector of instances, only given fusion operators: will use default problem spec
static std::vector<std::vector<Operation_Xdl_CShuffle>>
CreateOperations(const std::string& prologue, const std::string& epilogue);
// returns a vector of instances, given a problem spec and fusion operators
static std::vector<Operation_Xdl_CShuffle>
CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue);
TensorDesc A{};
TensorDesc B{};
DataType acc = DataType::Float;
DataType cs_type = DataType::Half;
std::vector<TensorDesc> Ds = {};
TensorDesc E{};
std::string a_elem_op = PassThrough;
std::string b_elem_op = PassThrough;
std::string cde_elem_op = Bilinear;
std::string prologue = "";
std::string epilogue = "";
std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default";
// tuning parameters
operation::TileDesc tile_desc{};
operation::BlockTransferDesc a_block_transfer{};
operation::BlockTransferDesc b_block_transfer{};
operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{};
// functions to update fusion operators if provided
void update_prologue(const std::string& prologue);
void update_epilogue(const std::string& epilogue);
/**constexpr**/ bool IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_);
// returns a templated instance
Solution ToSolution() const;
};
} // namespace device_gemm_multiple_d
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
namespace ck {
namespace host {
namespace device_gemm_multiple_d {
// defines the problem specification for a GEMM operation
struct Problem
{
// dimensions for GEMM operation
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
// layouts for tensors
bool TransA = false;
bool TransB = false;
bool TransE = false;
std::vector<bool> DsTrans = {};
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType EDataType = DataType::Half;
std::vector<DataType> DsDataType = {};
std::string AElementOp = PassThrough;
std::string BElementOp = PassThrough;
std::string CDEElementOp = PassThrough;
// returns the correct device op file for the operation
std::string GetIncludeHeader() const;
// returns a list of instances based on the problem spec and provided fusion operations
std::vector<Solution> GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const;
};
} // namespace device_gemm_multiple_d
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
namespace ck {
namespace host {
namespace conv {
// defines the values needed for an instance of forward convolution and functions to return
// (templated) instances
struct Operation_Conv_Fwd_Xdl_Cshuffle
{
// returns a vector of instances given the fusion operations, uses default values for problem
// spec
static std::vector<Operation_Conv_Fwd_Xdl_Cshuffle>
CreateOperations(const std::string& prologue, const std::string& epilogue);
// returns a vector of instances, provided with a problem spec and fusion operations
static std::vector<Operation_Conv_Fwd_Xdl_Cshuffle> CreateOperations(
const Problem_Conv_Fwd& prob, const std::string& prologue, const std::string& epilogue);
std::size_t NumDim;
TensorDesc A{};
TensorDesc B{};
DataType acc = DataType::Float;
DataType cs_type = DataType::Half;
std::vector<TensorDesc> Ds = {};
TensorDesc E{};
std::string a_elem_op = PassThrough;
std::string b_elem_op = PassThrough;
std::string cde_elem_op = PassThrough;
std::string prologue = "";
std::string epilogue = "";
std::string conv_specialization =
"ck::tensor_operation::device::ConvolutionForwardSpecialization::Default";
std::string gemm_specialization =
"ck::tensor_operation::device::GemmSpecialization::MNKPadding";
// tuning parameters
operation::TileDesc tile_desc{};
operation::BlockTransferDesc a_block_transfer{};
operation::BlockTransferDesc b_block_transfer{};
operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{};
// functions to update fusion operations if they are provided
void update_prologue(const std::string& prologue);
void update_epilogue(const std::string& epilogue);
// returns a templated instance
Solution ToSolution() const;
};
} // namespace conv
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/host/types.hpp"
namespace ck {
namespace host {
namespace conv {
// defines the problem specification for a forward convolution operation
struct Problem_Conv_Fwd
{
std::size_t NumDim = 0;
// size of a forward convolution operation
std::size_t G = 0;
std::size_t N = 0;
std::size_t C = 0;
std::size_t Hi = 0;
std::size_t Wi = 0;
std::size_t Ho = 0;
std::size_t Wo = 0;
std::size_t K = 0;
std::size_t Y = 0;
std::size_t X = 0;
Layout ALayout = Layout::NHWGC;
Layout BLayout = Layout::GKYXC;
Layout ELayout = Layout::NHWGK;
std::vector<Layout> DsLayout = {};
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType EDataType = DataType::Half;
std::vector<DataType> DsDataType = {};
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string CDEElementOp = "ck::tensor_operation::element_wise::PassThrough";
// returns the correct device op file for the operation
std::string GetIncludeHeader() const;
// returns a list of instances based on the problem spec and provided fusion operations
std::vector<Solution> GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const;
};
} // namespace conv
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <utility>
#include <unordered_map>
#include <vector>
namespace ck {
namespace host {
std::unordered_map<std::string_view, std::string_view> GetHeaders();
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace ck {
namespace host {
namespace operation {
struct TileDesc
{
int block_size = 0;
int m_per_block = 0;
int n_per_block = 0;
int k_per_block = 0;
int ak1 = 0;
int bk1 = 0;
int m_per_XDL = 0;
int n_per_XDL = 0;
int m_Xdl_per_wave = 0;
int n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0;
};
struct BlockTransferDesc
{
std::string thread_cluster_length = "";
std::string thread_cluster_arrange_order = "";
std::string src_access_order = "";
int src_vec_dim = 0;
int src_scalar_per_vector = 0;
int dst_scalar_per_vector_k1 = 0;
int lds_add_extra_dim = 0;
};
struct CShuffleDesc
{
int m_Xdl_per_wave_per_shuffle = 0;
int n_Xdl_per_wave_per_shuffle = 0;
};
struct CBlockTransferDesc
{
std::string cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl = "";
int scalar_per_vector_n_wave_n_per_Xdl = 0;
};
} // namespace operation
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <cassert>
#include <numeric>
#include <string>
#include <utility>
#include <unordered_map>
#include <vector>
namespace ck {
namespace host {
template <class F>
std::string trim(const std::string& s, F f)
{
auto start = std::find_if_not(s.begin(), s.end(), f);
auto last = std::find_if_not(s.rbegin(), std::string::const_reverse_iterator(start), f).base();
return {start, last};
}
inline std::string trim(const std::string& s)
{
return trim(s, [](unsigned char c) { return std::isspace(c); });
}
template <class Strings>
inline std::string JoinStrings(Strings strings, const std::string& delim)
{
auto it = strings.begin();
if(it == strings.end())
return "";
auto nit = std::next(it);
return std::accumulate(nit, strings.end(), *it, [&](std::string x, std::string y) {
return std::move(x) + delim + std::move(y);
});
}
template <class F>
inline std::string
InterpolateString(const std::string& input, F f, std::string start = "${", std::string end = "}")
{
std::string result = "";
result.reserve(input.size());
auto it = input.begin();
while(it != input.end())
{
auto next_start = std::search(it, input.end(), start.begin(), start.end());
auto next_end = std::search(next_start, input.end(), end.begin(), end.end());
result.append(it, next_start);
if(next_start == input.end())
break;
if(next_end == input.end())
{
throw std::runtime_error("Unbalanced brackets");
}
auto r = f(next_start + start.size(), next_end);
result.append(r.begin(), r.end());
it = next_end + end.size();
}
return result;
}
inline std::string InterpolateString(const std::string& input,
const std::unordered_map<std::string, std::string>& vars,
std::string start = "${",
std::string end = "}")
{
return InterpolateString(
input,
[&](auto start_it, auto last_it) {
auto key = trim({start_it, last_it});
auto it = vars.find(key);
if(it == vars.end())
throw std::runtime_error("Unknown key: " + key);
return it->second;
},
std::move(start),
std::move(end));
}
template <class Range, class F>
inline auto Transform(const Range& r, F f) -> std::vector<decltype(f(*r.begin()))>
{
std::vector<decltype(f(*r.begin()))> result;
std::transform(r.begin(), r.end(), std::back_inserter(result), f);
return result;
}
template <class Range1, class Range2, class F>
inline auto Transform(const Range1& r1, const Range2& r2, F f)
-> std::vector<decltype(f(*r1.begin(), *r2.begin()))>
{
std::vector<decltype(f(*r1.begin(), *r2.begin()))> result;
assert(std::distance(r1.begin(), r1.end()) == std::distance(r2.begin(), r2.end()));
std::transform(r1.begin(), r1.end(), r2.begin(), std::back_inserter(result), f);
return result;
}
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <sstream>
#include <utility>
#include <unordered_map>
#include <vector>
namespace ck {
namespace host {
// holds the templated instance, substitues values into template from instancess
struct Solution
{
Solution() = default;
Solution(std::string str, std::unordered_map<std::string, std::string> values);
std::string ToTemplateString() const;
std::string GetTemplateParameter(const std::string& name) const;
template <class T>
T GetTemplateParameter(const std::string& name) const
{
T result;
std::stringstream ss(GetTemplateParameter(name));
ss >> result;
return result;
}
private:
std::string template_str;
std::unordered_map<std::string, std::string> template_values;
};
// supported data types
enum class DataType
{
Half,
Float,
Int8,
Int32
};
std::string ToString(DataType dt);
// supported layouts: gemm and fwd conv
enum class Layout
{
Row,
Column,
GKYXC,
GKCYX,
GNHWK,
GNHWC,
NHWGC,
NHWGK
};
std::string ToString(Layout dl);
Layout ToLayout(bool Trans); // returns the layout for gemm
// supported GEMM types
enum class GemmType
{
Default
};
std::string ToString(GemmType gt);
struct TensorDesc
{
DataType element;
Layout layout;
};
std::string SequenceStr(const std::vector<int>& v);
std::string MakeTuple(const std::vector<std::string>& v);
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
template <int... xs>
const std::string S = SequenceStr({xs...});
#pragma clang diagnostic pop
constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough";
constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear";
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdint>
#include <unordered_set>
#include <numeric>
#include <iterator>
namespace ck {
namespace host {
std::size_t integer_divide_ceil(std::size_t x, std::size_t y);
const std::unordered_set<std::string>& get_xdlop_archs();
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
namespace ck {
namespace host {
namespace device_gemm_multiple_d {
// return the relevant device op file based on the operation
std::string Problem::GetIncludeHeader() const
{
return "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp";
}
// returns templated instances when provided with a problem specification
std::vector<Solution> Problem::GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const
{
if(get_xdlop_archs().count(arch) == 0)
return {};
auto ops = ck::host::device_gemm_multiple_d::Operation_Xdl_CShuffle::CreateOperations(
*this, prologue, epilogue); // obtains vector of instances
std::vector<Solution> result;
std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) {
return op.ToSolution(); // template instance with correct values
});
return result;
}
} // namespace device_gemm_multiple_d
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/types.hpp"
#include "ck/host/utils.hpp"
#include <cassert>
namespace ck {
namespace host {
namespace device_gemm_multiple_d {
// calculate appropriate Gemm Specification based on input tensor dimensions
static std::string GetGemmSpec(const std::size_t m,
const std::size_t n,
const std::size_t k,
const std::size_t m_per_block,
const std::size_t n_per_block,
const std::size_t k_per_block)
{
std::string spec = "";
if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0)
spec += "M";
if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0)
spec += "N";
if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0)
spec += "K";
if(spec == "")
return "ck::tensor_operation::device::GemmSpecialization::Default";
return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding";
}
// function to update prologue/epilogue with user provided operation
void Operation_Xdl_CShuffle::update_prologue(const std::string& pro)
{
if(!pro.empty())
{
this->prologue = pro;
this->cde_elem_op = "CDEElementOp";
}
else
{
this->prologue = "";
}
}
void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
{
if(!epi.empty())
{
this->epilogue = epi;
this->cde_elem_op = "CDEElementOp";
}
else
{
this->epilogue = "";
}
}
// accounts for all possible combinations of Row/Col major
static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; }
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
// instances
std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
const Problem& prob, const std::string& prologue, const std::string& epilogue)
{
std::vector<Operation_Xdl_CShuffle> result;
std::vector<operation::TileDesc> tile_descriptions = {
// clang-format off
// Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| NumGemmK|
// Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch|
// | | | | | | | | Wave| Wave| Stage|
// | | | | | | | | | | |
{ 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1},
{ 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, 1},
{ 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1},
{ 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1},
{ 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, 1},
{ 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1},
{ 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1},
{ 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1},
// clang-format on
};
std::vector<operation::BlockTransferDesc> a_block_descriptions_rowmajor = {
// clang-format off
// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
// clang-format on
};
std::vector<operation::BlockTransferDesc> a_block_descriptions_colmajor = {
// clang-format off
// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
// clang-format on
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
};
std::vector<operation::BlockTransferDesc> b_block_descriptions_rowmajor = {
// clang-format off
// BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
// clang-format on
};
std::vector<operation::BlockTransferDesc> b_block_descriptions_colmajor = {
// clang-format off
// BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
// clang-format on
};
std::vector<operation::CShuffleDesc> cshuffle_descriptions = {
// clang-format off
// CShuffle| CShuffle|
// MXdlPerWave| NXdlPerWave|
// PerShuffle| PerShuffle|
// | |
{ 1, 1},
{ 1, 1},
{ 1, 1},
{ 1, 1},
{ 1, 1},
{ 1, 1},
{ 1, 1},
{ 1, 1},
// clang-format on
};
std::vector<operation::CBlockTransferDesc> c_block_descriptions = {
// clang-format off
// CBlockTransferClusterLengths| CBlockTransfer
// _MBlock_MWaveMPerXdl| ScalarPerVector
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
// |
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 16, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 4>, 8},
{ S<1, 16, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
// clang-format on
};
// choose correct arrangement of tuning parameters based on the layout of each tensor
const auto a_block_descriptions =
prob.TransA ? a_block_descriptions_colmajor : a_block_descriptions_rowmajor;
const auto b_block_descriptions =
prob.TransB ? b_block_descriptions_colmajor : b_block_descriptions_rowmajor;
assert(tile_descriptions.size() == a_block_descriptions.size());
assert(tile_descriptions.size() == b_block_descriptions.size());
assert(tile_descriptions.size() == cshuffle_descriptions.size());
assert(tile_descriptions.size() == c_block_descriptions.size());
// Put all values together into a single operation > store into the result vector
for(std::size_t i = 0; i < tile_descriptions.size(); i++)
{
Operation_Xdl_CShuffle x;
x.tile_desc = tile_descriptions[i];
x.a_block_transfer = a_block_descriptions[i];
x.b_block_transfer = b_block_descriptions[i];
x.cshuffle = cshuffle_descriptions[i];
x.c_block_transfer = c_block_descriptions[i];
x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)};
x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)};
x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)};
x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) {
return TensorDesc{dt, ToLayout(trans)};
});
x.a_elem_op = prob.AElementOp;
x.b_elem_op = prob.BElementOp;
x.cde_elem_op = prob.CDEElementOp;
x.gemm_specialization = GetGemmSpec(prob.M,
prob.N,
prob.K,
x.tile_desc.m_per_block,
x.tile_desc.n_per_block,
x.tile_desc.k_per_block);
x.update_prologue(prologue);
x.update_epilogue(epilogue);
result.push_back(x);
}
return result;
}
// set up instances when not provided with a problem specification, use default operation values and
// all possible layout combinations
std::vector<std::vector<Operation_Xdl_CShuffle>>
Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std::string& epilogue)
{
std::vector<Problem> problems;
for(bool TransA : {true, false})
for(bool TransB : {true, false})
{
Problem prob;
prob.TransA = TransA;
prob.TransB = TransB;
problems.push_back(prob);
}
return Transform(problems,
[&](const Problem& p) { return CreateOperations(p, prologue, epilogue); });
}
static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
"ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle<${LayoutA}, ${LayoutB}, "
"${LayoutDs}, ${LayoutE}, ${ADataType}, ${BDataType}, ${AccDataType}, ${CShuffleDataType}, "
"${DsDataType}, ${EDataType}, ${AElementwiseOperation}, ${BElementwiseOperation}, "
"${CDEElementwiseOperation}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, "
"${MPerBlock}, ${NPerBlock}, ${KPerBlock}, ${AK1}, ${BK1}, ${MPerXDL}, ${NPerXDL}, "
"${MXdlPerWave}, ${NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, "
"${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, "
"${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, "
"${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, "
"${BBlockTransferThreadClusterLengths_BK0_N_BK1}, ${BBlockTransferThreadClusterArrangeOrder}, "
"${BBlockTransferSrcAccessOrder}, ${BBlockTransferSrcVectorDim}, "
"${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, "
"${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
"${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
"${CDEBlockTransferScalarPerVector_NPerBlock}>";
// use hardcoded instances from vector of operations to substitute values into instance template
Solution Operation_Xdl_CShuffle::ToSolution() const
{
std::unordered_map<std::string, std::string> values = {
{"name",
std::to_string(this->tile_desc.block_size) + "_" +
std::to_string(this->tile_desc.m_per_block) + "_" +
std::to_string(this->tile_desc.n_per_block) + "_" +
std::to_string(this->tile_desc.k_per_block) + "_" +
std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" +
std::to_string(this->tile_desc.m_per_XDL) + "_" +
std::to_string(this->tile_desc.n_per_XDL) + "_" +
std::to_string(this->tile_desc.m_Xdl_per_wave) + "_" +
std::to_string(this->tile_desc.n_Xdl_per_wave)},
{"LayoutA", ToString(this->A.layout)},
{"LayoutB", ToString(this->B.layout)},
{"LayoutDs",
MakeTuple(Transform(this->Ds, [](auto tensor) { return ToString(tensor.layout); }))},
{"LayoutE", ToString(this->E.layout)},
{"ADataType", ToString(this->A.element)},
{"BDataType", ToString(this->B.element)},
{"AccDataType", ToString(this->acc)},
{"CShuffleDataType", ToString(this->cs_type)},
{"DsDataType",
MakeTuple(Transform(this->Ds, [](auto tensor) { return ToString(tensor.element); }))},
{"EDataType", ToString(this->E.element)},
{"AElementwiseOperation", this->a_elem_op},
{"BElementwiseOperation", this->b_elem_op},
{"CDEElementwiseOperation", this->cde_elem_op},
{"GemmSpecialization", this->gemm_specialization},
{"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)},
{"BlockSize", std::to_string(this->tile_desc.block_size)},
{"MPerBlock", std::to_string(this->tile_desc.m_per_block)},
{"NPerBlock", std::to_string(this->tile_desc.n_per_block)},
{"KPerBlock", std::to_string(this->tile_desc.k_per_block)},
{"AK1", std::to_string(this->tile_desc.ak1)},
{"BK1", std::to_string(this->tile_desc.bk1)},
{"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)},
{"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)},
{"MXdlPerWave", std::to_string(this->tile_desc.m_Xdl_per_wave)},
{"NXdlPerWave", std::to_string(this->tile_desc.n_Xdl_per_wave)},
{"ABlockTransferThreadClusterLengths_AK0_M_AK1",
this->a_block_transfer.thread_cluster_length},
{"ABlockTransferThreadClusterArrangeOrder",
this->a_block_transfer.thread_cluster_arrange_order},
{"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order},
{"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)},
{"ABlockTransferSrcScalarPerVector",
std::to_string(this->a_block_transfer.src_scalar_per_vector)},
{"ABlockTransferDstScalarPerVector_AK1",
std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)},
{"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)},
{"BBlockTransferThreadClusterLengths_BK0_N_BK1",
this->b_block_transfer.thread_cluster_length},
{"BBlockTransferThreadClusterArrangeOrder",
this->b_block_transfer.thread_cluster_arrange_order},
{"BBlockTransferSrcAccessOrder", this->b_block_transfer.src_access_order},
{"BBlockTransferSrcVectorDim", std::to_string(this->b_block_transfer.src_vec_dim)},
{"BBlockTransferSrcScalarPerVector",
std::to_string(this->b_block_transfer.src_scalar_per_vector)},
{"BBlockTransferDstScalarPerVector_BK1",
std::to_string(this->b_block_transfer.dst_scalar_per_vector_k1)},
{"BBlockLdsExtraN", std::to_string(this->b_block_transfer.lds_add_extra_dim)},
{"CShuffleMXdlPerWavePerShuffle",
std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)},
{"CShuffleNXdlPerWavePerShuffle",
std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)},
{"CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock",
this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl},
{"CDEBlockTransferScalarPerVector_NPerBlock",
std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)},
};
return Solution{InterpolateString(DeviceGemmMultipleD_Xdl_CShuffleTemplate, values),
std::move(values)};
}
} // namespace device_gemm_multiple_d
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
#include <iostream>
namespace ck {
namespace host {
namespace conv {
// return the relevant device op file based on the operation
// NOTE: this is a modified version of the original CK file that calls the kernel from a device
// function and makes the Argument class accessible on the device
std::string Problem_Conv_Fwd::GetIncludeHeader() const
{
return "ck/tensor_operation/gpu/device/impl/"
"codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp";
}
// return vector of forward convolution instances when provided with a problem instance
std::vector<Solution> Problem_Conv_Fwd::GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const
{
if(get_xdlop_archs().count(arch) == 0)
return {};
auto ops = ck::host::conv::Operation_Conv_Fwd_Xdl_Cshuffle::CreateOperations(
*this, prologue, epilogue);
std::vector<Solution> result;
std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) {
return op.ToSolution();
});
return result;
}
} // namespace conv
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include <iostream>
#include "ck/host/stringutils.hpp"
#include "ck/host/types.hpp"
#include "ck/host/utils.hpp"
#include <cassert>
namespace ck {
namespace host {
namespace conv {
// NOTE: in CK, MNKPadding is always used for forward convolution, so didn't
// add GemmSpec function here
// function to update prologue/epilogue with user provided operation
void Operation_Conv_Fwd_Xdl_Cshuffle::update_prologue(const std::string& pro)
{
if(!pro.empty())
{
this->prologue = pro;
this->cde_elem_op = "CDEElementOp";
}
else
{
this->prologue = "";
}
}
void Operation_Conv_Fwd_Xdl_Cshuffle::update_epilogue(const std::string& epi)
{
if(!epi.empty())
{
this->epilogue = epi;
this->cde_elem_op = "CDEElementOp";
}
else
{
this->epilogue = "";
}
}
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
// instances
std::vector<Operation_Conv_Fwd_Xdl_Cshuffle> Operation_Conv_Fwd_Xdl_Cshuffle::CreateOperations(
const Problem_Conv_Fwd& prob, const std::string& prologue, const std::string& epilogue)
{
std::vector<Operation_Conv_Fwd_Xdl_Cshuffle> result;
std::vector<operation::TileDesc> tile_descriptions = {
// clang-format off
// Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| NumGemmK|
// Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch|
// | | | | | | | | Wave| Wave| Stage|
// | | | | | | | | | | |
{ 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, 1},
{ 256, 128, 256, 32, 8, 8, 32, 32, 4, 2, 1},
{ 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1},
{ 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, 1},
{ 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1},
{ 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1}
// clang-format on
};
std::vector<operation::BlockTransferDesc> a_block_descriptions = {
// clang-format off
// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}
// clang-format on
};
std::vector<operation::BlockTransferDesc> b_block_descriptions = {
// clang-format off
// BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}
// clang-format on
};
std::vector<operation::CShuffleDesc> cshuffle_descriptions = {
// clang-format off
// CShuffle| CShuffle|
// MXdlPerWave| NXdlPerWave|
// PerShuffle| PerShuffle|
// | |
{ 1, 1},
{ 1, 1},
{ 1, 1},
{ 1, 1},
{ 1, 1},
{ 1, 1}
// clang-format on
};
std::vector<operation::CBlockTransferDesc> c_block_descriptions = {
// clang-format off
// CBlockTransferClusterLengths| CBlockTransfer
// _MBlock_MWaveMPerXdl| ScalarPerVector
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
// |
{ S<1, 16, 1, 4>, 1},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 16, 1, 4>, 1},
{ S<1, 32, 1, 8>, 8},
{ S<1, 16, 1, 8>, 8}
// clang-format on
};
assert(tile_descriptions.size() == a_block_descriptions.size());
assert(tile_descriptions.size() == b_block_descriptions.size());
assert(tile_descriptions.size() == cshuffle_descriptions.size());
assert(tile_descriptions.size() == c_block_descriptions.size());
// Put all values together into a single operation > store into the result vector
for(std::size_t i = 0; i < tile_descriptions.size(); i++)
{
Operation_Conv_Fwd_Xdl_Cshuffle x;
x.NumDim = prob.NumDim;
x.tile_desc = tile_descriptions[i];
x.a_block_transfer = a_block_descriptions[i];
x.b_block_transfer = b_block_descriptions[i];
x.cshuffle = cshuffle_descriptions[i];
x.c_block_transfer = c_block_descriptions[i];
x.A = TensorDesc{prob.ADataType, prob.ALayout};
x.B = TensorDesc{prob.BDataType, prob.BLayout};
x.E = TensorDesc{prob.EDataType, prob.ELayout};
x.Ds = Transform(prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) {
return TensorDesc{dt, lo};
});
x.a_elem_op = prob.AElementOp;
x.b_elem_op = prob.BElementOp;
x.cde_elem_op = prob.CDEElementOp;
x.update_prologue(prologue);
x.update_epilogue(epilogue);
result.push_back(x);
}
return result;
}
// set up instances when not provided with a problem specification, use default operation values
std::vector<Operation_Conv_Fwd_Xdl_Cshuffle>
Operation_Conv_Fwd_Xdl_Cshuffle::CreateOperations(const std::string& prologue,
const std::string& epilogue)
{
Problem_Conv_Fwd prob;
return CreateOperations(prob, prologue, epilogue);
}
static const char* const CopyDevice_ConvTemplate =
R"(
${Prologue}
${Epilogue}
using CDEElementOp = Epilogue;
using DeviceConv = ck::tensor_operation::device::CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<${NumDim}, ${LayoutA}, ${LayoutB}, ${LayoutDs}, ${LayoutE}, ${ADataType}, ${BDataType}, ${AccDataType}, ${CShuffleDataType}, ${DsDataType}, ${EDataType}, ${AElementwiseOperation}, ${BElementwiseOperation}, ${CDEElementwiseOperation}, ${ConvSpecialization}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, ${MPerBlock}, ${NPerBlock}, ${KPerBlock}, ${AK1}, ${BK1}, ${MPerXDL}, ${NPerXDL}, ${MXdlPerWave}, ${NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, ${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, ${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, ${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, ${BBlockTransferThreadClusterLengths_BK0_N_BK1}, ${BBlockTransferThreadClusterArrangeOrder}, ${BBlockTransferSrcAccessOrder}, ${BBlockTransferSrcVectorDim}, ${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, ${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, ${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, ${CDEBlockTransferScalarPerVector_NPerBlock}>;
constexpr ck::index_t NumATensor = ck::tensor_operation::device::GetNumABTensors<false, ${ADataType}>();
constexpr ck::index_t NumBTensor = ck::tensor_operation::device::GetNumABTensors<false, ${BDataType}>();
extern "C" __global__ void run_${name}(
const ${ADataType}* in_dev,
const ${BDataType}* wei_dev,
${EDataType}* __restrict__ out_dev,
ck::Array<ck::index_t, ${NumDim} + 3> in_lengths,
ck::Array<ck::index_t, ${NumDim} + 3> in_strides,
ck::Array<ck::index_t, ${NumDim} + 3> wei_lengths,
ck::Array<ck::index_t, ${NumDim} + 3> wei_strides,
ck::Array<ck::index_t, ${NumDim} + 3> out_lengths,
ck::Array<ck::index_t, ${NumDim} + 3> out_strides,
ck::Array<ck::index_t, ${NumDim}> conv_filter_strides,
ck::Array<ck::index_t, ${NumDim}> conv_filter_dilations,
ck::Array<ck::index_t, ${NumDim}> input_left_pads,
ck::Array<ck::index_t, ${NumDim}> input_right_pads,
const ${AElementwiseOperation} a_element_op,
const ${BElementwiseOperation} b_element_op,
const ${CDEElementwiseOperation} cde_element_op
){
auto arg = DeviceConv::Argument(in_dev,
wei_dev,
ck::Array<const void*, 0>{},
out_dev,
in_lengths,
in_strides,
wei_lengths,
wei_strides,
ck::Array<ck::Array<ck::index_t, ${NumDim} + 3>, 0>{},
ck::Array<ck::Array<ck::index_t, ${NumDim} + 3>, 0>{},
out_lengths,
out_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
${AElementwiseOperation}{},
${BElementwiseOperation}{},
${CDEElementwiseOperation}{1.0f, 1.0f});
if(!DeviceConv::IsSupportedArgument(arg))
{
printf("Arguement is not supported.\n");
return;
};
constexpr ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler();
// GridwiseGemm
using GridwiseGemm = DeviceConv::GridwiseGemm;
static constexpr auto I0 = ck::Number<0>{};
ck::tensor_operation::device::device_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
GridwiseGemm,
const ${ADataType}*,
const ${BDataType}*,
typename GridwiseGemm::DsGridPointer,
${EDataType},
${AElementwiseOperation},
${BElementwiseOperation},
${CDEElementwiseOperation},
DeviceConv::AGridDesc_AK0_M_AK1,
DeviceConv::BGridDesc_BK0_N_BK1,
DeviceConv::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceConv::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceConv::Block2ETileMap,
ck::tensor_operation::device::ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, 0>,
ck::integral_constant<bool, true>{},
false,
false>
(
arg.p_as_grid_.At(I0),
arg.p_bs_grid_.At(I0),
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_,
arg.compute_ptr_offset_of_batch_
);
}
)";
// use hardcoded instances from vector of operations to substitute values into instance template
Solution Operation_Conv_Fwd_Xdl_Cshuffle::ToSolution() const
{
std::unordered_map<std::string, std::string> values = {
{"name",
std::to_string(this->tile_desc.block_size) + "_" +
std::to_string(this->tile_desc.m_per_block) + "_" +
std::to_string(this->tile_desc.n_per_block) + "_" +
std::to_string(this->tile_desc.k_per_block) + "_" +
std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" +
std::to_string(this->tile_desc.m_per_XDL) + "_" +
std::to_string(this->tile_desc.n_per_XDL) + "_" +
std::to_string(this->tile_desc.m_Xdl_per_wave) + "_" +
std::to_string(this->tile_desc.n_Xdl_per_wave)},
{"NumDim", std::to_string(this->NumDim)},
{"LayoutA", ToString(this->A.layout)},
{"LayoutB", ToString(this->B.layout)},
{"LayoutDs",
MakeTuple(Transform(this->Ds, [](auto tensor) { return ToString(tensor.layout); }))},
{"LayoutE", ToString(this->E.layout)},
{"ADataType", ToString(this->A.element)},
{"BDataType", ToString(this->B.element)},
{"AccDataType", ToString(this->acc)},
{"ComputeDataType", ToString(this->A.element)},
{"CShuffleDataType", ToString(this->cs_type)},
{"DsDataType",
MakeTuple(Transform(this->Ds, [](auto tensor) { return ToString(tensor.element); }))},
{"EDataType", ToString(this->E.element)},
{"AElementwiseOperation", this->a_elem_op},
{"BElementwiseOperation", this->b_elem_op},
{"CDEElementwiseOperation", this->cde_elem_op},
{"Prologue", this->prologue},
{"Epilogue", this->epilogue},
{"ConvSpecialization", this->conv_specialization},
{"GemmSpecialization", this->gemm_specialization},
{"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)},
{"BlockSize", std::to_string(this->tile_desc.block_size)},
{"MPerBlock", std::to_string(this->tile_desc.m_per_block)},
{"NPerBlock", std::to_string(this->tile_desc.n_per_block)},
{"KPerBlock", std::to_string(this->tile_desc.k_per_block)},
{"AK1", std::to_string(this->tile_desc.ak1)},
{"BK1", std::to_string(this->tile_desc.bk1)},
{"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)},
{"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)},
{"MXdlPerWave", std::to_string(this->tile_desc.m_Xdl_per_wave)},
{"NXdlPerWave", std::to_string(this->tile_desc.n_Xdl_per_wave)},
{"ABlockTransferThreadClusterLengths_AK0_M_AK1",
this->a_block_transfer.thread_cluster_length},
{"ABlockTransferThreadClusterArrangeOrder",
this->a_block_transfer.thread_cluster_arrange_order},
{"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order},
{"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)},
{"ABlockTransferSrcScalarPerVector",
std::to_string(this->a_block_transfer.src_scalar_per_vector)},
{"ABlockTransferDstScalarPerVector_AK1",
std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)},
{"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)},
{"BBlockTransferThreadClusterLengths_BK0_N_BK1",
this->b_block_transfer.thread_cluster_length},
{"BBlockTransferThreadClusterArrangeOrder",
this->b_block_transfer.thread_cluster_arrange_order},
{"BBlockTransferSrcAccessOrder", this->b_block_transfer.src_access_order},
{"BBlockTransferSrcVectorDim", std::to_string(this->b_block_transfer.src_vec_dim)},
{"BBlockTransferSrcScalarPerVector",
std::to_string(this->b_block_transfer.src_scalar_per_vector)},
{"BBlockTransferDstScalarPerVector_BK1",
std::to_string(this->b_block_transfer.dst_scalar_per_vector_k1)},
{"BBlockLdsExtraN", std::to_string(this->b_block_transfer.lds_add_extra_dim)},
{"CShuffleMXdlPerWavePerShuffle",
std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)},
{"CShuffleNXdlPerWavePerShuffle",
std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)},
{"CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock",
this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl},
{"CDEBlockTransferScalarPerVector_NPerBlock",
std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)},
};
return Solution{InterpolateString(CopyDevice_ConvTemplate, values), std::move(values)};
}
} // namespace conv
} // namespace host
} // namespace ck
#include "ck/host/headers.hpp"
#include "ck_headers.hpp"
namespace ck {
namespace host {
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
const std::string config_header = "";
#pragma clang diagnostic pop
std::unordered_map<std::string_view, std::string_view> GetHeaders()
{
auto headers = ck_headers();
headers.insert(std::make_pair("ck/config.h", config_header));
return headers;
}
} // namespace host
} // namespace ck
#include "ck/host/types.hpp"
#include "ck/host/stringutils.hpp"
#include <algorithm>
#include <stdexcept>
namespace ck {
namespace host {
Solution::Solution(std::string str, std::unordered_map<std::string, std::string> values)
: template_str(std::move(str)), template_values(std::move(values))
{
}
std::string Solution::ToTemplateString() const { return this->template_str; }
std::string Solution::GetTemplateParameter(const std::string& name) const
{
return this->template_values.at(name);
}
std::string ToString(DataType dt)
{
switch(dt)
{
case DataType::Float: return "float";
case DataType::Half: return "ck::half_t";
case DataType::Int8: return "int8_t";
case DataType::Int32: return "int32_t";
}
throw std::runtime_error("Incorrect data type");
}
Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; }
std::string ToString(Layout dl)
{
switch(dl)
{
case Layout::Row: return "ck::tensor_layout::gemm::RowMajor";
case Layout::Column: return "ck::tensor_layout::gemm::ColumnMajor";
case Layout::GKCYX: return "ck::tensor_layout::convolution::GKCYX";
case Layout::GKYXC: return "ck::tensor_layout::convolution::GKYXC";
case Layout::GNHWK: return "ck::tensor_layout::convolution::GNHWK";
case Layout::GNHWC: return "ck::tensor_layout::convolution::GNHWC";
case Layout::NHWGC: return "ck::tensor_layout::convolution::NHWGC";
case Layout::NHWGK: return "ck::tensor_layout::convolution::NHWGK";
}
throw std::runtime_error("Incorrect layout");
}
std::string ToString(GemmType gt)
{
switch(gt)
{
case GemmType::Default: return "ck::tensor_operation::device::GemmSpecialization::Default";
}
throw std::runtime_error("Incorrect gemm type");
}
std::string SequenceStr(const std::vector<int>& v)
{
return "ck::Sequence<" +
JoinStrings(Transform(v, [](int x) { return std::to_string(x); }), ", ") + ">";
}
std::string MakeTuple(const std::vector<std::string>& v)
{
return "ck::Tuple<" + JoinStrings(v, ", ") + ">";
}
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/utils.hpp"
namespace ck {
namespace host {
std::size_t integer_divide_ceil(std::size_t x, std::size_t y)
{
return (x + y - std::size_t{1}) / y;
}
const std::unordered_set<std::string>& get_xdlop_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx90a", "gfx908", "gfx940", "gfx942"};
return supported_archs;
}
} // namespace host
} // namespace ck
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
add_subdirectory(rtc)
file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp)
# TODO: These tests need to be refactored to remove dependency on main ck
# headers and device compilation.
set(TESTS_REQUIRE_DEVICE_COMPILE
grouped_conv_fwd_multiple_d_v1
grouped_conv_fwd_multiple_d_v2
grouped_conv_fwd_multiple_d_v3
grouped_conv_fwd_multiple_d_v4
)
find_package(hip)
foreach(TEST_SRC ${TEST_SRCS})
get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE)
rocm_add_test_executable(codegen_test_${BASE_NAME} ${TEST_SRC})
target_link_libraries(codegen_test_${BASE_NAME} ck_rtc ck_host)
target_include_directories(codegen_test_${BASE_NAME} PUBLIC include)
if(BASE_NAME IN_LIST TESTS_REQUIRE_DEVICE_COMPILE)
target_link_libraries(codegen_test_${BASE_NAME} hip::device)
target_include_directories(codegen_test_${BASE_NAME} PUBLIC ${CK_ROOT}/include)
target_include_directories(codegen_test_${BASE_NAME} PUBLIC ${CK_ROOT}/library/include)
endif()
endforeach()
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/headers.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
#include <cmath>
#include <iterator>
#include <random>
#include <test.hpp>
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <fstream>
using half = _Float16;
// using half = __fp16;
std::vector<rtc::src_file> get_headers_for_test()
{
std::vector<rtc::src_file> result;
auto hs = ck::host::GetHeaders();
std::transform(
hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
return {p.first, p.second};
});
return result;
}
template <class T>
rtc::buffer<T> generate_buffer(std::size_t n, std::size_t seed = 0)
{
rtc::buffer<T> result(n);
std::mt19937 gen(seed);
std::uniform_real_distribution<double> dis(-1.0);
std::generate(result.begin(), result.end(), [&] { return dis(gen); });
return result;
}
template <class T, class U>
bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01)
{
return std::equal(a.begin(), a.end(), b.begin(), b.end(), [&](double x, double y) {
return fabs(x - y) < atol + rtol * fabs(y);
});
}
std::string classify(double x)
{
switch(std::fpclassify(x))
{
case FP_INFINITE: return "inf";
case FP_NAN: return "nan";
case FP_NORMAL: return "normal";
case FP_SUBNORMAL: return "subnormal";
case FP_ZERO: return "zero";
default: return "unknown";
}
}
template <class Buffer>
void print_classification(const Buffer& x)
{
std::unordered_set<std::string> result;
for(const auto& i : x)
result.insert(classify(i));
for(const auto& c : result)
std::cout << c << ", ";
std::cout << std::endl;
}
template <class Buffer>
void print_statistics(const Buffer& x)
{
std::cout << "Min value: " << *std::min_element(x.begin(), x.end()) << ", ";
std::cout << "Max value: " << *std::max_element(x.begin(), x.end()) << ", ";
double num_elements = x.size();
auto mean =
std::accumulate(x.begin(), x.end(), double{0.0}, std::plus<double>{}) / num_elements;
auto stddev = std::sqrt(
std::accumulate(x.begin(),
x.end(),
double{0.0},
[&](double r, double v) { return r + std::pow((v - mean), 2.0); }) /
num_elements);
std::cout << "Mean: " << mean << ", ";
std::cout << "StdDev: " << stddev << "\n";
}
template <class Buffer>
void print_preview(const Buffer& x)
{
if(x.size() <= 10)
{
std::for_each(x.begin(), x.end(), [&](double i) { std::cout << i << ", "; });
}
else
{
std::for_each(x.begin(), x.begin() + 5, [&](double i) { std::cout << i << ", "; });
std::cout << "..., ";
std::for_each(x.end() - 5, x.end(), [&](double i) { std::cout << i << ", "; });
}
std::cout << std::endl;
}
template <class T>
struct check_all
{
rtc::buffer<T> data{};
bool operator()(const rtc::buffer<T>& x)
{
if(data.empty())
{
data = x;
return true;
}
if(std::any_of(x.begin(), x.end(), [](double y) { return std::isnan(y); }))
return false;
return allclose(data, x);
}
};
template <class Solution>
auto report(const Solution& solution, bool pass)
{
return test::make_predicate(solution.ToTemplateString(), [=] { return pass; });
}
const std::string gemm_compile_check = R"__ck__(
#include <${include}>
extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_t* c) {
using G = ${template};
constexpr auto desc = ${template}::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})),
ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(1, ${n})),
ck::make_tuple(),
ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${n})));
static_assert(desc.IsValid(), "Invalid ck gemm.");
if constexpr(desc.IsValid())
{
${template}::Run(desc,
a,
b,
ck::make_tuple(),
c);
}
}
)__ck__";
TEST_CASE(test_problem_kernel)
{
ck::host::device_gemm_multiple_d::Problem prob;
prob.M = 1024;
prob.N = 1024;
prob.K = 1024;
check_all<half> check;
auto a = to_gpu(generate_buffer<half>(1024 * 1024, 0));
auto b = to_gpu(generate_buffer<half>(1024 * 1024, 1));
auto c = to_gpu(generate_buffer<half>(1024 * 1024, 2));
std::string epilogue = "";
std::string prologue = "";
for(auto solution : prob.GetSolutions("gfx90a", prologue, epilogue))
{
auto src = ck::host::InterpolateString(gemm_compile_check,
{{"include", prob.GetIncludeHeader()},
{"template", solution.ToTemplateString()},
{"m", std::to_string(prob.M)},
{"n", std::to_string(prob.N)},
{"k", std::to_string(prob.K)}});
auto srcs = get_headers_for_test();
srcs.push_back({"main.cpp", src});
rtc::compile_options options;
options.kernel_name = "f";
auto k = rtc::compile_kernel(srcs, options);
auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize");
auto m_per_block = solution.GetTemplateParameter<std::size_t>("MPerBlock");
auto n_per_block = solution.GetTemplateParameter<std::size_t>("NPerBlock");
auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) *
ck::host::integer_divide_ceil(prob.N, n_per_block);
k.launch(nullptr, grid_size * block_size, block_size)(a.data(), b.data(), c.data());
CHECK(report(solution, check(rtc::from_gpu(c))));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
#include "ck/host/headers.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp"
#include "ck/tensor_operation/gpu/device/helper.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include <test.hpp>
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include "common.hpp"
#include <fstream>
// Need this for verification
/**struct Epilogue
{
Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename E, typename D>
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
const ck::half_t& d) const
{
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
}
float alpha_;
float beta_;
};**/
const std::string conv_compile_check = R"__ck__(
#include <${include}>
${template};
)__ck__";
TEST_CASE(test_problem_kernel)
{
// set up problem specification
ck::host::conv::Problem_Conv_Fwd prob;
prob.NumDim = 2;
prob.G = 32;
prob.N = 256;
prob.C = 32;
prob.K = 64;
prob.Y = 3;
prob.X = 3;
prob.Hi = 28;
prob.Wi = 28;
prob.Ho = 28;
prob.Wo = 28;
check_all<ck::half_t> check;
// user provided fusion operations
std::string epilogue = R"(
struct Epilogue
{
__host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename E, typename D>
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
const ck::half_t& d) const
{
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
}
float alpha_;
float beta_;
};
)";
std::string prologue = "";
// length+stride arrays
ck::Array<ck::index_t, 5> in_lengths{static_cast<int>(prob.G),
static_cast<int>(prob.N),
static_cast<int>(prob.C),
static_cast<int>(prob.Hi),
static_cast<int>(prob.Wi)};
ck::Array<ck::index_t, 5> out_lengths{static_cast<int>(prob.G),
static_cast<int>(prob.N),
static_cast<int>(prob.K),
static_cast<int>(prob.Ho),
static_cast<int>(prob.Wo)};
ck::Array<ck::index_t, 5> wei_lengths{static_cast<int>(prob.G),
static_cast<int>(prob.K),
static_cast<int>(prob.C),
static_cast<int>(prob.Y),
static_cast<int>(prob.X)};
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
1,
static_cast<int>(prob.Wi * prob.G * prob.C),
static_cast<int>(prob.G * prob.C)};
ck::Array<ck::index_t, 5> out_strides{static_cast<int>(prob.K),
static_cast<int>(prob.Ho * prob.Wo * prob.G * prob.K),
1,
static_cast<int>(prob.Wo * prob.G * prob.K),
static_cast<int>(prob.G * prob.K)};
ck::Array<ck::index_t, 5> wei_strides{static_cast<int>(prob.K * prob.Y * prob.X * prob.C),
static_cast<int>(prob.Y * prob.X * prob.C),
1,
static_cast<int>(prob.X * prob.C),
static_cast<int>(prob.C)};
ck::Array<ck::index_t, 2> conv_filter_strides = {2, 2};
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
ck::Array<ck::index_t, 2> input_left_pads = {1, 1};
ck::Array<ck::index_t, 2> input_right_pads = {1, 1};
// move the data onto the device
auto in_dev =
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(in_lengths, in_strides, 0));
auto wei_dev =
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(wei_lengths, wei_strides, 1));
auto out_dev =
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(out_lengths, out_strides, 2));
// CK Verficiation: Reference Kernel
/**bool pass = true;
Tensor<ck::half_t> in_host(in_lengths, in_strides);
in_host.GenerateTensorValue(GeneratorTensor_1<ck::half_t>{1});
Tensor<ck::half_t> wei_host(wei_lengths, wei_strides);
wei_host.GenerateTensorValue(GeneratorTensor_1<ck::half_t>{1});
Tensor<ck::half_t> out_host(out_lengths, out_strides);
std::vector<ck::index_t> conv_filter_strides_ = {2, 2};
std::vector<ck::index_t> conv_filter_dilations_ = {1, 1};
std::vector<ck::index_t> input_left_pads_ = {1, 1};
std::vector<ck::index_t> input_right_pads_ = {1, 1};
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<
2,
ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Epilogue>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_host,
wei_host,
out_host,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
Epilogue{1.0f, 1.0f});
out_host.SetZero();
ref_invoker.Run(ref_argument);**/
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
{
// substitute instance values into the template
auto src = ck::host::InterpolateString(
conv_compile_check,
{{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}});
auto srcs = get_headers_for_test();
srcs.push_back({"main.cpp", src});
rtc::compile_options options;
auto name = solution.GetTemplateParameter<std::string>("name");
options.kernel_name = "run_" + name;
auto k = rtc::compile_kernel(srcs, options);
// Grid size calculation
auto block_size = solution.GetTemplateParameter<ck::index_t>("BlockSize");
auto tmp = get_launch_params(solution, out_lengths, out_strides);
auto grid_size = tmp * in_lengths[1];
// launch the kernel with arguments needed for the argument pointer
k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(),
wei_dev.data(),
out_dev.data(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
out_lengths,
out_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
// auto res = rtc::from_gpu(out_dev);
// pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
// assert(pass);
// Simple check: this checks that the output from each instance matches the output from the
// first instance
CHECK(report(solution, check(rtc::from_gpu(out_dev))));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment