Unverified Commit b5ca008d authored by Mirza Halilčević's avatar Mirza Halilčević Committed by GitHub
Browse files

Introduce gemm_softmax_gemm to codegen (#1542)



* Introduce ck_host library and gemm_softmax_gemm.

* Minor refactor.

* Add descriptor to gemm_softmax_gemm.

* Bugfix.

* Revert ck_host library.

* fix clang format

---------
Co-authored-by: default avatarIllia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: default avatarillsilin <Illia.Silin@amd.com>
parent c0adab48
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
// defines all values need for an instance of fwd conv
struct Operation_Xdl_CShuffle
{
// returns a vector of instances, only given fusion operators: will use default problem spec
static std::vector<std::vector<Operation_Xdl_CShuffle>>
CreateOperations(const std::string& prologue, const std::string& epilogue);
// returns a vector of instances, given a problem spec and fusion operators
static std::vector<Operation_Xdl_CShuffle>
CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue);
TensorDesc A{};
TensorDesc B{};
TensorDesc B1{};
TensorDesc C{};
DataType acc = DataType::Float;
DataType cs_type = DataType::Half;
std::string a_elem_op = PassThrough;
std::string b_elem_op = PassThrough;
std::string b1_elem_op = PassThrough;
std::string c_elem_op = PassThrough;
std::string acc_elem_op = Scale;
std::string prologue = "";
std::string epilogue = "";
std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default";
// tuning parameters
operation::TileDescGemmGemm tile_desc{};
operation::BlockTransferDesc a_block_transfer{};
operation::BlockTransferDesc b0_block_transfer{};
operation::BlockTransferDesc b1_block_transfer{};
operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{};
bool mask_out_upper_triangle = false;
// functions to update fusion operators if provided
void update_prologue(const std::string& prologue);
void update_epilogue(const std::string& epilogue);
/**constexpr**/ bool
IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_, std::size_t Gemm1NRaw_);
// returns a templated instance
Solution ToSolution() const;
};
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
// defines the problem specification for a GEMM operation
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
std::size_t O = 0;
bool TransA = false;
bool TransB = false;
bool TransB1 = false;
bool TransC = false;
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType B1DataType = DataType::Half;
DataType CDataType = DataType::Half;
std::string AElementOp = PassThrough;
std::string BElementOp = PassThrough;
std::string B1ElementOp = PassThrough;
std::string CElementOp = PassThrough;
std::string AccElementOp = Scale;
// returns the correct device op file for the operation
std::string GetIncludeHeader() const;
// returns a list of instances based on the problem spec and provided fusion operations
std::vector<Solution> GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const;
};
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
......@@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle
operation::BlockTransferDesc b_block_transfer{};
operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{};
LoopScheduler loop_scheduler{};
PipelineVersion pipeline_version{};
// functions to update fusion operators if provided
void update_prologue(const std::string& prologue);
......
......@@ -23,6 +23,26 @@ struct TileDesc
int n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0;
};
struct TileDescGemmGemm
{
int block_size = 0;
int gemm01_m_per_block = 0;
int gemm0_n_per_block = 0;
int gemm0_k_per_block = 0;
int gemm1_n_per_block = 0;
int gemm1_k_per_block = 0;
int ak1 = 0;
int bk1 = 0;
int b1k1 = 0;
int m_per_XDL = 0;
int n_per_XDL = 0;
int gemm0_m_Xdl_per_wave = 0;
int gemm0_n_Xdl_per_wave = 0;
int gemm1_n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0;
};
struct BlockTransferDesc
{
std::string thread_cluster_length = "";
......
......@@ -66,6 +66,20 @@ enum class GemmType
};
std::string ToString(GemmType gt);
enum class LoopScheduler
{
Default,
Interwave,
};
std::string ToString(LoopScheduler ls);
enum class PipelineVersion
{
v1,
v2
};
std::string ToString(PipelineVersion pv);
struct TensorDesc
{
DataType element;
......@@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...});
constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough";
constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear";
constexpr const char* Scale = "ck::tensor_operation::element_wise::Scale";
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
// return the relevant device op file based on the operation
std::string Problem::GetIncludeHeader() const
{
return "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp";
}
// returns templated instances when provided with a problem specification
std::vector<Solution> Problem::GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const
{
if(get_xdlop_archs().count(arch) == 0)
return {};
auto ops = ck::host::device_batched_gemm_softmax_gemm::Operation_Xdl_CShuffle::CreateOperations(
*this, prologue, epilogue); // obtains vector of instances
std::vector<Solution> result;
std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) {
return op.ToSolution(); // template instance with correct values
});
return result;
}
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp"
#include <cassert>
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
// calculate appropriate Gemm Specification based on input tensor dimensions
std::string GetGemmSpec(const std::size_t m,
const std::size_t n,
const std::size_t k,
const std::size_t n1,
const std::size_t m_per_block,
const std::size_t n_per_block,
const std::size_t k_per_block,
const std::size_t n1_per_block)
{
std::string spec = "";
if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0)
spec += "M";
if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0)
spec += "N";
if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0)
spec += "K";
if(integer_divide_ceil(n1, n1_per_block) * n1_per_block - n1 != 0)
spec += "O";
if(spec == "")
return "ck::tensor_operation::device::GemmSpecialization::Default";
return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding";
}
// function to update prologue/epilogue with user provided operation
void Operation_Xdl_CShuffle::update_prologue(const std::string& pro)
{
if(!prologue.empty())
{
this->prologue = pro;
}
else
{
this->prologue = "";
}
}
void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
{
if(!epilogue.empty())
{
this->epilogue = epi;
}
else
{
this->epilogue = "";
}
}
// accounts for all possible combinations of Row/Col major
static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; }
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
// instances
std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
const Problem& prob, const std::string& prologue, const std::string& epilogue)
{
std::vector<Operation_Xdl_CShuffle> result;
std::vector<operation::TileDescGemmGemm> tile_descriptions = {
// clang-format off
// Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| NumGemmK|
// Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch|
// | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage|
// | | | | | | | | | | | Wave| Wave| Wave| |
{ 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, 1},
{ 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, 1},
{ 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, 1},
{ 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, 1},
{ 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1},
{ 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1},
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
{ 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
{ 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1},
{ 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1},
{ 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1},
{ 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1},
// Padded fallback kernel
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
{ 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, 1},
// Irregular k
{ 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, 1},
{ 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, 1},
{ 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, 1},
{ 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, 1},
{ 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, 1},
{ 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, 1},
// clang-format on
};
const std::vector<operation::BlockTransferDesc> a_block_descriptions = {
// clang-format off
// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
// Padded fallback kernel
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
// Irregular k
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
// clang-format on
};
const std::vector<operation::BlockTransferDesc> b1_block_descriptions = {
// clang-format off
// B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
// Padded fallback kernel
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
// Irregular k
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
// clang-format on
};
std::vector<operation::CShuffleDesc> cshuffle_descriptions = {
// clang-format off
// CShuffle| CShuffle|
// MXdlPerWave| NXdlPerWave|
// PerShuffle| PerShuffle|
// | |
{ 1, 2},
{ 1, 2},
{ 1, 2},
{ 1, 2},
{ 1, 2},
{ 1, 2},
{ 1, 2},
{ 1, 2},
{ 1, 8},
{ 1, 4},
{ 1, 8},
{ 1, 4},
// Padded fallback kernel
{ 1, 2},
{ 1, 2},
// Irregular k
{ 1, 2},
{ 1, 2},
{ 1, 2},
{ 1, 2},
{ 1, 2},
{ 1, 2},
// clang-format on
};
std::vector<operation::CBlockTransferDesc> c_block_descriptions = {
// clang-format off
// CBlockTransferClusterLengths| CBlockTransfer
// _MBlock_MWaveMPerXdl| ScalarPerVector
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
// |
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 16, 1,16>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 16, 1,16>, 8},
{ S<1, 32, 1, 8>, 8},
// Padded fallback kernel
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
// Irregular k
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
// clang-format on
};
assert(tile_descriptions.size() == a_block_descriptions.size());
assert(tile_descriptions.size() == b1_block_descriptions.size());
assert(tile_descriptions.size() == cshuffle_descriptions.size());
assert(tile_descriptions.size() == c_block_descriptions.size());
// Put all values together into a single operation > store into the result vector
for(std::size_t i = 0; i < tile_descriptions.size(); i++)
{
Operation_Xdl_CShuffle x;
x.tile_desc = tile_descriptions[i];
x.a_block_transfer = a_block_descriptions[i];
x.b0_block_transfer = a_block_descriptions[i]; // b0 same as a
x.b1_block_transfer = b1_block_descriptions[i];
x.cshuffle = cshuffle_descriptions[i];
x.c_block_transfer = c_block_descriptions[i];
x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)};
x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)};
x.B1 = TensorDesc{prob.B1DataType, ToLayout(prob.TransB1)};
x.C = TensorDesc{prob.CDataType, ToLayout(prob.TransC)};
x.a_elem_op = prob.AElementOp;
x.b_elem_op = prob.BElementOp;
x.b1_elem_op = prob.B1ElementOp;
x.c_elem_op = prob.CElementOp;
x.acc_elem_op = prob.AccElementOp;
x.gemm_specialization = GetGemmSpec(prob.M,
prob.N,
prob.K,
prob.O,
x.tile_desc.gemm01_m_per_block,
x.tile_desc.gemm0_n_per_block,
x.tile_desc.gemm0_k_per_block,
x.tile_desc.gemm1_n_per_block);
x.update_prologue(prologue);
x.update_epilogue(epilogue);
x.mask_out_upper_triangle = true;
result.push_back(x);
x.mask_out_upper_triangle = false;
result.push_back(x);
}
return result;
}
// set up instances when not provided with a problem specification, use default operation values and
// all possible layout combinations
std::vector<std::vector<Operation_Xdl_CShuffle>>
Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std::string& epilogue)
{
Problem prob;
prob.TransA = false;
prob.TransB = true;
prob.TransB1 = false;
prob.TransC = false;
return {CreateOperations(prob, prologue, epilogue)};
}
static const char* const DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate =
"ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle<${LayoutA}, "
"${LayoutB0}, ${LayoutB1}, ${LayoutC}, ${ADataType}, ${B0DataType}, ${B1DataType}, "
"${CDataType}, ${AccDataType}, ${CShuffleDataType}, ${AElementwiseOperation}, "
"${B0ElementwiseOperation}, ${Acc0ElementwiseOperation}, ${B1ElementwiseOperation}, "
"${CElementwiseOperation}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, "
"${Gemm01MPerBlock}, ${Gemm0NPerBlock}, ${Gemm0KPerBlock}, ${Gemm1NPerBlock}, "
"${Gemm1KPerBlock}, ${AK1}, ${BK1}, ${B1K1}, ${MPerXDL}, ${NPerXDL}, ${Gemm0MXdlPerWave}, "
"${Gemm0NXdlPerWave}, ${Gemm1NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, "
"${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, "
"${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, "
"${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, "
"${B0BlockTransferThreadClusterLengths_BK0_N_BK1}, "
"${B0BlockTransferThreadClusterArrangeOrder}, ${B0BlockTransferSrcAccessOrder}, "
"${B0BlockTransferSrcVectorDim}, ${B0BlockTransferSrcScalarPerVector}, "
"${B0BlockTransferDstScalarPerVector_BK1}, ${B0BlockLdsExtraN}, "
"${B1BlockTransferThreadClusterLengths_BK0_N_BK1}, "
"${B1BlockTransferThreadClusterArrangeOrder}, ${B1BlockTransferSrcAccessOrder}, "
"${B1BlockTransferSrcVectorDim}, ${B1BlockTransferSrcScalarPerVector}, "
"${B1BlockTransferDstScalarPerVector_BK1}, ${B1BlockLdsExtraN}, "
"${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
"${CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl}, "
"${CBlockTransferScalarPerVector_NWaveNPerXdl}, ${MaskOutUpperTriangle}>";
// use hardcoded instances from vector of operations to substitute values into instance template
Solution Operation_Xdl_CShuffle::ToSolution() const
{
std::unordered_map<std::string, std::string> values = {
{"name",
std::to_string(this->tile_desc.block_size) + "_" +
std::to_string(this->tile_desc.gemm01_m_per_block) + "_" +
std::to_string(this->tile_desc.gemm0_n_per_block) + "_" +
std::to_string(this->tile_desc.gemm0_k_per_block) + "_" +
std::to_string(this->tile_desc.gemm1_n_per_block) + "_" +
std::to_string(this->tile_desc.gemm1_k_per_block) + "_" +
std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" +
std::to_string(this->tile_desc.b1k1) + "_" +
std::to_string(this->tile_desc.m_per_XDL) + "_" +
std::to_string(this->tile_desc.n_per_XDL) + "_" +
std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave) + "_" +
std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave) + "_" +
std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)},
{"LayoutA", ToString(this->A.layout)},
{"LayoutB0", ToString(this->B.layout)},
{"LayoutB1", ToString(this->B1.layout)},
{"LayoutC", ToString(this->C.layout)},
{"ADataType", ToString(this->A.element)},
{"B0DataType", ToString(this->B.element)},
{"B1DataType", ToString(this->B1.element)},
{"CDataType", ToString(this->C.element)},
{"AccDataType", ToString(this->acc)},
{"CShuffleDataType", ToString(this->cs_type)},
{"AElementwiseOperation", this->a_elem_op},
{"B0ElementwiseOperation", this->b_elem_op},
{"Acc0ElementwiseOperation", this->acc_elem_op},
{"B1ElementwiseOperation", this->b1_elem_op},
{"CElementwiseOperation", this->c_elem_op},
{"GemmSpecialization", this->gemm_specialization},
{"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)},
{"BlockSize", std::to_string(this->tile_desc.block_size)},
{"Gemm01MPerBlock", std::to_string(this->tile_desc.gemm01_m_per_block)},
{"Gemm0NPerBlock", std::to_string(this->tile_desc.gemm0_n_per_block)},
{"Gemm0KPerBlock", std::to_string(this->tile_desc.gemm0_k_per_block)},
{"Gemm1NPerBlock", std::to_string(this->tile_desc.gemm1_n_per_block)},
{"Gemm1KPerBlock", std::to_string(this->tile_desc.gemm1_k_per_block)},
{"AK1", std::to_string(this->tile_desc.ak1)},
{"BK1", std::to_string(this->tile_desc.bk1)},
{"B1K1", std::to_string(this->tile_desc.b1k1)},
{"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)},
{"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)},
{"Gemm0MXdlPerWave", std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave)},
{"Gemm0NXdlPerWave", std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave)},
{"Gemm1NXdlPerWave", std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)},
{"ABlockTransferThreadClusterLengths_AK0_M_AK1",
this->a_block_transfer.thread_cluster_length},
{"ABlockTransferThreadClusterArrangeOrder",
this->a_block_transfer.thread_cluster_arrange_order},
{"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order},
{"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)},
{"ABlockTransferSrcScalarPerVector",
std::to_string(this->a_block_transfer.src_scalar_per_vector)},
{"ABlockTransferDstScalarPerVector_AK1",
std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)},
{"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)},
{"B0BlockTransferThreadClusterLengths_BK0_N_BK1",
this->b0_block_transfer.thread_cluster_length},
{"B0BlockTransferThreadClusterArrangeOrder",
this->b0_block_transfer.thread_cluster_arrange_order},
{"B0BlockTransferSrcAccessOrder", this->b0_block_transfer.src_access_order},
{"B0BlockTransferSrcVectorDim", std::to_string(this->b0_block_transfer.src_vec_dim)},
{"B0BlockTransferSrcScalarPerVector",
std::to_string(this->b0_block_transfer.src_scalar_per_vector)},
{"B0BlockTransferDstScalarPerVector_BK1",
std::to_string(this->b0_block_transfer.dst_scalar_per_vector_k1)},
{"B0BlockLdsExtraN", std::to_string(this->b0_block_transfer.lds_add_extra_dim)},
{"B1BlockTransferThreadClusterLengths_BK0_N_BK1",
this->b1_block_transfer.thread_cluster_length},
{"B1BlockTransferThreadClusterArrangeOrder",
this->b1_block_transfer.thread_cluster_arrange_order},
{"B1BlockTransferSrcAccessOrder", this->b1_block_transfer.src_access_order},
{"B1BlockTransferSrcVectorDim", std::to_string(this->b1_block_transfer.src_vec_dim)},
{"B1BlockTransferSrcScalarPerVector",
std::to_string(this->b1_block_transfer.src_scalar_per_vector)},
{"B1BlockTransferDstScalarPerVector_BK1",
std::to_string(this->b1_block_transfer.dst_scalar_per_vector_k1)},
{"B1BlockLdsExtraN", std::to_string(this->b1_block_transfer.lds_add_extra_dim)},
{"CShuffleMXdlPerWavePerShuffle",
std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)},
{"CShuffleNXdlPerWavePerShuffle",
std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)},
{"CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl",
this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl},
{"CBlockTransferScalarPerVector_NWaveNPerXdl",
std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)},
{"MaskOutUpperTriangle", std::to_string(this->mask_out_upper_triangle)},
};
return Solution{InterpolateString(DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate, values),
std::move(values)};
}
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
......@@ -62,6 +62,12 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
// accounts for all possible combinations of Row/Col major
static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; }
// clang-format off
// DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1,
// DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>
// clang-format on
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
// instances
std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
......@@ -83,6 +89,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1},
{ 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1},
{ 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1},
// Irregular tile
{ 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, 1},
// clang-format on
};
......@@ -100,6 +108,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
// Irregular tile
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
// clang-format on
};
......@@ -109,15 +119,17 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
// Irregular tile
{ S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
// clang-format on
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
};
std::vector<operation::BlockTransferDesc> b_block_descriptions_rowmajor = {
......@@ -134,6 +146,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
// Irregular tile
{ S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
// clang-format on
};
......@@ -151,6 +165,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
// Irregular tile
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
// clang-format on
};
......@@ -167,6 +183,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ 1, 1},
{ 1, 1},
{ 1, 1},
{ 1, 1},
{ 1, 1},
// clang-format on
};
......@@ -185,6 +202,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<1, 16, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
// Irregular tile
{ S<1, 16, 1, 4>, 1},
// clang-format on
};
......@@ -199,33 +218,44 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
assert(tile_descriptions.size() == cshuffle_descriptions.size());
assert(tile_descriptions.size() == c_block_descriptions.size());
// Put all values together into a single operation > store into the result vector
for(std::size_t i = 0; i < tile_descriptions.size(); i++)
const std::vector<std::tuple<LoopScheduler, PipelineVersion>> scheduler_pipeline_descriptions =
{
{LoopScheduler::Default, PipelineVersion::v1},
{LoopScheduler::Interwave, PipelineVersion::v1},
{LoopScheduler::Default, PipelineVersion::v2},
};
for(auto [loop_scheduler, pipeline_version] : scheduler_pipeline_descriptions)
{
Operation_Xdl_CShuffle x;
x.tile_desc = tile_descriptions[i];
x.a_block_transfer = a_block_descriptions[i];
x.b_block_transfer = b_block_descriptions[i];
x.cshuffle = cshuffle_descriptions[i];
x.c_block_transfer = c_block_descriptions[i];
x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)};
x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)};
x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)};
x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) {
return TensorDesc{dt, ToLayout(trans)};
});
x.a_elem_op = prob.AElementOp;
x.b_elem_op = prob.BElementOp;
x.cde_elem_op = prob.CDEElementOp;
x.gemm_specialization = GetGemmSpec(prob.M,
prob.N,
prob.K,
x.tile_desc.m_per_block,
x.tile_desc.n_per_block,
x.tile_desc.k_per_block);
x.update_prologue(prologue);
x.update_epilogue(epilogue);
result.push_back(x);
// Put all values together into a single operation > store into the result vector
for(std::size_t i = 0; i < tile_descriptions.size(); i++)
{
Operation_Xdl_CShuffle x;
x.tile_desc = tile_descriptions[i];
x.a_block_transfer = a_block_descriptions[i];
x.b_block_transfer = b_block_descriptions[i];
x.cshuffle = cshuffle_descriptions[i];
x.c_block_transfer = c_block_descriptions[i];
x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)};
x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)};
x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)};
x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) {
return TensorDesc{dt, ToLayout(trans)};
});
x.a_elem_op = prob.AElementOp;
x.b_elem_op = prob.BElementOp;
x.cde_elem_op = prob.CDEElementOp;
x.gemm_specialization = GetGemmSpec(prob.M,
prob.N,
prob.K,
x.tile_desc.m_per_block,
x.tile_desc.n_per_block,
x.tile_desc.k_per_block);
x.loop_scheduler = loop_scheduler;
x.pipeline_version = pipeline_version;
x.update_prologue(prologue);
x.update_epilogue(epilogue);
result.push_back(x);
}
}
return result;
}
......@@ -263,7 +293,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
"${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, "
"${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
"${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
"${CDEBlockTransferScalarPerVector_NPerBlock}>";
"${CDEBlockTransferScalarPerVector_NPerBlock}, ${LoopScheduler}, ${PipelineVersion}>";
// use hardcoded instances from vector of operations to substitute values into instance template
Solution Operation_Xdl_CShuffle::ToSolution() const
......@@ -336,6 +366,8 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl},
{"CDEBlockTransferScalarPerVector_NPerBlock",
std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)},
{"LoopScheduler", ToString(this->loop_scheduler)},
{"PipelineVersion", ToString(this->pipeline_version)},
};
return Solution{InterpolateString(DeviceGemmMultipleD_Xdl_CShuffleTemplate, values),
......
......@@ -59,6 +59,26 @@ std::string ToString(GemmType gt)
throw std::runtime_error("Incorrect gemm type");
}
std::string ToString(LoopScheduler ls)
{
switch(ls)
{
case LoopScheduler::Default: return "ck::LoopScheduler::Default";
case LoopScheduler::Interwave: return "ck::LoopScheduler::Interwave";
}
throw std::runtime_error("Incorrect LoopScheduler type");
}
std::string ToString(PipelineVersion pv)
{
switch(pv)
{
case PipelineVersion::v1: return "ck::PipelineVersion::v1";
case PipelineVersion::v2: return "ck::PipelineVersion::v2";
}
throw std::runtime_error("Incorrect PipelineVersion type");
}
std::string SequenceStr(const std::vector<int>& v)
{
return "ck::Sequence<" +
......
......@@ -8,6 +8,7 @@
#include <memory>
#include <stdexcept>
#include <string>
#include <stdexcept>
namespace rtc {
......
......@@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType,
typename ALayout, typename BLayout, typename CLayout>
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
......@@ -57,9 +62,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.stride_B = stride_B;
args.stride_C = stride_C;
float ave_time = gemm_calc<ADataType, BDataType, AccDataType, CDataType,
ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
float ave_time =
gemm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
......@@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " A_Layout =" << ALayout::name
<< " B_Layout =" << BLayout::name
<< " C_Layout =" << CLayout::name
<< " A Type = " << DataTypeTraits<ADataType>::name
<< " B Type = " << DataTypeTraits<BDataType>::name
<< " C Type = " << DataTypeTraits<CDataType>::name
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
<< " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name
<< " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits<ADataType>::name
<< " B Type = " << DataTypeTraits<BDataType>::name
<< " C Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}
......@@ -92,10 +94,10 @@ int run_gemm_example_with_layouts(int argc,
if(!result)
return -1;
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
......@@ -133,19 +135,19 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<ADataType, BDataType, AccDataType, CDataType,
ALayout, BLayout, CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
invoke_gemm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
......@@ -160,9 +162,9 @@ int run_gemm_example_with_layouts(int argc,
a_m_k, b_k_n, c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>
(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
......@@ -218,9 +220,9 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>
(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_gpu_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
......
......@@ -610,6 +610,96 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return true;
}
static constexpr bool
IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_)
{
// check vector load/store
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A
if constexpr(is_same_v<ALayout, Row>)
{
if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<ALayout, Col>)
{
if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of B
if constexpr(is_same_v<BLayout, Row>)
{
if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<BLayout, Col>)
{
if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of B1
if constexpr(is_same_v<B1Layout, Row>)
{
if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<B1Layout, Col>)
{
if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of C
if constexpr(is_same_v<CLayout, Row>)
{
if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else if constexpr(is_same_v<CLayout, Col>)
{
if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
return false;
}
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
......@@ -624,29 +714,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
arg.block_2_ctile_map_) and
IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw);
}
// polymorphic
......@@ -764,6 +837,268 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return str.str();
}
template <class ADesc, class BDesc, class B1Desc, class CDesc>
struct Descriptor
{
template <class AGridDescriptor>
static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc)
{
const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc);
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <class BGridDescriptor>
static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc)
{
const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <class B1GridDescriptor>
static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc)
{
const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc);
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <class CGridDescriptor>
static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc)
{
return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc);
}
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1(ADesc{}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1(BDesc{}))>;
using B1GridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeB1GridDescriptor_BK0_N_BK1(B1Desc{}))>;
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(CDesc{}))>;
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
B1GridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
AK1,
BK1,
B1K1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
Gemm1NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
true,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
true,
BBlockLdsExtraN,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
false,
B1BlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
matrix_padder.PadN,
MaskOutUpperTriangle>;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1;
CGridDesc_M_N c_grid_desc_m_n;
C0MatrixMask c0_matrix_mask;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock;
// element-wise op
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
B1ElementwiseOperation b1_element_op;
CElementwiseOperation c_element_op;
bool has_main_k_block_loop = true;
bool is_valid = false;
constexpr Descriptor(ADesc a,
BDesc b,
B1Desc b1,
CDesc c,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
B1ElementwiseOperation b1_element_op_,
CElementwiseOperation c_element_op_)
: a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(a)},
b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)},
b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)},
c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)},
block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock{
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n)},
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
c0_matrix_mask{c.GetLength(I1)},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
b1_element_op{b1_element_op_},
c_element_op{c_element_op_},
is_valid{GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_m_n,
block_2_ctile_map) and
IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1),
b_grid_desc_bk0_n_bk1.GetLength(I1),
a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_ak0_m_ak1.GetLength(I2),
b1_grid_desc_bk0_n_bk1.GetLength(I1))}
{
}
constexpr bool IsValid() const { return is_valid; }
};
template <class ADesc, class BDesc, class B1Desc, class CDesc>
static constexpr auto
make_descriptor(ADesc a,
BDesc b,
B1Desc b1,
CDesc c,
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{},
CElementwiseOperation c_element_op = CElementwiseOperation{})
{
return Descriptor<ADesc, BDesc, B1Desc, CDesc>(
a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op);
}
template <class Desc>
__device__ static void Run(const Desc& desc,
const float scale,
const ADataType* __restrict__ p_a_grid,
const ADataType* __restrict__ p_b_grid,
const ADataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid)
{
#ifndef __HIPCC_RTC__
assert(desc.is_valid);
#endif
__shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()];
AccElementwiseOperation acc_element_op{scale};
if(desc.has_main_k_block_loop)
{
Desc::GridwiseGemm::template Run<true>(
p_a_grid,
p_b_grid,
p_b1_grid,
p_c_grid,
p_shared_block,
desc.a_element_op,
desc.b_element_op,
acc_element_op,
desc.b1_element_op,
desc.c_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.b1_grid_desc_bk0_n_bk1,
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
desc.block_2_ctile_map,
desc.c0_matrix_mask);
}
else
{
Desc::GridwiseGemm::template Run<false>(
p_a_grid,
p_b_grid,
p_b1_grid,
p_c_grid,
p_shared_block,
desc.a_element_op,
desc.b_element_op,
acc_element_op,
desc.b1_element_op,
desc.c_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.b1_grid_desc_bk0_n_bk1,
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
desc.block_2_ctile_map,
desc.c0_matrix_mask);
}
}
};
} // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment