Commit 986182fc authored by Umang Yadav's avatar Umang Yadav
Browse files

Merge branch 'migraphx' into migx-jit-lib-hiprtc

parents 3ca84d92 11cab2d5
......@@ -63,3 +63,6 @@ _templates/
_toc.yml
docBin/
_doxygen/
# pycache
__pycache__/
......@@ -65,8 +65,8 @@ else()
-Wuninitialized
-Wunreachable-code
-Wunused
-Werror
-Wno-reserved-identifier
-Werror
-Wno-option-ignored
-Wsign-compare
-Wno-extra-semi-stmt
......
......@@ -611,6 +611,95 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return true;
}
static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_)
{
// check vector load/store
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A
if constexpr(is_same_v<ALayout, Row>)
{
if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<ALayout, Col>)
{
if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of B
if constexpr(is_same_v<BLayout, Row>)
{
if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<BLayout, Col>)
{
if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of B1
if constexpr(is_same_v<B1Layout, Row>)
{
if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<B1Layout, Col>)
{
if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of C
if constexpr(is_same_v<CLayout, Row>)
{
if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else if constexpr(is_same_v<CLayout, Col>)
{
if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
return false;
}
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
......@@ -625,29 +714,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
arg.block_2_ctile_map_) and
IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw);
}
// polymorphic
......@@ -766,6 +838,266 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return str.str();
}
template <class ADesc, class BDesc, class B1Desc, class CDesc>
struct Descriptor
{
template<class AGridDescriptor>
static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc)
{
const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc);
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template<class BGridDescriptor>
static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc)
{
const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template<class B1GridDescriptor>
static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc)
{
const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc);
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template<class CGridDescriptor>
static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc)
{
return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc);
}
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1(ADesc{}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1(BDesc{}))>;
using B1GridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeB1GridDescriptor_BK0_N_BK1(B1Desc{}))>;
using CGridDesc_M_N =
remove_cvref_t<decltype(MakeCGridDescriptor_M_N(CDesc{}))>;
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
B1GridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
AK1,
BK1,
B1K1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
Gemm1NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
true,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
true,
BBlockLdsExtraN,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
false,
B1BlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
matrix_padder.PadN,
MaskOutUpperTriangle>;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1;
CGridDesc_M_N c_grid_desc_m_n;
C0MatrixMask c0_matrix_mask;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_descriptor_mblock_mperblock_nblock_nperblock;
// element-wise op
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
B1ElementwiseOperation b1_element_op;
CElementwiseOperation c_element_op;
bool has_main_k_block_loop = true;
bool is_valid = false;
constexpr Descriptor(ADesc a,
BDesc b,
B1Desc b1,
CDesc c,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
B1ElementwiseOperation b1_element_op_,
CElementwiseOperation c_element_op_)
: a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(a)},
b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)},
b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)},
c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)},
block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(
c_grid_desc_m_n)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock{
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n)},
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
c0_matrix_mask{c.GetLength(I1)},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
b1_element_op{b1_element_op_},
c_element_op{c_element_op_},
is_valid{GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_m_n,
block_2_ctile_map) and
IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1),
b_grid_desc_bk0_n_bk1.GetLength(I1),
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2),
b1_grid_desc_bk0_n_bk1.GetLength(I1))}
{
}
constexpr bool IsValid() const
{
return is_valid;
}
};
template <class ADesc, class BDesc, class B1Desc, class CDesc>
static constexpr auto
make_descriptor(ADesc a,
BDesc b,
B1Desc b1,
CDesc c,
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{},
CElementwiseOperation c_element_op = CElementwiseOperation{})
{
return Descriptor<ADesc, BDesc, B1Desc, CDesc>(
a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op);
}
template <class Desc>
__device__ static void Run(const Desc& desc,
const float scale,
const ADataType* __restrict__ p_a_grid,
const ADataType* __restrict__ p_b_grid,
const ADataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid)
{
assert(desc.is_valid);
__shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()];
AccElementwiseOperation acc_element_op{scale};
if(desc.has_main_k_block_loop)
{
Desc::GridwiseGemm::template Run<true>(p_a_grid,
p_b_grid,
p_b1_grid,
p_c_grid,
p_shared_block,
desc.a_element_op,
desc.b_element_op,
acc_element_op,
desc.b1_element_op,
desc.c_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.b1_grid_desc_bk0_n_bk1,
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
desc.block_2_ctile_map,
desc.c0_matrix_mask);
}
else
{
Desc::GridwiseGemm::template Run<false>(p_a_grid,
p_b_grid,
p_b1_grid,
p_c_grid,
p_shared_block,
desc.a_element_op,
desc.b_element_op,
acc_element_op,
desc.b1_element_op,
desc.c_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.b1_grid_desc_bk0_n_bk1,
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
desc.block_2_ctile_map,
desc.c0_matrix_mask);
}
}
};
} // namespace device
......
......@@ -581,7 +581,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{
return false;
}
......
......@@ -53,7 +53,7 @@ struct MaskOutUpperTrianglePredicate
template <typename MaskOutPredicate>
struct C0MatrixMask_impl
{
C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {}
constexpr C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {}
__host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
......
......@@ -13,6 +13,7 @@ execute_process(
)
add_library(jit_library STATIC
src/device_batched_gemm_softmax_gemm.cpp
src/device_gemm_multiple_d.cpp
src/common.cpp
)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/host/common.hpp"
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
std::size_t O = 0;
bool TransA = false;
bool TransB = false;
bool TransB1 = false;
bool TransC = false;
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType B1DataType = DataType::Half;
DataType CDataType = DataType::Half;
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string B1ElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string CElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string AccElementOp = "ck::tensor_operation::element_wise::Scale";
std::string GetIncludeHeader() const;
std::vector<Solution> GetSolutions(const std::string& arch) const;
private:
std::vector<std::string> GetInstances(const std::string& arch) const;
Solution MakeSolution(std::size_t idx, const std::string& arch) const;
static const std::size_t DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle_idx = 0;
static const std::size_t ALayout_idx = 1;
static const std::size_t B0Layout_idx = 2;
static const std::size_t B1Layout_idx = 3;
static const std::size_t CLayout_idx = 4;
static const std::size_t ADataType_idx = 5;
static const std::size_t B0DataType_idx = 6;
static const std::size_t B1DataType_idx = 7;
static const std::size_t CDataType_idx = 8;
static const std::size_t AccDataType_idx = 9;
static const std::size_t CShuffleDataType_idx = 10;
static const std::size_t AElementwiseOperation_idx = 11;
static const std::size_t B0ElementwiseOperation_idx = 12;
static const std::size_t Acc0ElementwiseOperation_idx = 13;
static const std::size_t B1ElementwiseOperation_idx = 14;
static const std::size_t CElementwiseOperation_idx = 15;
static const std::size_t GEMMSpecialization_idx = 16;
static const std::size_t NumGemmKPrefetchStage_idx = 17;
static const std::size_t BlockSize_idx = 18;
static const std::size_t Gemm01MPerBlock_idx = 19;
static const std::size_t Gemm0NPerBlock_idx = 20;
static const std::size_t Gemm0KPerBlock_idx = 21;
static const std::size_t Gemm1NPerBlock_idx = 22;
static const std::size_t Gemm1KPerBlock_idx = 23;
static const std::size_t AK1_idx = 24;
static const std::size_t BK1_idx = 25;
static const std::size_t B1K1_idx = 26;
static const std::size_t MPerXDL_idx = 27;
static const std::size_t NPerXDL_idx = 28;
static const std::size_t Gemm0MXdlPerWave_idx = 29;
static const std::size_t Gemm0NXdlPerWave_idx = 30;
static const std::size_t Gemm1NXdlPerWave_idx = 31;
static const std::size_t ABlockTransferThreadClusterLengths_K0_M_K1_idx = 32;
static const std::size_t ABlockTransferThreadClusterArrangeOrder_idx = 33;
static const std::size_t ABlockTransferSrcAccessOrder_idx = 34;
static const std::size_t ABlockTransferSrcVectorDim_idx = 35;
static const std::size_t ABlockTransferSrcScalarPerVector_idx = 36;
static const std::size_t ABlockTransferDstScalarPerVector_K1_idx = 37;
static const std::size_t ABlockLdsAddExtraM_idx = 38;
static const std::size_t B0BlockTransferThreadClusterLengths_K0_N_K1_idx = 39;
static const std::size_t B0BlockTransferThreadClusterArrangeOrder_idx = 40;
static const std::size_t B0BlockTransferSrcAccessOrder_idx = 41;
static const std::size_t B0BlockTransferSrcVectorDim_idx = 42;
static const std::size_t B0BlockTransferSrcScalarPerVector_idx = 43;
static const std::size_t B0BlockTransferDstScalarPerVector_K1_idx = 44;
static const std::size_t B0BlockLdsAddExtraN_idx = 45;
static const std::size_t B1BlockTransferThreadClusterLengths_K0_N_K1_idx = 46;
static const std::size_t B1BlockTransferThreadClusterArrangeOrder_idx = 47;
static const std::size_t B1BlockTransferSrcAccessOrder_idx = 48;
static const std::size_t B1BlockTransferSrcVectorDim_idx = 49;
static const std::size_t B1BlockTransferSrcScalarPerVector_idx = 50;
static const std::size_t B1BlockTransferDstScalarPerVector_K1_idx = 51;
static const std::size_t B1BlockLdsAddExtraN_idx = 52;
static const std::size_t CShuffleMXdlPerWavePerShuffle_idx = 53;
static const std::size_t CShuffleNXdlPerWavePerShuffle_idx = 54;
static const std::size_t CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl_idx = 55;
static const std::size_t CBlockTransferScalarPerVector_NWaveNPerXdl_idx = 56;
static const std::size_t MaskOutUpperTriangle_idx = 57;
};
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
#include "ck/host/common.hpp"
#include "batched_gemm_softmax_gemm_instances.hpp"
#include <algorithm>
#include <unordered_set>
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
std::string GetGemmSpec(const std::size_t m,
const std::size_t n,
const std::size_t k,
const std::size_t n1,
const std::size_t m_per_block,
const std::size_t n_per_block,
const std::size_t k_per_block,
const std::size_t n1_per_block)
{
std::string spec = "";
if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0)
spec += "M";
if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0)
spec += "N";
if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0)
spec += "K";
if(integer_divide_ceil(n1, n1_per_block) * n1_per_block - n1 != 0)
spec += "O";
if(spec == "")
return "ck::tensor_operation::device::GemmSpecialization::Default";
return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding";
}
std::size_t GetGridSize(const std::size_t m,
const std::size_t n,
const std::size_t m_per_block,
const std::size_t n_per_block)
{
return integer_divide_ceil(m, m_per_block) * integer_divide_ceil(n, n_per_block);
}
const std::unordered_set<std::string>& get_xdlop_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx90a", "gfx908", "gfx940"};
return supported_archs;
}
std::vector<std::string> Problem::GetInstances(const std::string& arch) const
{
std::vector<std::string> instances;
if(get_xdlop_archs().find(arch) != get_xdlop_archs().end())
{
ck::host::instance::batched_gemm_softmax_gemm_instances all_instances{};
instances = all_instances.get_instances();
}
return instances;
}
Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
{
auto template_str = GetInstances(arch).at(idx);
std::istringstream iss(template_str);
std::vector<std::string> params(std::istream_iterator<std::string>{iss},
std::istream_iterator<std::string>());
params[AElementwiseOperation_idx] = AElementOp;
params[B0ElementwiseOperation_idx] = BElementOp;
params[B1ElementwiseOperation_idx] = BElementOp;
params[CElementwiseOperation_idx] = CElementOp;
params[Acc0ElementwiseOperation_idx] = AccElementOp;
auto block_size_str = params[BlockSize_idx];
auto m_per_block_str = params[Gemm01MPerBlock_idx];
auto n_per_block_str = params[Gemm0NPerBlock_idx];
auto k_per_block_str = params[Gemm0KPerBlock_idx];
auto n1_per_block_str = params[Gemm1NPerBlock_idx];
const std::size_t block_size = std::stoi(block_size_str);
const std::size_t m_per_block = std::stoi(m_per_block_str);
const std::size_t n_per_block = std::stoi(n_per_block_str);
const std::size_t k_per_block = std::stoi(k_per_block_str);
const std::size_t n1_per_block = std::stoi(n1_per_block_str);
const std::size_t grid_size = GetGridSize(M, O, m_per_block, n1_per_block);
params[GEMMSpecialization_idx] = GetGemmSpec(M, N, K, O, m_per_block, n_per_block, k_per_block, n1_per_block);
std::string str = std::accumulate(
params.begin() + 1,
params.end(),
std::string{},
[](const std::string& a, const std::string& b) { return a.empty() ? b : a + ", " + b; });
str = params.front() + "< " + str + ">";
return Solution{str, block_size, grid_size};
}
std::string Problem::GetIncludeHeader() const
{
return ck::host::instance::batched_gemm_softmax_gemm_instances{}.get_include_header();
}
std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
{
std::vector<Solution> solutions;
const std::size_t num_instances = GetInstances(arch).size();
for(std::size_t i = 0; i < num_instances; ++i)
{
solutions.push_back(MakeSolution(i, arch));
}
return solutions;
}
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
out_file_with_quant = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
namespace ck {{
namespace host {{
namespace instance {{
struct {op_name}_instances
{{
static inline std::vector<std::string> {col_row_name} =
{{
{col_row_instances}
}};
static inline std::vector<std::string> {col_col_name} =
{{
{col_col_instances}
}};
static inline std::vector<std::string> {row_row_name} =
{{
{row_row_instances}
}};
static inline std::vector<std::string> {row_col_name} =
{{
{row_col_instances}
}};
static inline std::vector<std::string> {int8_col_row_name} =
{{
{int8_col_row_instances}
}};
static inline std::vector<std::string> {int8_col_col_name} =
{{
{int8_col_col_instances}
}};
static inline std::vector<std::string> {int8_row_row_name} =
{{
{int8_row_row_instances}
}};
static inline std::vector<std::string> {int8_row_col_name} =
{{
{int8_row_col_instances}
}};
static auto get_col_row_instances(const bool quantize)
{{
return quantize ? {int8_col_row_name} :
{col_row_name};
}}
static auto get_col_col_instances(const bool quantize)
{{
return quantize ? {int8_col_col_name} :
{col_col_name};
}}
static auto get_row_row_instances(const bool quantize)
{{
return quantize ? {int8_row_row_name} :
{row_row_name};
}}
static auto get_row_col_instances(const bool quantize)
{{
return quantize ? {int8_row_col_name} :
{row_col_name};
}}
static auto get_include_header()
{{
return "{include_header}";
}}
}};
}} // namespace instance
}} // namespace host
}} // namespace ck
"""
out_file_no_quant = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
namespace ck {{
namespace host {{
namespace instance {{
struct {op_name}_instances
{{
static inline std::vector<std::string> {instances_name} =
{{
{instances}
}};
static auto get_instances()
{{
return {instances_name};
}}
static auto get_include_header()
{{
return "{include_header}";
}}
}};
}} // namespace instance
}} // namespace host
}} // namespace ck
"""
def get_device_gemm_multiple_d_file(op_name,
col_row_name,
col_row_instances,
col_col_name,
col_col_instances,
row_row_name,
row_row_instances,
row_col_name,
row_col_instances,
int8_col_row_name,
int8_col_row_instances,
int8_col_col_name,
int8_col_col_instances,
int8_row_row_name,
int8_row_row_instances,
int8_row_col_name,
int8_row_col_instances,
include_header):
return out_file_with_quant.format(
op_name=op_name,
col_row_name=col_row_name,
col_row_instances=col_row_instances,
col_col_name=col_col_name,
col_col_instances=col_col_instances,
row_row_name=row_row_name,
row_row_instances=row_row_instances,
row_col_name=row_col_name,
row_col_instances=row_col_instances,
int8_col_row_name=int8_col_row_name,
int8_col_row_instances=int8_col_row_instances,
int8_col_col_name=int8_col_col_name,
int8_col_col_instances=int8_col_col_instances,
int8_row_row_name=int8_row_row_name,
int8_row_row_instances=int8_row_row_instances,
int8_row_col_name=int8_row_col_name,
int8_row_col_instances=int8_row_col_instances,
include_header=include_header)
def get_device_gemm_softmax_gemm_file(op_name,
instances_name,
instances,
include_header):
return out_file_no_quant.format(
op_name=op_name,
instances_name=instances_name,
instances=instances,
include_header=include_header)
import argparse, re, json, os, sys
out_file = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
namespace ck {{
namespace host {{
namespace instance {{
struct {op_name}_instances
{{
static inline std::vector<std::string> {col_row_name} =
{{
{col_row_instances}
}};
static inline std::vector<std::string> {col_col_name} =
{{
{col_col_instances}
}};
static inline std::vector<std::string> {row_row_name} =
{{
{row_row_instances}
}};
static inline std::vector<std::string> {row_col_name} =
{{
{row_col_instances}
}};
static inline std::vector<std::string> {int8_col_row_name} =
{{
{int8_col_row_instances}
}};
static inline std::vector<std::string> {int8_col_col_name} =
{{
{int8_col_col_instances}
}};
static inline std::vector<std::string> {int8_row_row_name} =
{{
{int8_row_row_instances}
}};
static inline std::vector<std::string> {int8_row_col_name} =
{{
{int8_row_col_instances}
}};
static auto get_col_row_instances(const bool quantize)
{{
return quantize ? {int8_col_row_name} :
{col_row_name};
}}
static auto get_col_col_instances(const bool quantize)
{{
return quantize ? {int8_col_col_name} :
{col_col_name};
}}
static auto get_row_row_instances(const bool quantize)
{{
return quantize ? {int8_row_row_name} :
{row_row_name};
}}
static auto get_row_col_instances(const bool quantize)
{{
return quantize ? {int8_row_col_name} :
{row_col_name};
}}
static auto get_include_header()
{{
return "{include_header}";
}}
}};
}} // namespace instance
}} // namespace host
}} // namespace ck
"""
import argparse, re, json, os, sys, file_templates
def strip_sequences(str):
matches = re.findall(r'S<\d+(?:,\s*\d+)*>', str)
matches = re.findall(r'S<\s*\d+(?:,\s*\d+)*>', str)
for match in matches:
str = str.replace(match, match.replace(' ', ''))
str = str.replace('S<', "ck::Sequence<")
......@@ -251,27 +161,206 @@ def parse_instances(source, out_dir):
int8_file = "/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp"
int8_instances = get_int8_instances(source, int8_file, "DeviceGemmMultipleD_Xdl_CShuffle")
with open(os.path.join(out_dir, out_file_name), "w+") as f:
f.write(out_file.format(op_name=op_name,
col_row_name=col_row_name,
col_row_instances="\n".join(col_row_instances),
col_col_name=col_col_name,
col_col_instances="\n".join(col_col_instances),
row_row_name=row_row_name,
row_row_instances="\n".join(row_row_instances),
row_col_name=row_col_name,
row_col_instances="\n".join(row_col_instances),
int8_col_row_name=int8_instances["col_row_name"],
int8_col_row_instances="\n".join(int8_instances["col_row"]),
int8_col_col_name=int8_instances["col_col_name"],
int8_col_col_instances="\n".join(int8_instances["col_col"]),
int8_row_row_name=int8_instances["row_row_name"],
int8_row_row_instances="\n".join(int8_instances["row_row"]),
int8_row_col_name=int8_instances["row_col_name"],
int8_row_col_instances="\n".join(int8_instances["row_col"]),
include_header=include_header))
f.write(file_templates.get_device_gemm_multiple_d_file(
op_name,
col_row_name,
"\n".join(col_row_instances),
col_col_name,
"\n".join(col_col_instances),
row_row_name,
"\n".join(row_row_instances),
row_col_name,
"\n".join(row_col_instances),
int8_instances["col_row_name"],
"\n".join(int8_instances["col_row"]),
int8_instances["col_col_name"],
"\n".join(int8_instances["col_col"]),
int8_instances["row_row_name"],
"\n".join(int8_instances["row_row"]),
int8_instances["row_col_name"],
"\n".join(int8_instances["row_col"]),
include_header))
def parse_device_gemm_multiple_d_instances(source, out_dir):
aliases = {"F16_F16_Tuple": "ck::Tuple<F16,F16>",
"Row_Row_Tuple": "ck::Tuple<Row,Row>",
"Empty_Tuple": "ck::Tuple<>",
"LoopScheduler": "ck::LoopScheduler",
"PipelineVersion": "ck::PipelineVersion",
"Row": "ck::tensor_layout::gemm::RowMajor",
"Col": "ck::tensor_layout::gemm::ColumnMajor",
"F16": "ck::half_t",
"F32": "float",
"OutElementOp": "PassThrough"}
device_ops = {"gemm_add_add_fastgelu": "DeviceGemmMultipleD_Xdl_CShuffle",
#"batched_gemm_softmax_gemm": "DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle"
}
for root_, dirs_, files_ in os.walk(source):
for dir in dirs_:
op_name = os.path.split(dir)[-1]
if op_name not in device_ops:
continue
col_row_name = ""
col_col_name = ""
row_row_name = ""
row_col_name = ""
row_row_instances = []
col_row_instances = []
row_col_instances = []
col_col_instances = []
for root, dirs, files in os.walk(os.path.join(root_, dir)):
for file in files:
if not file.endswith(".cpp"):
continue;
file_name = os.path.split(file)[-1]
is_row_row = bool(re.search(".*mk.*kn.*", file_name))
is_col_row = bool(re.search(".*km.*kn.*", file_name))
is_row_col = bool(re.search(".*mk.*nk.*", file_name))
is_col_col = bool(re.search(".*km.*nk.*", file_name))
if is_row_row:
row_row_name = file_name[:-4]
if is_col_row:
col_row_name = file_name[:-4]
if is_row_col:
row_col_name = file_name[:-4]
if is_col_col:
col_col_name = file_name[:-4]
instances_list = []
template_name = device_ops[op_name]
include_header = ""
with open(os.path.join(root, file)) as f:
for line in f:
if "impl" in line:
include_header = line.replace("#include \"", "").replace("\"", "").replace("\n", "")
elif template_name in line:
# Turn all whitespace into single spaces
new_line = " ".join(line.split())
# Remove whitespace from S<*>
new_line = strip_sequences(new_line)
new_line = remove_commas_and_brackets(new_line)
last_char = "\n"
if new_line[-1] == ",":
last_char = ",\n"
new_line = new_line[:-1]
new_line = ' "ck::tensor_operation::device::' + new_line + '",'
for key in aliases:
new_line = new_line.replace(key, aliases[key])
instances_list.append(new_line)
instances_list[-1] = instances_list[-1][:-1]
if is_row_row:
row_row_instances = instances_list
if is_col_row:
col_row_instances = instances_list
if is_row_col:
row_col_instances = instances_list
if is_col_col:
col_col_instances = instances_list
out_file_name = op_name + "_instances.hpp"
if not os.path.exists(out_dir):
os.mkdir(out_dir)
int8_file = "/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp"
int8_instances = get_int8_instances(source, int8_file, "DeviceGemmMultipleD_Xdl_CShuffle")
with open(os.path.join(out_dir, out_file_name), "w+") as f:
f.write(file_templates.get_device_gemm_multiple_d_file(
op_name,
col_row_name,
"\n".join(col_row_instances),
col_col_name,
"\n".join(col_col_instances),
row_row_name,
"\n".join(row_row_instances),
row_col_name,
"\n".join(row_col_instances),
int8_instances["col_row_name"],
"\n".join(int8_instances["col_row"]),
int8_instances["col_col_name"],
"\n".join(int8_instances["col_col"]),
int8_instances["row_row_name"],
"\n".join(int8_instances["row_row"]),
int8_instances["row_col_name"],
"\n".join(int8_instances["row_col"]),
include_header))
def parse_param_names(file):
param_names = []
for line in file:
if bool(re.search(r"\s*//#+", line)):
names = line.split('|')
names = [n.strip() for n in names]
if not param_names:
param_names = [""] * len(names)
param_names = [a + b for a, b in zip(param_names, names)]
elif param_names:
param_names[0] = line.split('<')[0].strip()
file.seek(0)
return param_names[:-1]
file.seek(0)
return param_names[:-1]
def parse_device_batched_gemm_softmax_gemm_instances(source, out_dir):
aliases = {"Row": "ck::tensor_layout::gemm::RowMajor",
"Col": "ck::tensor_layout::gemm::ColumnMajor",
"F16": "ck::half_t",
"F32": "float",
"PassThrough": "ck::tensor_operation::element_wise::PassThrough",
"Scale": "ck::tensor_operation::element_wise::Scale",
"GemmPadded": "ck::tensor_operation::device::GemmSpecialization::MNKOPadding",
"GemmDefault": "ck::tensor_operation::device::GemmSpecialization::Default"}
device_ops = {"batched_gemm_softmax_gemm": "DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle"
}
for root_, dirs_, files_ in os.walk(source):
for dir in dirs_:
op_name = os.path.split(dir)[-1]
if "permute" in op_name or op_name not in device_ops:
continue
for root, dirs, files in os.walk(os.path.join(root_, dir)):
for file in files:
if not file.endswith(".cpp"):
continue;
file_name = os.path.split(file)[-1]
instances_name = file_name[:-4]
instances_list = []
template_name = device_ops[op_name]
include_header = ""
with open(os.path.join(root, file)) as f:
param_names = parse_param_names(f)
# for i in range(len(param_names)):
# print(f"{i}: {param_names[i]}")
for line in f:
if "impl" in line:
include_header = line.replace("#include \"", "").replace("\"", "").replace("\n", "")
elif template_name in line:
# Turn all whitespace into single spaces
new_line = " ".join(line.split())
# Remove whitespace from S<*>
new_line = strip_sequences(new_line)
new_line = remove_commas_and_brackets(new_line)
last_char = "\n"
if new_line[-1] == ",":
last_char = ",\n"
new_line = new_line[:-1]
new_line = ' "ck::tensor_operation::device::' + new_line + '",'
for key in aliases:
new_line = new_line.replace(key, aliases[key])
masking = new_line.replace("Masking", "true")
no_masking = new_line.replace("Masking", "false")
instances_list.append(masking)
instances_list.append(no_masking)
out_file_name = op_name + "_instances.hpp"
if not os.path.exists(out_dir):
os.mkdir(out_dir)
with open(os.path.join(out_dir, out_file_name), "w+") as f:
f.write(file_templates.get_device_gemm_softmax_gemm_file(
op_name,
instances_name,
"\n".join(instances_list),
include_header))
def run(args):
parse_instances(args[0], args[1])
parse_device_gemm_multiple_d_instances(args[0], args[1])
parse_device_batched_gemm_softmax_gemm_instances(args[0], args[1])
if __name__ == '__main__':
run(sys.argv[1:])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment