"vscode:/vscode.git/clone" did not exist on "19ff47dfb48ef53d6a7485687e2a3a63cfe611de"
Commit 82c58e44 authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent 84189dd5
......@@ -153,7 +153,7 @@ template <typename ALayout,
ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler(),
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct CK_DeviceGemmMultipleD
{
......
......@@ -124,7 +124,7 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, Gemm1NPerBlock, decltype(c_grid_desc_m_n)>(
c_grid_desc_m_n);
//using C0MatrixMask = ck::conditional_t<gemm.get_MOUT(),
// using C0MatrixMask = ck::conditional_t<gemm.get_MOUT(),
// C0MatrixMask_impl<MaskOutUpperTrianglePredicate>,
// C0MatrixMask_impl<MaskDisabledPredicate>>;
// template<>
......
......@@ -140,8 +140,10 @@ struct MaskDisabledPredicate
return false;
};
__host__ __device__ constexpr bool
IsTileSkippable(ck::index_t /*m*/, ck::index_t /*n*/, ck::index_t /*m_tile*/, ck::index_t /*n_tile*/) const
__host__ __device__ constexpr bool IsTileSkippable(ck::index_t /*m*/,
ck::index_t /*n*/,
ck::index_t /*m_tile*/,
ck::index_t /*n_tile*/) const
{
return false;
}
......@@ -149,7 +151,10 @@ struct MaskDisabledPredicate
struct MaskOutUpperTrianglePredicate
{
__host__ __device__ constexpr bool operator()(ck::index_t m, ck::index_t n) const { return n > m; }
__host__ __device__ constexpr bool operator()(ck::index_t m, ck::index_t n) const
{
return n > m;
}
__host__ __device__ constexpr bool
IsTileSkippable(ck::index_t m, ck::index_t n, ck::index_t m_tile, ck::index_t /*n_tile*/) const
......@@ -163,7 +168,10 @@ struct MaskOutUpperTrianglePredicate
template <typename MaskOutPredicate>
struct C0MatrixMask_impl
{
__host__ __device__ C0MatrixMask_impl(ck::index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {}
__host__ __device__ C0MatrixMask_impl(ck::index_t NRaw)
: NRaw_(NRaw), predicate_(MaskOutPredicate{})
{
}
__host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ ck::index_t n) const
{
......@@ -266,15 +274,17 @@ struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CElementwiseOperation c_element_op{};
AccElementwiseOperation acc_element_op{alpha};
//static constexpr auto get_MOUT() { return MaskOutUpperTriangle; };
// static constexpr auto get_MOUT() { return MaskOutUpperTriangle; };
using C0MatrixMask = ck::conditional_t<MaskOutUpperTriangle,
C0MatrixMask_impl<MaskOutUpperTrianglePredicate>,
C0MatrixMask_impl<MaskDisabledPredicate>>;
C0MatrixMask_impl<MaskOutUpperTrianglePredicate>,
C0MatrixMask_impl<MaskDisabledPredicate>>;
struct C0MM_Wrapper
{
__device__ C0MM_Wrapper(const unsigned int n) : c0_matrix_mask_{static_cast<ck::index_t>(n)} {}
__device__ C0MM_Wrapper(const unsigned int n) : c0_matrix_mask_{static_cast<ck::index_t>(n)}
{
}
C0MatrixMask c0_matrix_mask_;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment