Commit 84189dd5 authored by Alan Turner's avatar Alan Turner
Browse files

Update to use new gemm instances and new gsg

parent 6ac41ed8
This diff is collapsed.
......@@ -153,7 +153,8 @@ template <typename ALayout,
ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()>
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct CK_DeviceGemmMultipleD
{
static constexpr auto I0 = ck::Number<0>{};
......@@ -210,7 +211,8 @@ struct CK_DeviceGemmMultipleD
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
LoopSched,
PipelineVer>;
// return block_id to E matrix tile idx (m0, n0) mapping
template <class EGridDesc_M_N>
......
......@@ -124,7 +124,12 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, Gemm1NPerBlock, decltype(c_grid_desc_m_n)>(
c_grid_desc_m_n);
const C0MatrixMask c0_matrix_mask(n);
//using C0MatrixMask = ck::conditional_t<gemm.get_MOUT(),
// C0MatrixMask_impl<MaskOutUpperTrianglePredicate>,
// C0MatrixMask_impl<MaskDisabledPredicate>>;
// template<>
// C0MatrixMask_impl c0_matrix_mask<MaskDisabledPredicate>{n};
typename G::C0MM_Wrapper cw(n);
const auto K = a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) *
a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{});
......@@ -159,7 +164,7 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map,
c0_matrix_mask);
cw.c0_matrix_mask_);
}
template <class G, index_int BlocksPerBatch, class... Ts>
......
......@@ -113,24 +113,78 @@ struct BlockToCTileMap_M00_N0_M01Adapt
CGridDesc_M_N c_grid_desc_m_n_;
};
// // to track the points which need to be set to -inf on C0
// // Note: no need to reset M padding value, because they will not be stored out.
// struct C0MatrixMask
// {
// __device__ C0MatrixMask(ck::index_t NRaw) : NRaw_(NRaw) {}
// __device__ bool IsUpperTriangle(ck::index_t m, ck::index_t n) const { return n > m; }
// __device__ bool IsNOutOfBound(/*ck::index_t m, */ ck::index_t n) const { return n >= NRaw_; }
// __device__ bool IsMaskedElement(ck::index_t m, ck::index_t n) const
// {
// return IsUpperTriangle(m, n) || IsNOutOfBound(n);
// }
// private:
// // ck::index_t MRaw_;
// ck::index_t NRaw_;
// };
struct MaskDisabledPredicate
{
__host__ __device__ constexpr bool operator()(ck::index_t /*m*/, ck::index_t /*n*/) const
{
return false;
};
__host__ __device__ constexpr bool
IsTileSkippable(ck::index_t /*m*/, ck::index_t /*n*/, ck::index_t /*m_tile*/, ck::index_t /*n_tile*/) const
{
return false;
}
};
struct MaskOutUpperTrianglePredicate
{
__host__ __device__ constexpr bool operator()(ck::index_t m, ck::index_t n) const { return n > m; }
__host__ __device__ constexpr bool
IsTileSkippable(ck::index_t m, ck::index_t n, ck::index_t m_tile, ck::index_t /*n_tile*/) const
{
return operator()(m + m_tile - 1, n);
}
};
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
template <typename MaskOutPredicate>
struct C0MatrixMask_impl
{
__device__ C0MatrixMask(ck::index_t NRaw) : NRaw_(NRaw) {}
__host__ __device__ C0MatrixMask_impl(ck::index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {}
__device__ bool IsUpperTriangle(ck::index_t m, ck::index_t n) const { return n > m; }
__host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ ck::index_t n) const
{
return n >= NRaw_;
}
__device__ bool IsNOutOfBound(/*ck::index_t m, */ ck::index_t n) const { return n >= NRaw_; }
__host__ __device__ constexpr bool IsMaskedElement(ck::index_t m, ck::index_t n) const
{
return predicate_(m, n) || IsNOutOfBound(n);
}
__device__ bool IsMaskedElement(ck::index_t m, ck::index_t n) const
__host__ __device__ constexpr bool
IsTileSkippable(ck::index_t m, ck::index_t n, ck::index_t m_tile, ck::index_t n_tile) const
{
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
return predicate_.IsTileSkippable(m, n, m_tile, n_tile);
}
private:
// ck::index_t MRaw_;
// index_t MRaw_;
ck::index_t NRaw_;
MaskOutPredicate predicate_;
};
template <typename ALayout,
......@@ -212,6 +266,19 @@ struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CElementwiseOperation c_element_op{};
AccElementwiseOperation acc_element_op{alpha};
//static constexpr auto get_MOUT() { return MaskOutUpperTriangle; };
using C0MatrixMask = ck::conditional_t<MaskOutUpperTriangle,
C0MatrixMask_impl<MaskOutUpperTrianglePredicate>,
C0MatrixMask_impl<MaskDisabledPredicate>>;
struct C0MM_Wrapper
{
__device__ C0MM_Wrapper(const unsigned int n) : c0_matrix_mask_{static_cast<ck::index_t>(n)} {}
C0MatrixMask c0_matrix_mask_;
};
template <typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1,
......
......@@ -127,6 +127,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops{},
dead_code_elimination{},
auto_contiguous{},
fuse_ck_gemm_softmax_gemm{&ctx},
dead_code_elimination{},
optimize_module{},
enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}),
dead_code_elimination{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment