Commit d9676215 authored by Alan Turner's avatar Alan Turner
Browse files

Add Descriptor and Run to device op

parent 611196d5
...@@ -662,7 +662,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -662,7 +662,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto c_extent_lowest = const auto c_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw; is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
if constexpr(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && if (!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
...@@ -857,26 +857,83 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -857,26 +857,83 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1(BDesc{}))>; remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1(BDesc{}))>;
using B1GridDesc_BK0_N_BK1 = using B1GridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeB1GridDescriptor_BK0_N_BK1(B1Desc{}))>; remove_cvref_t<decltype(MakeB1GridDescriptor_BK0_N_BK1(B1Desc{}))>;
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = using CGridDesc_M_N =
remove_cvref_t<decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( remove_cvref_t<decltype(MakeCGridDescriptor_M_N(CDesc{}))>;
MakeCGridDescriptor_M_N(CDesc{})))>;
using Block2CTileMap = remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2CTileMap( // GridwiseGemm
MakeCGridDescriptor_M_N(CDesc{})))>; using GridwiseGemmSpec = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
using C0MatrixMask = conditional_t<MaskOutUpperTriangle, ADataType, // TODO: distinguish A/B datatype
C0MatrixMask_impl<MaskOutUpperTrianglePredicate>, GemmAccDataType,
C0MatrixMask_impl<MaskDisabledPredicate>>; CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
B1GridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
AK1,
BK1,
B1K1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
Gemm1NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
true,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
true,
BBlockLdsExtraN,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
false,
B1BlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
matrix_padder.PadN,
MaskOutUpperTriangle>;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1;
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_descriptor_mblock_mperblock_nblock_nperblock; CGridDesc_M_N c_grid_desc_m_n;
Block2CTileMap block_2_ctile_map;
C0MatrixMask c0_matrix_mask; C0MatrixMask c0_matrix_mask;
typename GridwiseGemmSpec::DefaultBlock2CTileMap block_2_ctile_map;
typename GridwiseGemmSpec::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_descriptor_mblock_mperblock_nblock_nperblock;
// element-wise op // element-wise op
AElementwiseOperation a_element_op; AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op; BElementwiseOperation b_element_op;
AccElementwiseOperation acc_element_op;
B1ElementwiseOperation b1_element_op; B1ElementwiseOperation b1_element_op;
CElementwiseOperation c_element_op; CElementwiseOperation c_element_op;
...@@ -889,31 +946,29 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -889,31 +946,29 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CDesc c, CDesc c,
AElementwiseOperation a_element_op_, AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_, BElementwiseOperation b_element_op_,
AccElementwiseOperation acc_element_op_,
B1ElementwiseOperation b1_element_op_, B1ElementwiseOperation b1_element_op_,
CElementwiseOperation c_element_op_) CElementwiseOperation c_element_op_)
: a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(a)}, : a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(a)},
b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)}, b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)},
b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)}, b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock{GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)},
MakeCGridDescriptor_M_N(c))}, block_2_ctile_map{GridwiseGemmSpec::MakeDefaultBlock2CTileMap(
block_2_etile_map{GridwiseGemm::MakeDefaultBlock2CTileMap( c_grid_desc_m_n)},
MakeCGridDescriptor_M_N(c))}, c_grid_descriptor_mblock_mperblock_nblock_nperblock{
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop( GridwiseGemmSpec::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n)},
has_main_k_block_loop{GridwiseGemmSpec::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))}, a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
c0_matrix_mask{c.GetLength(I1)} c0_matrix_mask{c.GetLength(I1)},
a_element_op{a_element_op_}, a_element_op{a_element_op_},
b_element_op{b_element_op_}, b_element_op{b_element_op_},
acc_element_op{acc_element_op_},
b1_element_op{b1_element_op_}, b1_element_op{b1_element_op_},
c_element_op{c_element_op_}, c_element_op{c_element_op_},
is_valid{GridwiseGemm::CheckValidity( is_valid{GridwiseGemmSpec::CheckValidity(
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
MakeCGridDescriptor_M_N(c), c_grid_desc_m_n,
block_2_ctile_map) and block_2_ctile_map)}
IsSupported(c.GetLength(I0), c.GetLength(I1), a.GetLength(I1), b1.GetLength(I1))}
{ {
} }
...@@ -927,37 +982,43 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -927,37 +982,43 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
static constexpr auto static constexpr auto
make_descriptor(ADesc a, make_descriptor(ADesc a,
BDesc b, BDesc b,
B1Desc b1desc, B1Desc b1,
CDesc c, CDesc c,
AElementwiseOperation a_element_op = AElementwiseOperation{}, AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{}, BElementwiseOperation b_element_op = BElementwiseOperation{},
AccElementwiseOperation acc_element_op = AccElementwiseOperation{},
B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{}, B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{},
CElementwiseOperation c_element_op = CElementwiseOperation{}) CElementwiseOperation c_element_op = CElementwiseOperation{})
{ {
return Descriptor<ADesc, BDesc, B1Desc, CDesc>( return Descriptor<ADesc, BDesc, B1Desc, CDesc>(
a, b, b1, c, a_element_op, b_element_op, acc_element_op, b1_element_op, c_element_op); a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op);
} }
template <class Desc> template <class Desc>
__device__ static void Run(const Desc& desc, __device__ static void Run(const Desc& desc,
const float scale,
const ADataType* __restrict__ p_a_grid, const ADataType* __restrict__ p_a_grid,
const ADataType* __restrict__ p_b_grid, const ADataType* __restrict__ p_b_grid,
const ADataType* __restrict__ p_b1_grid, const ADataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid) CDataType* __restrict__ p_c_grid)
{ {
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()]; assert(desc.is_valid and
assert(desc.is_valid); IsSupported(desc.a_grid_desc_ak0_m_ak1.GetLength(I1),
desc.b_grid_desc_bk0_n_bk1.GetLength(I1),
desc.a_grid_desc_ak0_m_ak1.GetLength(I0) * desc.a_grid_desc_ak0_m_ak1.GetLength(I2),
desc.b1_grid_desc_bk0_n_bk1.GetLength(I1)));
__shared__ char p_shared_block[Desc::GridwiseGemmSpec::GetSharedMemoryNumberOfByte()];
AccElementwiseOperation acc_element_op{scale};
if(desc.has_main_k_block_loop) if(desc.has_main_k_block_loop)
{ {
GridwiseGemm::template Run<true>(p_a_grid, Desc::GridwiseGemmSpec::template Run<true>(p_a_grid,
p_b_grid, p_b_grid,
p_b1_grid, p_b1_grid,
p_c_grid, p_c_grid,
p_shared, p_shared_block,
desc.a_element_op, desc.a_element_op,
desc.b_element_op, desc.b_element_op,
desc.acc_element_op, acc_element_op,
desc.b1_element_op, desc.b1_element_op,
desc.c_element_op, desc.c_element_op,
desc.a_grid_desc_ak0_m_ak1, desc.a_grid_desc_ak0_m_ak1,
...@@ -969,14 +1030,14 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -969,14 +1030,14 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
else else
{ {
GridwiseGemm::template Run<false>(p_a_grid, Desc::GridwiseGemmSpec::template Run<false>(p_a_grid,
p_b_grid, p_b_grid,
p_b1_grid, p_b1_grid,
p_c_grid, p_c_grid,
p_shared, p_shared_block,
desc.a_element_op, desc.a_element_op,
desc.b_element_op, desc.b_element_op,
desc.acc_element_op, acc_element_op,
desc.b1_element_op, desc.b1_element_op,
desc.c_element_op, desc.c_element_op,
desc.a_grid_desc_ak0_m_ak1, desc.a_grid_desc_ak0_m_ak1,
......
...@@ -53,7 +53,7 @@ struct MaskOutUpperTrianglePredicate ...@@ -53,7 +53,7 @@ struct MaskOutUpperTrianglePredicate
template <typename MaskOutPredicate> template <typename MaskOutPredicate>
struct C0MatrixMask_impl struct C0MatrixMask_impl
{ {
C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {} constexpr C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {}
__host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const __host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const
{ {
......
...@@ -13,6 +13,7 @@ execute_process( ...@@ -13,6 +13,7 @@ execute_process(
) )
add_library(jit_library STATIC add_library(jit_library STATIC
src/device_batched_gemm_softmax_gemm.cpp
src/device_gemm_multiple_d.cpp src/device_gemm_multiple_d.cpp
src/common.cpp src/common.cpp
) )
......
...@@ -33,7 +33,16 @@ struct Problem ...@@ -33,7 +33,16 @@ struct Problem
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough"; std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string B1ElementOp = "ck::tensor_operation::element_wise::PassThrough"; std::string B1ElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string CElementOp = "ck::tensor_operation::element_wise::PassThrough"; std::string CElementOp = "ck::tensor_operation::element_wise::PassThrough";
float scale = 1.0; std::string AccElementOp = "ck::tensor_operation::element_wise::Scale";
std::string GetIncludeHeader() const;
std::vector<Solution> GetSolutions(const std::string& arch) const;
private:
std::vector<std::string> GetInstances(const std::string& arch) const;
Solution MakeSolution(std::size_t idx, const std::string& arch) const;
static const std::size_t DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle_idx = 0; static const std::size_t DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle_idx = 0;
static const std::size_t ALayout_idx = 1; static const std::size_t ALayout_idx = 1;
...@@ -93,15 +102,6 @@ struct Problem ...@@ -93,15 +102,6 @@ struct Problem
static const std::size_t CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl_idx = 55; static const std::size_t CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl_idx = 55;
static const std::size_t CBlockTransferScalarPerVector_NWaveNPerXdl_idx = 56; static const std::size_t CBlockTransferScalarPerVector_NWaveNPerXdl_idx = 56;
static const std::size_t MaskOutUpperTriangle_idx = 57; static const std::size_t MaskOutUpperTriangle_idx = 57;
std::string GetIncludeHeader() const;
std::vector<Solution> GetSolutions(const std::string& arch) const;
private:
std::vector<std::string> GetInstances(const std::string& arch) const;
Solution MakeSolution(std::size_t idx, const std::string& arch) const;
}; };
} // namespace device_batched_gemm_softmax_gemm } // namespace device_batched_gemm_softmax_gemm
......
#include "ck/host/device_batched_gemm_softmax_gemm.hpp" #include "ck/host/device_batched_gemm_softmax_gemm.hpp"
#include "ck/host/common.hpp" #include "ck/host/common.hpp"
#include "gemm_add_add_fastgelu_instances.hpp" #include "batched_gemm_softmax_gemm_instances.hpp"
#include <algorithm> #include <algorithm>
#include <unordered_set> #include <unordered_set>
...@@ -57,11 +57,6 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const ...@@ -57,11 +57,6 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const
return instances; return instances;
} }
std::string GetElementwiseScaleString(const float s)
{
return "ck::tensor_operation::element_wise::Scale{" + std::to_string(s) + "}";
}
Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
{ {
auto template_str = GetInstances(arch).at(idx); auto template_str = GetInstances(arch).at(idx);
...@@ -73,19 +68,17 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const ...@@ -73,19 +68,17 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
params[B0ElementwiseOperation_idx] = BElementOp; params[B0ElementwiseOperation_idx] = BElementOp;
params[B1ElementwiseOperation_idx] = BElementOp; params[B1ElementwiseOperation_idx] = BElementOp;
params[CElementwiseOperation_idx] = CElementOp; params[CElementwiseOperation_idx] = CElementOp;
params[Acc0ElementwiseOperation_idx] = GetElementwiseScaleString(scale); params[Acc0ElementwiseOperation_idx] = AccElementOp;
auto block_size_str = params[BlockSize_idx]; auto block_size_str = params[BlockSize_idx];
auto m_per_block_str = params[Gemm01MPerBlock_idx]; auto m_per_block_str = params[Gemm01MPerBlock_idx];
auto n_per_block_str = params[Gemm0NPerBlock_idx]; auto n_per_block_str = params[Gemm0NPerBlock_idx];
auto k_per_block_str = params[Gemm0KPerBlock_idx]; auto k_per_block_str = params[Gemm0KPerBlock_idx];
auto n1_per_block_str = params[Gemm1NPerBlock_idx]; auto n1_per_block_str = params[Gemm1NPerBlock_idx];
auto k1_per_block_str = params[Gemm1KPerBlock_idx];
const std::size_t block_size = std::stoi(block_size_str); const std::size_t block_size = std::stoi(block_size_str);
const std::size_t m_per_block = std::stoi(m_per_block_str); const std::size_t m_per_block = std::stoi(m_per_block_str);
const std::size_t n_per_block = std::stoi(n_per_block_str); const std::size_t n_per_block = std::stoi(n_per_block_str);
const std::size_t k_per_block = std::stoi(k_per_block_str); const std::size_t k_per_block = std::stoi(k_per_block_str);
const std::size_t n1_per_block = std::stoi(n1_per_block_str); const std::size_t n1_per_block = std::stoi(n1_per_block_str);
const std::size_t k1_per_block = std::stoi(k1_per_block_str);
const std::size_t grid_size = GetGridSize(M, O, m_per_block, n1_per_block); const std::size_t grid_size = GetGridSize(M, O, m_per_block, n1_per_block);
params[GEMMSpecialization_idx] = GetGemmSpec(M, N, K, O, m_per_block, n_per_block, k_per_block, n1_per_block); params[GEMMSpecialization_idx] = GetGemmSpec(M, N, K, O, m_per_block, n_per_block, k_per_block, n1_per_block);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment