"src/include/blockwise_4d_tensor_op.hpp" did not exist on "a0584426ff5b6b8b448c971b97c9b1a4d86ba010"
Commit 82c58e44 authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent 84189dd5
...@@ -153,7 +153,7 @@ template <typename ALayout, ...@@ -153,7 +153,7 @@ template <typename ALayout,
ck::index_t CShuffleNXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock, ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler(), ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1> ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct CK_DeviceGemmMultipleD struct CK_DeviceGemmMultipleD
{ {
......
...@@ -124,7 +124,7 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1) ...@@ -124,7 +124,7 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, Gemm1NPerBlock, decltype(c_grid_desc_m_n)>( BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, Gemm1NPerBlock, decltype(c_grid_desc_m_n)>(
c_grid_desc_m_n); c_grid_desc_m_n);
//using C0MatrixMask = ck::conditional_t<gemm.get_MOUT(), // using C0MatrixMask = ck::conditional_t<gemm.get_MOUT(),
// C0MatrixMask_impl<MaskOutUpperTrianglePredicate>, // C0MatrixMask_impl<MaskOutUpperTrianglePredicate>,
// C0MatrixMask_impl<MaskDisabledPredicate>>; // C0MatrixMask_impl<MaskDisabledPredicate>>;
// template<> // template<>
......
...@@ -140,8 +140,10 @@ struct MaskDisabledPredicate ...@@ -140,8 +140,10 @@ struct MaskDisabledPredicate
return false; return false;
}; };
__host__ __device__ constexpr bool __host__ __device__ constexpr bool IsTileSkippable(ck::index_t /*m*/,
IsTileSkippable(ck::index_t /*m*/, ck::index_t /*n*/, ck::index_t /*m_tile*/, ck::index_t /*n_tile*/) const ck::index_t /*n*/,
ck::index_t /*m_tile*/,
ck::index_t /*n_tile*/) const
{ {
return false; return false;
} }
...@@ -149,7 +151,10 @@ struct MaskDisabledPredicate ...@@ -149,7 +151,10 @@ struct MaskDisabledPredicate
struct MaskOutUpperTrianglePredicate struct MaskOutUpperTrianglePredicate
{ {
__host__ __device__ constexpr bool operator()(ck::index_t m, ck::index_t n) const { return n > m; } __host__ __device__ constexpr bool operator()(ck::index_t m, ck::index_t n) const
{
return n > m;
}
__host__ __device__ constexpr bool __host__ __device__ constexpr bool
IsTileSkippable(ck::index_t m, ck::index_t n, ck::index_t m_tile, ck::index_t /*n_tile*/) const IsTileSkippable(ck::index_t m, ck::index_t n, ck::index_t m_tile, ck::index_t /*n_tile*/) const
...@@ -163,7 +168,10 @@ struct MaskOutUpperTrianglePredicate ...@@ -163,7 +168,10 @@ struct MaskOutUpperTrianglePredicate
template <typename MaskOutPredicate> template <typename MaskOutPredicate>
struct C0MatrixMask_impl struct C0MatrixMask_impl
{ {
__host__ __device__ C0MatrixMask_impl(ck::index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {} __host__ __device__ C0MatrixMask_impl(ck::index_t NRaw)
: NRaw_(NRaw), predicate_(MaskOutPredicate{})
{
}
__host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ ck::index_t n) const __host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ ck::index_t n) const
{ {
...@@ -266,15 +274,17 @@ struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -266,15 +274,17 @@ struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CElementwiseOperation c_element_op{}; CElementwiseOperation c_element_op{};
AccElementwiseOperation acc_element_op{alpha}; AccElementwiseOperation acc_element_op{alpha};
//static constexpr auto get_MOUT() { return MaskOutUpperTriangle; }; // static constexpr auto get_MOUT() { return MaskOutUpperTriangle; };
using C0MatrixMask = ck::conditional_t<MaskOutUpperTriangle, using C0MatrixMask = ck::conditional_t<MaskOutUpperTriangle,
C0MatrixMask_impl<MaskOutUpperTrianglePredicate>, C0MatrixMask_impl<MaskOutUpperTrianglePredicate>,
C0MatrixMask_impl<MaskDisabledPredicate>>; C0MatrixMask_impl<MaskDisabledPredicate>>;
struct C0MM_Wrapper struct C0MM_Wrapper
{ {
__device__ C0MM_Wrapper(const unsigned int n) : c0_matrix_mask_{static_cast<ck::index_t>(n)} {} __device__ C0MM_Wrapper(const unsigned int n) : c0_matrix_mask_{static_cast<ck::index_t>(n)}
{
}
C0MatrixMask c0_matrix_mask_; C0MatrixMask c0_matrix_mask_;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment