Commit 82c58e44 authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent 84189dd5
......@@ -124,7 +124,7 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, Gemm1NPerBlock, decltype(c_grid_desc_m_n)>(
c_grid_desc_m_n);
//using C0MatrixMask = ck::conditional_t<gemm.get_MOUT(),
// using C0MatrixMask = ck::conditional_t<gemm.get_MOUT(),
// C0MatrixMask_impl<MaskOutUpperTrianglePredicate>,
// C0MatrixMask_impl<MaskDisabledPredicate>>;
// template<>
......
......@@ -140,8 +140,10 @@ struct MaskDisabledPredicate
return false;
};
__host__ __device__ constexpr bool
IsTileSkippable(ck::index_t /*m*/, ck::index_t /*n*/, ck::index_t /*m_tile*/, ck::index_t /*n_tile*/) const
__host__ __device__ constexpr bool IsTileSkippable(ck::index_t /*m*/,
ck::index_t /*n*/,
ck::index_t /*m_tile*/,
ck::index_t /*n_tile*/) const
{
return false;
}
......@@ -149,7 +151,10 @@ struct MaskDisabledPredicate
struct MaskOutUpperTrianglePredicate
{
__host__ __device__ constexpr bool operator()(ck::index_t m, ck::index_t n) const { return n > m; }
__host__ __device__ constexpr bool operator()(ck::index_t m, ck::index_t n) const
{
return n > m;
}
__host__ __device__ constexpr bool
IsTileSkippable(ck::index_t m, ck::index_t n, ck::index_t m_tile, ck::index_t /*n_tile*/) const
......@@ -163,7 +168,10 @@ struct MaskOutUpperTrianglePredicate
template <typename MaskOutPredicate>
struct C0MatrixMask_impl
{
__host__ __device__ C0MatrixMask_impl(ck::index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {}
__host__ __device__ C0MatrixMask_impl(ck::index_t NRaw)
: NRaw_(NRaw), predicate_(MaskOutPredicate{})
{
}
__host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ ck::index_t n) const
{
......@@ -266,7 +274,7 @@ struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CElementwiseOperation c_element_op{};
AccElementwiseOperation acc_element_op{alpha};
//static constexpr auto get_MOUT() { return MaskOutUpperTriangle; };
// static constexpr auto get_MOUT() { return MaskOutUpperTriangle; };
using C0MatrixMask = ck::conditional_t<MaskOutUpperTriangle,
C0MatrixMask_impl<MaskOutUpperTrianglePredicate>,
......@@ -274,7 +282,9 @@ struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
struct C0MM_Wrapper
{
__device__ C0MM_Wrapper(const unsigned int n) : c0_matrix_mask_{static_cast<ck::index_t>(n)} {}
__device__ C0MM_Wrapper(const unsigned int n) : c0_matrix_mask_{static_cast<ck::index_t>(n)}
{
}
C0MatrixMask c0_matrix_mask_;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment