Commit 45ff21e1 authored by Alan Turner's avatar Alan Turner
Browse files

Add jit lib for batched_gemm_softmax_gemm

parent e8b54cb3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/host/common.hpp"
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
std::size_t O = 0;
bool TransA = false;
bool TransB = false;
bool TransB1 = false;
bool TransC = false;
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType B1DataType = DataType::Half;
DataType CDataType = DataType::Half;
DataType AccDataType = DataType::Float;
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string B1ElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string CElementOp = "ck::tensor_operation::element_wise::PassThrough";
float scale = 1.0;
static const std::size_t DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle_idx = 0;
static const std::size_t ALayout_idx = 1;
static const std::size_t B0Layout_idx = 2;
static const std::size_t B1Layout_idx = 3;
static const std::size_t CLayout_idx = 4;
static const std::size_t ADataType_idx = 5;
static const std::size_t B0DataType_idx = 6;
static const std::size_t B1DataType_idx = 7;
static const std::size_t CDataType_idx = 8;
static const std::size_t AccDataType_idx = 9;
static const std::size_t CShuffleDataType_idx = 10;
static const std::size_t AElementwiseOperation_idx = 11;
static const std::size_t B0ElementwiseOperation_idx = 12;
static const std::size_t Acc0ElementwiseOperation_idx = 13;
static const std::size_t B1ElementwiseOperation_idx = 14;
static const std::size_t CElementwiseOperation_idx = 15;
static const std::size_t GEMMSpecialization_idx = 16;
static const std::size_t NumGemmKPrefetchStage_idx = 17;
static const std::size_t BlockSize_idx = 18;
static const std::size_t Gemm01MPerBlock_idx = 19;
static const std::size_t Gemm0NPerBlock_idx = 20;
static const std::size_t Gemm0KPerBlock_idx = 21;
static const std::size_t Gemm1NPerBlock_idx = 22;
static const std::size_t Gemm1KPerBlock_idx = 23;
static const std::size_t AK1_idx = 24;
static const std::size_t BK1_idx = 25;
static const std::size_t B1K1_idx = 26;
static const std::size_t MPerXDL_idx = 27;
static const std::size_t NPerXDL_idx = 28;
static const std::size_t Gemm0MXdlPerWave_idx = 29;
static const std::size_t Gemm0NXdlPerWave_idx = 30;
static const std::size_t Gemm1NXdlPerWave_idx = 31;
static const std::size_t ABlockTransferThreadClusterLengths_K0_M_K1_idx = 32;
static const std::size_t ABlockTransferThreadClusterArrangeOrder_idx = 33;
static const std::size_t ABlockTransferSrcAccessOrder_idx = 34;
static const std::size_t ABlockTransferSrcVectorDim_idx = 35;
static const std::size_t ABlockTransferSrcScalarPerVector_idx = 36;
static const std::size_t ABlockTransferDstScalarPerVector_K1_idx = 37;
static const std::size_t ABlockLdsAddExtraM_idx = 38;
static const std::size_t B0BlockTransferThreadClusterLengths_K0_N_K1_idx = 39;
static const std::size_t B0BlockTransferThreadClusterArrangeOrder_idx = 40;
static const std::size_t B0BlockTransferSrcAccessOrder_idx = 41;
static const std::size_t B0BlockTransferSrcVectorDim_idx = 42;
static const std::size_t B0BlockTransferSrcScalarPerVector_idx = 43;
static const std::size_t B0BlockTransferDstScalarPerVector_K1_idx = 44;
static const std::size_t B0BlockLdsAddExtraN_idx = 45;
static const std::size_t B1BlockTransferThreadClusterLengths_K0_N_K1_idx = 46;
static const std::size_t B1BlockTransferThreadClusterArrangeOrder_idx = 47;
static const std::size_t B1BlockTransferSrcAccessOrder_idx = 48;
static const std::size_t B1BlockTransferSrcVectorDim_idx = 49;
static const std::size_t B1BlockTransferSrcScalarPerVector_idx = 50;
static const std::size_t B1BlockTransferDstScalarPerVector_K1_idx = 51;
static const std::size_t B1BlockLdsAddExtraN_idx = 52;
static const std::size_t CShuffleMXdlPerWavePerShuffle_idx = 53;
static const std::size_t CShuffleNXdlPerWavePerShuffle_idx = 54;
static const std::size_t CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl_idx = 55;
static const std::size_t CBlockTransferScalarPerVector_NWaveNPerXdl_idx = 56;
static const std::size_t MaskOutUpperTriangle_idx = 57;
std::string GetIncludeHeader() const;
std::vector<Solution> GetSolutions(const std::string& arch) const;
private:
std::vector<std::string> GetInstances(const std::string& arch) const;
Solution MakeSolution(std::size_t idx, const std::string& arch) const;
};
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/host/common.hpp"
namespace ck {
namespace host {
namespace device_gemm_softmax_gemm {
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
std::size_t O = 0;
bool TransA = false;
bool TransB = false;
bool TransB1 = false;
bool TransC = false;
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType B1DataType = DataType::Half;
DataType CDataType = DataType::Half;
DataType AccDataType = DataType::Float;
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string AccElementOp = "ck::tensor_operation::element_wise::Scale";
std::string B1ElementOp = "ck::tensor_operation::element_wise::PassThrough";
static const std::size_t ds_layout_idx = 3;
static const std::size_t ds_data_type_idx = 9;
static const std::size_t e_data_type_idx = 10;
static const std::size_t a_elementwise_op_idx = 11;
static const std::size_t b_elementwise_op_idx = 12;
static const std::size_t ds_elementwise_op_idx = 13;
static const std::size_t gemm_spec_idx = 14;
static const std::size_t block_size_idx = 16;
static const std::size_t m_per_block_idx = 17;
static const std::size_t n_per_block_idx = 18;
static const std::size_t k_per_block_idx = 19;
std::string GetIncludeHeader() const;
std::vector<Solution> GetSolutions(const std::string& arch) const;
private:
std::vector<std::string> GetInstances(const std::string& arch) const;
Solution MakeSolution(std::size_t idx, const std::string& arch) const;
};
} // namespace device_gemm_softmax_gemm
} // namespace host
} // namespace ck
#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
#include "ck/host/common.hpp"
#include "gemm_add_add_fastgelu_instances.hpp"
#include <algorithm>
#include <unordered_set>
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
std::string GetGemmSpec(const std::size_t m,
const std::size_t n,
const std::size_t k,
const std::size_t n1,
const std::size_t m_per_block,
const std::size_t n_per_block,
const std::size_t k_per_block,
const std::size_t n1_per_block)
{
std::string spec = "";
if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0)
spec += "M";
if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0)
spec += "N";
if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0)
spec += "K";
if(integer_divide_ceil(n1, n1_per_block) * n1_per_block - n1 != 0)
spec += "O";
if(spec == "")
return "ck::tensor_operation::device::GemmSpecialization::Default";
return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding";
}
std::size_t GetGridSize(const std::size_t m,
const std::size_t n,
const std::size_t m_per_block,
const std::size_t n_per_block)
{
return integer_divide_ceil(m, m_per_block) * integer_divide_ceil(n, n_per_block);
}
const std::unordered_set<std::string>& get_xdlop_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx90a", "gfx908", "gfx940"};
return supported_archs;
}
std::vector<std::string> Problem::GetInstances(const std::string& arch) const
{
std::vector<std::string> instances;
if(get_xdlop_archs().find(arch) != get_xdlop_archs().end())
{
ck::host::instance::batched_gemm_softmax_gemm_instances all_instances{};
instances = all_instances.get_instances();
}
return instances;
}
std::string GetElementwiseScaleString(const float s)
{
return "ck::tensor_operation::element_wise::Scale{" + std::to_string(s) + "}";
}
Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
{
auto template_str = GetInstances(arch).at(idx);
std::istringstream iss(template_str);
std::vector<std::string> params(std::istream_iterator<std::string>{iss},
std::istream_iterator<std::string>());
params[AElementwiseOperation_idx] = AElementOp;
params[B0ElementwiseOperation_idx] = BElementOp;
params[B1ElementwiseOperation_idx] = BElementOp;
params[CElementwiseOperation_idx] = CElementOp;
params[Acc0ElementwiseOperation_idx] = GetElementwiseScaleString(scale);
auto block_size_str = params[BlockSize_idx];
auto m_per_block_str = params[Gemm01MPerBlock_idx];
auto n_per_block_str = params[Gemm0NPerBlock_idx];
auto k_per_block_str = params[Gemm0KPerBlock_idx];
auto n1_per_block_str = params[Gemm1NPerBlock_idx];
auto k1_per_block_str = params[Gemm1KPerBlock_idx];
const std::size_t block_size = std::stoi(block_size_str);
const std::size_t m_per_block = std::stoi(m_per_block_str);
const std::size_t n_per_block = std::stoi(n_per_block_str);
const std::size_t k_per_block = std::stoi(k_per_block_str);
const std::size_t n1_per_block = std::stoi(n1_per_block_str);
const std::size_t k1_per_block = std::stoi(k1_per_block_str);
const std::size_t grid_size = GetGridSize(M, O, m_per_block, n1_per_block);
params[GEMMSpecialization_idx] = GetGemmSpec(M, N, K, O, m_per_block, n_per_block, k_per_block, n1_per_block);
std::string str = std::accumulate(
params.begin() + 1,
params.end(),
std::string{},
[](const std::string& a, const std::string& b) { return a.empty() ? b : a + ", " + b; });
str = params.front() + "< " + str + ">";
return Solution{str, block_size, grid_size};
}
std::string Problem::GetIncludeHeader() const
{
return ck::host::instance::batched_gemm_softmax_gemm_instances{}.get_include_header();
}
std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
{
std::vector<Solution> solutions;
const std::size_t num_instances = GetInstances(arch).size();
for(std::size_t i = 0; i < num_instances; ++i)
{
solutions.push_back(MakeSolution(i, arch));
}
return solutions;
}
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment