Commit 4f7d9bbe authored by Alan Turner's avatar Alan Turner
Browse files

Add descriptor class and run method

parent 4100d1d8
...@@ -650,6 +650,29 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -650,6 +650,29 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
static constexpr bool IsSupported(index_t MRaw, index_t NRaw, index_t KRaw, index_t Gemm1NRaw)
{
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
if constexpr(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
return true;
}
// polymorphic // polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
{ {
...@@ -765,6 +788,205 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -765,6 +788,205 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return str.str(); return str.str();
} }
template <class ADesc, class BDesc, class B1Desc, class CDesc>
struct Descriptor
{
template<class AGridDescriptor>
static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc)
{
const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc);
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template<class BGridDescriptor>
static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc)
{
const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template<class B1GridDescriptor>
static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc)
{
const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc);
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template<class CGridDescriptor>
static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc)
{
return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc);
}
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1(ADesc{}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1(BDesc{}))>;
using B1GridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeB1GridDescriptor_BK0_N_BK1(B1Desc{}))>;
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
MakeCGridDescriptor_M_N(CDesc{})))>;
using Block2CTileMap = remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(
MakeCGridDescriptor_M_N(CDesc{})))>;
using C0MatrixMask = conditional_t<MaskOutUpperTriangle,
C0MatrixMask_impl<MaskOutUpperTrianglePredicate>,
C0MatrixMask_impl<MaskDisabledPredicate>>;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1;
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_descriptor_mblock_mperblock_nblock_nperblock;
Block2CTileMap block_2_ctile_map;
C0MatrixMask c0_matrix_mask;
// element-wise op
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
AccElementwiseOperation acc_element_op;
B1ElementwiseOperation b1_element_op;
CElementwiseOperation c_element_op;
bool has_main_k_block_loop = true;
bool is_valid = false;
constexpr Descriptor(ADesc a,
BDesc b,
B1Desc b1,
CDesc c,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
AccElementwiseOperation acc_element_op_,
B1ElementwiseOperation b1_element_op_,
CElementwiseOperation c_element_op_)
: a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(a)},
b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)},
b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock{GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
MakeCGridDescriptor_M_N(c))},
block_2_etile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(
MakeCGridDescriptor_M_N(c))},
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
c0_matrix_mask{c.GetLength(I1)}
a_element_op{a_element_op_},
b_element_op{b_element_op_},
acc_element_op{acc_element_op_},
b1_element_op{b1_element_op_},
c_element_op{c_element_op_},
is_valid{GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
MakeCGridDescriptor_M_N(c),
block_2_ctile_map) and
IsSupported(c.GetLength(I0), c.GetLength(I1), a.GetLength(I1), b1.GetLength(I1))}
{
}
constexpr bool IsValid() const
{
return is_valid;
}
};
template <class ADesc, class BDesc, class B1Desc, class CDesc>
static constexpr auto
make_descriptor(ADesc a,
BDesc b,
B1Desc b1desc,
CDesc c,
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
AccElementwiseOperation acc_element_op = AccElementwiseOperation{},
B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{},
CElementwiseOperation c_element_op = CElementwiseOperation{})
{
return Descriptor<ADesc, BDesc, B1Desc, CDesc>(
a, b, b1, c, a_element_op, b_element_op, acc_element_op, b1_element_op, c_element_op);
}
template <class Desc>
__device__ static void Run(const Desc& desc,
const ADataType* __restrict__ p_a_grid,
const ADataType* __restrict__ p_b_grid,
const ADataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid)
{
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
assert(desc.is_valid);
if(desc.has_main_k_block_loop)
{
GridwiseGemm::template Run<true>(p_a_grid,
p_b_grid,
p_b1_grid,
p_c_grid,
p_shared,
desc.a_element_op,
desc.b_element_op,
desc.acc_element_op,
desc.b1_element_op,
desc.c_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.b1_grid_desc_bk0_n_bk1,
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
desc.block_2_ctile_map,
desc.c0_matrix_mask);
}
else
{
GridwiseGemm::template Run<false>(p_a_grid,
p_b_grid,
p_b1_grid,
p_c_grid,
p_shared,
desc.a_element_op,
desc.b_element_op,
desc.acc_element_op,
desc.b1_element_op,
desc.c_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.b1_grid_desc_bk0_n_bk1,
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
desc.block_2_ctile_map,
desc.c0_matrix_mask);
}
}
}; };
} // namespace device } // namespace device
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/host/common.hpp"
namespace ck {
namespace host {
namespace device_gemm_softmax_gemm {
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
std::size_t O = 0;
bool TransA = false;
bool TransB = false;
bool TransB1 = false;
bool TransC = false;
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType B1DataType = DataType::Half;
DataType CDataType = DataType::Half;
DataType AccDataType = DataType::Float;
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string AccElementOp = "ck::tensor_operation::element_wise::Scale";
std::string B1ElementOp = "ck::tensor_operation::element_wise::PassThrough";
static const std::size_t ds_layout_idx = 3;
static const std::size_t ds_data_type_idx = 9;
static const std::size_t e_data_type_idx = 10;
static const std::size_t a_elementwise_op_idx = 11;
static const std::size_t b_elementwise_op_idx = 12;
static const std::size_t ds_elementwise_op_idx = 13;
static const std::size_t gemm_spec_idx = 14;
static const std::size_t block_size_idx = 16;
static const std::size_t m_per_block_idx = 17;
static const std::size_t n_per_block_idx = 18;
static const std::size_t k_per_block_idx = 19;
std::string GetIncludeHeader() const;
std::vector<Solution> GetSolutions(const std::string& arch) const;
private:
std::vector<std::string> GetInstances(const std::string& arch) const;
Solution MakeSolution(std::size_t idx, const std::string& arch) const;
};
} // namespace device_gemm_softmax_gemm
} // namespace host
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment