Commit 82072168 authored by Jing Zhang's avatar Jing Zhang
Browse files

add is_detected

parent 3cf22191
......@@ -126,13 +126,19 @@ struct ThreadGroupTensorSliceTransfer_v7r2
}
}
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename DstBuffers>
__device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_descs, dst_bufs);
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
threadwise_transfer_.RunWrite(dst_descs, dst_bufs);
else
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs));
}
}
......
......@@ -687,70 +687,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
(as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) /
KPerBlock);
#if 1
{
const auto a_grid_desc = as_grid_desc_ak0_m_ak1;
const auto b_grid_desc = bs_grid_desc_bk0_n_bk1;
const auto a_block_copy_step = a_block_slice_copy_step;
const auto b_block_copy_step = b_block_slice_copy_step;
const auto a_block_desc = a_block_desc_ak0_m_ak1;
const auto b_block_desc = b_block_desc_bk0_n_bk1;
const auto a_grid_bufs = as_grid_buf;
const auto b_grid_bufs = bs_grid_buf;
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_bufs);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_bufs);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(tie(a_block_desc), tie(a_block_buf));
b_blockwise_copy.RunWrite(tie(b_block_desc), tie(b_block_buf));
const auto num_loop = num_k_block_main_loop;
// main body
if constexpr(HasMainKBlockLoop)
{
index_t k = 0;
do
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_bufs);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_bufs);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(tie(a_block_desc), tie(a_block_buf));
b_blockwise_copy.RunWrite(tie(b_block_desc), tie(b_block_buf));
++k;
} while(k < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
#else
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
......@@ -770,7 +706,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
#endif
// shuffle C and write out
{
......
......@@ -7,18 +7,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include <type_traits>
template <typename T, typename = void>
struct has_vec_len : std::false_type
{
};
template <typename T>
struct has_vec_len<T, std::void_t<decltype(std::declval<T>().vec_len)>> : std::true_type
{
};
#include "ck/utility/is_detected.hpp"
namespace ck {
......@@ -143,6 +132,9 @@ struct ThreadwiseTensorSliceTransfer_v7r2
Number<num>{});
}
template <typename T>
using has_vec_len = decltype(std::declval<T&>().vec_len());
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
template <typename SrcBuffers,
......@@ -167,7 +159,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
is_src_valid);
});
if constexpr(has_vec_len<decltype(element_op_)>::value)
if constexpr(is_detected<has_vec_len, decltype(element_op_)>::value)
{
constexpr auto elem_op_vec_len = decltype(element_op_)::vec_len;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace detail {
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
struct detector
{
using value_t = std::false_type;
using type = Default;
};
template <class Default, template <class...> class Op, class... Args>
struct detector<Default, std::void_t<Op<Args...>>, Op, Args...>
{
using value_t = std::true_type;
using type = Op<Args...>;
};
} // namespace detail
struct nonesuch
{
~nonesuch() = delete;
nonesuch(nonesuch const&) = delete;
void operator=(nonesuch const&) = delete;
};
template <template <class...> class Op, class... Args>
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
} // namespace ck
......@@ -177,6 +177,8 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsTuple() { return true; }
};
template <>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment