"driver/include/device.hpp" did not exist on "c82b833d8e76094a3702046d81872132d5c4b15a"
Commit 84189dd5 authored by Alan Turner's avatar Alan Turner
Browse files

Update to use new gemm instances and new gsg

parent 6ac41ed8
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -153,7 +153,8 @@ template <typename ALayout,
ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()>
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct CK_DeviceGemmMultipleD
{
static constexpr auto I0 = ck::Number<0>{};
......@@ -210,7 +211,8 @@ struct CK_DeviceGemmMultipleD
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
LoopSched,
PipelineVer>;
// return block_id to E matrix tile idx (m0, n0) mapping
template <class EGridDesc_M_N>
......
......@@ -124,7 +124,12 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, Gemm1NPerBlock, decltype(c_grid_desc_m_n)>(
c_grid_desc_m_n);
const C0MatrixMask c0_matrix_mask(n);
//using C0MatrixMask = ck::conditional_t<gemm.get_MOUT(),
// C0MatrixMask_impl<MaskOutUpperTrianglePredicate>,
// C0MatrixMask_impl<MaskDisabledPredicate>>;
// template<>
// C0MatrixMask_impl c0_matrix_mask<MaskDisabledPredicate>{n};
typename G::C0MM_Wrapper cw(n);
const auto K = a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) *
a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{});
......@@ -159,7 +164,7 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map,
c0_matrix_mask);
cw.c0_matrix_mask_);
}
template <class G, index_int BlocksPerBatch, class... Ts>
......
......@@ -113,24 +113,78 @@ struct BlockToCTileMap_M00_N0_M01Adapt
CGridDesc_M_N c_grid_desc_m_n_;
};
// // to track the points which need to be set to -inf on C0
// // Note: no need to reset M padding value, because they will not be stored out.
// struct C0MatrixMask
// {
// __device__ C0MatrixMask(ck::index_t NRaw) : NRaw_(NRaw) {}
// __device__ bool IsUpperTriangle(ck::index_t m, ck::index_t n) const { return n > m; }
// __device__ bool IsNOutOfBound(/*ck::index_t m, */ ck::index_t n) const { return n >= NRaw_; }
// __device__ bool IsMaskedElement(ck::index_t m, ck::index_t n) const
// {
// return IsUpperTriangle(m, n) || IsNOutOfBound(n);
// }
// private:
// // ck::index_t MRaw_;
// ck::index_t NRaw_;
// };
struct MaskDisabledPredicate
{
__host__ __device__ constexpr bool operator()(ck::index_t /*m*/, ck::index_t /*n*/) const
{
return false;
};
__host__ __device__ constexpr bool
IsTileSkippable(ck::index_t /*m*/, ck::index_t /*n*/, ck::index_t /*m_tile*/, ck::index_t /*n_tile*/) const
{
return false;
}
};
struct MaskOutUpperTrianglePredicate
{
__host__ __device__ constexpr bool operator()(ck::index_t m, ck::index_t n) const { return n > m; }
__host__ __device__ constexpr bool
IsTileSkippable(ck::index_t m, ck::index_t n, ck::index_t m_tile, ck::index_t /*n_tile*/) const
{
return operator()(m + m_tile - 1, n);
}
};
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
template <typename MaskOutPredicate>
struct C0MatrixMask_impl
{
__device__ C0MatrixMask(ck::index_t NRaw) : NRaw_(NRaw) {}
__host__ __device__ C0MatrixMask_impl(ck::index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {}
__device__ bool IsUpperTriangle(ck::index_t m, ck::index_t n) const { return n > m; }
__host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ ck::index_t n) const
{
return n >= NRaw_;
}
__device__ bool IsNOutOfBound(/*ck::index_t m, */ ck::index_t n) const { return n >= NRaw_; }
__host__ __device__ constexpr bool IsMaskedElement(ck::index_t m, ck::index_t n) const
{
return predicate_(m, n) || IsNOutOfBound(n);
}
__device__ bool IsMaskedElement(ck::index_t m, ck::index_t n) const
__host__ __device__ constexpr bool
IsTileSkippable(ck::index_t m, ck::index_t n, ck::index_t m_tile, ck::index_t n_tile) const
{
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
return predicate_.IsTileSkippable(m, n, m_tile, n_tile);
}
private:
// ck::index_t MRaw_;
// index_t MRaw_;
ck::index_t NRaw_;
MaskOutPredicate predicate_;
};
template <typename ALayout,
......@@ -212,6 +266,19 @@ struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CElementwiseOperation c_element_op{};
AccElementwiseOperation acc_element_op{alpha};
//static constexpr auto get_MOUT() { return MaskOutUpperTriangle; };
using C0MatrixMask = ck::conditional_t<MaskOutUpperTriangle,
C0MatrixMask_impl<MaskOutUpperTrianglePredicate>,
C0MatrixMask_impl<MaskDisabledPredicate>>;
struct C0MM_Wrapper
{
__device__ C0MM_Wrapper(const unsigned int n) : c0_matrix_mask_{static_cast<ck::index_t>(n)} {}
C0MatrixMask c0_matrix_mask_;
};
template <typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1,
......
......@@ -127,6 +127,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops{},
dead_code_elimination{},
auto_contiguous{},
fuse_ck_gemm_softmax_gemm{&ctx},
dead_code_elimination{},
optimize_module{},
enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}),
dead_code_elimination{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment