Commit 1b15b21a authored by Chao Liu's avatar Chao Liu
Browse files

update static_tensor for dealing with invalid element

parent 2fd5e6ae
#ifndef CK_STATIC_TENSOR_HPP #ifndef CK_STATIC_TENSOR_HPP
#define CK_STATIC_TENSOR_HPP #define CK_STATIC_TENSOR_HPP
#include "ignore.hpp"
namespace ck { namespace ck {
// StaticTensor for Scalar // StaticTensor for Scalar
...@@ -17,10 +15,10 @@ struct StaticTensor ...@@ -17,10 +15,10 @@ struct StaticTensor
static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension(); static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension();
static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize(); static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize();
__host__ __device__ constexpr StaticTensor() : invalid_element_value_{0} {} __host__ __device__ constexpr StaticTensor() : invalid_element_scalar_value_{0} {}
__host__ __device__ constexpr StaticTensor(T invalid_element_value) __host__ __device__ constexpr StaticTensor(T invalid_element_value)
: invalid_element_value_{invalid_element_value} : invalid_element_scalar_value_{invalid_element_value}
{ {
} }
...@@ -44,11 +42,11 @@ struct StaticTensor ...@@ -44,11 +42,11 @@ struct StaticTensor
{ {
if constexpr(InvalidElementUseNumericalZeroValue) if constexpr(InvalidElementUseNumericalZeroValue)
{ {
return T{0}; return zero_scalar_value_;
} }
else else
{ {
return invalid_element_value_; return invalid_element_scalar_value_;
} }
} }
} }
...@@ -71,12 +69,14 @@ struct StaticTensor ...@@ -71,12 +69,14 @@ struct StaticTensor
} }
else else
{ {
return ignore; return ignored_element_scalar_;
} }
} }
StaticBuffer<AddressSpace, T, element_space_size_, true> data_; StaticBuffer<AddressSpace, T, element_space_size_, true> data_;
T invalid_element_value_ = T{0}; static constexpr T zero_scalar_value_ = T{0};
const T invalid_element_scalar_value_;
T ignored_element_scalar_;
}; };
// StaticTensor for vector // StaticTensor for vector
...@@ -97,10 +97,13 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -97,10 +97,13 @@ struct StaticTensorTupleOfVectorBuffer
using V = vector_type<S, ScalarPerVector>; using V = vector_type<S, ScalarPerVector>;
__host__ __device__ constexpr StaticTensorTupleOfVectorBuffer() : invalid_element_value_{0} {} __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer()
: invalid_element_scalar_value_{0}
{
}
__host__ __device__ constexpr StaticTensorTupleOfVectorBuffer(S invalid_element_value) __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer(S invalid_element_value)
: invalid_element_value_{invalid_element_value} : invalid_element_scalar_value_{invalid_element_value}
{ {
} }
...@@ -125,11 +128,11 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -125,11 +128,11 @@ struct StaticTensorTupleOfVectorBuffer
{ {
if constexpr(InvalidElementUseNumericalZeroValue) if constexpr(InvalidElementUseNumericalZeroValue)
{ {
return S{0}; return zero_scalar_value_;
} }
else else
{ {
return invalid_element_value_; return invalid_element_scalar_value_;
} }
} }
} }
...@@ -153,7 +156,7 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -153,7 +156,7 @@ struct StaticTensorTupleOfVectorBuffer
} }
else else
{ {
return ignore; return ignored_element_scalar_;
} }
} }
...@@ -186,7 +189,7 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -186,7 +189,7 @@ struct StaticTensorTupleOfVectorBuffer
else else
{ {
// TODO: is this right way to initialize a vector? // TODO: is this right way to initialize a vector?
return X{invalid_element_value_}; return X{invalid_element_scalar_value_};
} }
} }
} }
...@@ -237,7 +240,9 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -237,7 +240,9 @@ struct StaticTensorTupleOfVectorBuffer
} }
StaticBufferTupleOfVector<AddressSpace, S, num_of_vector_, ScalarPerVector, true> data_; StaticBufferTupleOfVector<AddressSpace, S, num_of_vector_, ScalarPerVector, true> data_;
S invalid_element_value_ = S{0}; static constexpr S zero_scalar_value_ = S{0};
const S invalid_element_scalar_value_ = S{0};
S ignored_element_scalar_;
}; };
template <AddressSpaceEnum_t AddressSpace, template <AddressSpaceEnum_t AddressSpace,
......
...@@ -114,6 +114,16 @@ struct BlockwiseTensorSliceTransfer_v4 ...@@ -114,6 +114,16 @@ struct BlockwiseTensorSliceTransfer_v4
} }
} }
template <typename SrcBuffer, typename DstBuffer>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
RunRead(src_desc, src_buf);
RunWrite(dst_desc, dst_buf);
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp" #include "threadwise_tensor_slice_set.hpp"
#define DEBUG_USE_C_SHUFFLE 0
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -17,7 +19,11 @@ template <typename GridwiseGemm, ...@@ -17,7 +19,11 @@ template <typename GridwiseGemm,
typename FloatC, typename FloatC,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
#if !DEBUG_USE_C_SHUFFLE
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
#else
typename CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl,
#endif
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
...@@ -33,24 +39,30 @@ __global__ void ...@@ -33,24 +39,30 @@ __global__ void
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
#if !DEBUG_USE_C_SHUFFLE
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
#else
const CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
#endif
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
constexpr index_t shared_block_size = __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared,
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
#if !DEBUG_USE_C_SHUFFLE
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
#else
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
#endif
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
...@@ -220,6 +232,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -220,6 +232,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
return has_main_k0_block_loop; return has_main_k0_block_loop;
} }
#if !DEBUG_USE_C_SHUFFLE
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
...@@ -269,6 +282,33 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -269,6 +282,33 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
} }
#else
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl(
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl);
const index_t MBlock = M / (MWave * MPerXdl * MRepeat);
const index_t NBlock = N / (NWave * NPerXdl * NRepeat);
const auto c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl =
transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(MBlock, Number<MRepeat>{}, Number<MWave * MPerXdl>{})),
make_unmerge_transform(
make_tuple(NBlock, Number<NRepeat>{}, Number<NWave * NPerXdl>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl;
}
#endif
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -305,20 +345,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -305,20 +345,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
return c_blockid_to_m0_n0_block_cluster_adaptor; return c_blockid_to_m0_n0_block_cluster_adaptor;
} }
#if !DEBUG_USE_C_SHUFFLE
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using Block2CTileMap = decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); #else
using CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl =
remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl(
CGridDesc_M_N{}))>;
#endif
using Block2CTileMap = remove_cvref_t<decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
template <bool HasMainKBlockLoop> template <bool HasMainKBlockLoop>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block, void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
#if !DEBUG_USE_C_SHUFFLE
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
#else
const CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl&
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
#endif
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
...@@ -328,8 +379,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -328,8 +379,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
#if !DEBUG_USE_C_SHUFFLE
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
#else
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid,
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
.GetElementSpaceSize());
#endif
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
...@@ -459,8 +518,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -459,8 +518,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block; #if !DEBUG_USE_C_SHUFFLE
FloatAB* p_b_block = p_shared_block + a_block_space_size; FloatAB* p_a_block = static_cast<FloatAB*>(p_shared);
FloatAB* p_b_block = static_cast<FloatAB*>(p_shared) + a_block_space_size;
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize());
#else
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size,
b_block_desc_k0_n_k1.GetElementSpaceSize());
#endif
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
...@@ -474,11 +548,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -474,11 +548,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{}; constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{}; constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize());
// preload data into LDS // preload data into LDS
{ {
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks); a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
...@@ -530,7 +599,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -530,7 +599,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
} }
#if 1 #if !DEBUG_USE_C_SHUFFLE
// output: register to global memory // output: register to global memory
{ {
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
...@@ -620,7 +689,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -620,7 +689,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr index_t MPerBlock_CCopy = MWave * MPerXdl; constexpr index_t MPerBlock_CCopy = MWave * MPerXdl;
constexpr index_t NPerBlock_CCopy = NWave * NPerXdl; constexpr index_t NPerBlock_CCopy = NWave * NPerXdl;
// hacky // TODO: hacky
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
...@@ -636,29 +705,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -636,29 +705,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
constexpr auto c_block_desc_mwavemperxdl_nwavenperxdl =
make_naive_tensor_descriptor_packed(Number<MPerBlock_CCopy>{},
Number<NPerBlock_CCopy>{});
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_block_desc_mwavemperxdl_nwavenperxdl,
make_tuple(make_unmerge_transform(make_tuple(I1, Number<MWave>{}, M2, M3, M4)),
make_unmerge_transform(make_tuple(I1, Number<NWave>{}, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
mdke_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
...@@ -679,13 +733,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -679,13 +733,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
make_multi_index(n_thread_data_on_block)); make_multi_index(n_thread_data_on_block));
// VGPR to LDS // VGPR to LDS
auto c_thread_copy_vgpr2lds = auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc, ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatAcc, FloatAcc,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Sequence<M0, N0, I1, I1, M2, I1, M4, I1>, Sequence<I1, I1, I1, I1, M2, M3, M4, N2>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, 7,
1, 1,
...@@ -695,37 +749,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -695,37 +749,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0, make_multi_index(0,
0, 0,
m_thread_data_on_grid_idx[I1], m_thread_data_on_block_idx[I1],
n_thread_data_on_grid_idx[I1], n_thread_data_on_block_idx[I1],
m_thread_data_on_grid_idx[I2], m_thread_data_on_block_idx[I2],
m_thread_data_on_grid_idx[I3], m_thread_data_on_block_idx[I3],
m_thread_data_on_grid_idx[I4], m_thread_data_on_block_idx[I4],
n_thread_data_on_grid_idx[I2]), n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
// hardcoded // TODO: this is hardcoded, only works for BlockSize = 256. fix it!
constexpr index_t MThread_CCopy = 16; constexpr index_t MThread_CCopy = 16;
constexpr index_t NThread_CCopy = 16; constexpr index_t NThread_CCopy = 16;
constexpr index_t MPerThread_CCopy = MPerBlock_CCopy / MThread_CCopy; constexpr index_t MPerThread_CCopy = MPerBlock_CCopy / MThread_CCopy;
constexpr index_t NPerThread_CCopy = NPerBlock_CCopy / NThread_CCopy; constexpr index_t NPerThread_CCopy = NPerBlock_CCopy / NThread_CCopy;
constexpr auto c_block_desc_mblock_mrepeat_mwaveMPerXdl_nblock_nrepeat_nwaveNPerXdl = constexpr auto c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl =
make_naive_tensor_descriptor_packed( make_naive_tensor_descriptor_packed(make_tuple(
I1, I1, Number<MPerBlock_CCopy>{}, I1, I1, Number<NPerBlock_CCopy>{}); I1, I1, Number<MPerBlock_CCopy>{}, I1, I1, Number<NPerBlock_CCopy>{}));
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
static_cast<FloatAcc*>(p_shared),
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
.GetElementSpaceSize());
auto c_block_copy = BlockwiseTensorSliceTransfer_v4< auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v4<
BlockSize, // index_t BlockSize, BlockSize, // index_t BlockSize,
ck::tensor_operation::element_wise::PassThrough, // SrcElementwiseOperation, ck::tensor_operation::element_wise::PassThrough, // SrcElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp, CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1, 1, MPerBlock_CCopy, 1, 1, NPerBlock_CCopy>, // BlockSliceLengths, Sequence<1, 1, MPerBlock_CCopy, 1, 1, NPerBlock_CCopy>, // BlockSliceLengths,
Sequence<1, 1, MPerThread_CCopy, 1, 1, NPerThread_CCopy>, // ThreadSliceLengths, Sequence<1, 1, MPerThread_CCopy, 1, 1, NPerThread_CCopy>, // ThreadSliceLengths,
Sequence<1, 1, MPerThread, 1, 1, NPerThread>, // typename ThreadClusterLengths, Sequence<1, 1, MThread_CCopy, 1, 1, NThread_CCopy>, // ThreadClusterLengths,
Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
FloatAcc, // typename SrcData, FloatAcc, // typename SrcData,
FloatC, // typename DstData, FloatC, // typename DstData,
decltype(c_block_desc_mblock_mrepeat_mwaveMPerXdl_nblock_nrepeat_nwaveNPerXdl), decltype(c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl),
decltype(c_global_desc_mblock_mrepeat_mwaveMPerXdl_nblock_nrepeat_nwaveNPerXdl), decltype(c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl),
Sequence<0, 1, 2, 3, 4, 5>, // typename SrcDimAccessOrder, Sequence<0, 1, 2, 3, 4, 5>, // typename SrcDimAccessOrder,
Sequence<0, 1, 2, 3, 4, 5>, // typename DstDimAccessOrder, Sequence<0, 1, 2, 3, 4, 5>, // typename DstDimAccessOrder,
5, // index_t SrcVectorDim, 5, // index_t SrcVectorDim,
...@@ -736,12 +795,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -736,12 +795,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
1, // index_t DstScalarStrideInVector, 1, // index_t DstScalarStrideInVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun, true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun> false> // bool ThreadTransferDstResetCoordinateAfterRun>
{ {c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
c_block_desc_mblock_mrepeat_mwaveMPerXdl_nblock_nrepeat_nwaveNPerXdl, make_multi_index(0, 0, 0, 0, 0, 0),
make_multi_index(0, 0, 0, 0, 0, 0), c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
c_global_desc_mblock_mrepeat_mwaveMPerXdl_nblock_nrepeat_nwaveNPerXdl, make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0) ck::tensor_operation::element_wise::PassThrough{}};
}
constexpr auto mrepeat_forward_step = make_multi_index(0, 1, 0, 0, 0, 0); constexpr auto mrepeat_forward_step = make_multi_index(0, 1, 0, 0, 0, 0);
constexpr auto nrepeat_forward_step = make_multi_index(0, 0, 0, 0, 1, 0); constexpr auto nrepeat_forward_step = make_multi_index(0, 0, 0, 0, 1, 0);
...@@ -750,37 +808,46 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -750,37 +808,46 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// make sure all ds_read from GEMM is completed // make sure all ds_read from GEMM is completed
block_sync_lds(); block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto mrepeat) { static_for<0, MRepeat, 1>{}([&](auto mrepeat_iter) {
static_for<0, NRepeat, 1>{}([&](auto nrepeat) { constexpr auto mrepeat = mrepeat_iter;
// VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, static_for<0, NRepeat, 1>{}([&](auto nrepeat_iter) {
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), constexpr bool nrepeat_forward_sweep = (mrepeat % 2 == 0);
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_block_buf);
constexpr index_t nrepeat_value =
nrepeat_forward_sweep ? nrepeat_iter : (NRepeat - nrepeat_iter - 1);
constexpr auto nrepeat = Number<nrepeat_value>{};
// VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(mrepeat, nrepeat, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_block_buf);
// make sure ds_write from c_thread_copy_vgpr_to_lds is completed
block_sync_lds(); block_sync_lds();
// LDS to global // LDS to global
c_block_copy_lds_to_global.Run( c_block_copy_lds_to_global.Run(
c_block_desc_mblock_mrepeat_mwaveMPerXdl_nblock_nrepeat_nwaveNPerXdl, c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
c_block_buf, c_block_buf,
c_global_desc_mblock_mrepeat_mwaveMPerXdl_nblock_nrepeat_nwaveNPerXdl, c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
c_global_buf); c_grid_buf);
constexpr bool nrepeat_forward_sweep = mrepeat % 2 == 0;
// move on nrepeat dimension // move on nrepeat dimension
if constexpr(nrepeat_forward_sweep && nrepeat < NRepeat - 1) if constexpr(nrepeat_forward_sweep && (nrepeat < NRepeat - 1))
{ {
c_block_copy.MoveDstSliceWindow( c_block_copy_lds_to_global.MoveDstSliceWindow(
c_global_desc_mblock_mrepeat_mwaveMPerXdl_nblock_nrepeat_nwaveNPerXdl, c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
nrepeat_forward_step); nrepeat_forward_step);
} }
else if constexpr((!nrepeat_forward_sweep) & nrepeat > 1) else if constexpr((!nrepeat_forward_sweep) && (nrepeat > 1))
{ {
c_block_copy.MoveDstSliceWindow( c_block_copy_lds_to_global.MoveDstSliceWindow(
c_global_desc_mblock_mrepeat_mwaveMPerXdl_nblock_nrepeat_nwaveNPerXdl, c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
nrepeat_backward_step); nrepeat_backward_step);
} }
}); });
...@@ -789,7 +856,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -789,7 +856,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
if constexpr(mrepeat < MRepeat - 1) if constexpr(mrepeat < MRepeat - 1)
{ {
c_block_copy_lds_to_global.MoveDstSliceWindow( c_block_copy_lds_to_global.MoveDstSliceWindow(
c_global_desc_mblock_mrepeat_mwaveMPerXdl_nblock_nrepeat_nwaveNPerXdl, c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
mrepeat_forward_step); mrepeat_forward_step);
} }
}); });
......
...@@ -165,6 +165,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -165,6 +165,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0]; index_t tmp = ordered_src_access_idx[I0];
// TODO: BUG: should start at 1
static_for<0, i, 1>{}([&](auto j) { static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
}); });
...@@ -412,6 +413,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -412,6 +413,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0]; index_t tmp = ordered_dst_access_idx[I0];
// TODO: BUG: should start at 1
static_for<0, i, 1>{}([&](auto j) { static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
}); });
...@@ -512,7 +514,8 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -512,7 +514,8 @@ struct ThreadwiseTensorSliceTransfer_v3r2
template <typename DstBuffer> template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{ {
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); // TODO: why need remove_cvref_t ?
constexpr index_t ntransform_dst = remove_cvref_t<DstDesc>::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{}; constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
...@@ -545,6 +548,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -545,6 +548,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
forward_sweep_(I0) = true; forward_sweep_(I0) = true;
// TODO: BUG: should start at 1
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_lengths[I0] - 1; index_t tmp = ordered_src_access_lengths[I0] - 1;
...@@ -608,6 +612,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -608,6 +612,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_lengths[I0] - 1; index_t tmp = ordered_dst_access_lengths[I0] - 1;
// TODO: BUG: should start at 1
static_for<0, i, 1>{}([&](auto j) { static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
}); });
......
...@@ -35,8 +35,8 @@ ...@@ -35,8 +35,8 @@
#include "dynamic_buffer.hpp" #include "dynamic_buffer.hpp"
#include "is_known_at_compile_time.hpp" #include "is_known_at_compile_time.hpp"
#include "transpose_vectors.hpp" #include "transpose_vectors.hpp"
#include "inner_product.hpp" #include "inner_product.hpp"
#include "element_wise_operation.hpp"
// TODO: remove this // TODO: remove this
#if CK_USE_AMD_INLINE_ASM #if CK_USE_AMD_INLINE_ASM
......
...@@ -24,12 +24,16 @@ ...@@ -24,12 +24,16 @@
#define CK_MIN_BLOCK_PER_CU 2 #define CK_MIN_BLOCK_PER_CU 2
#endif #endif
// buffer resourse // GPU-specific parameters
#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \ #if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A)
// buffer resourse
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
// wave size
#define CK_GPU_WAVE_SIZE 64
#elif defined(CK_AMD_GPU_GFX1030) #elif defined(CK_AMD_GPU_GFX1030)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#define CK_GPU_WAVE_SIZE 32
#endif #endif
// FMA instruction // FMA instruction
......
...@@ -5,8 +5,12 @@ ...@@ -5,8 +5,12 @@
namespace ck { namespace ck {
__device__ constexpr index_t get_wave_size() { return CK_GPU_WAVE_SIZE; }
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; } __device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
__device__ index_t get_wave_local_1d_id() { return threadIdx.x / get_wave_size(); }
__device__ index_t get_block_1d_id() { return blockIdx.x; } __device__ index_t get_block_1d_id() { return blockIdx.x; }
} // namespace ck } // namespace ck
......
...@@ -29,8 +29,8 @@ template <typename InDataType, ...@@ -29,8 +29,8 @@ template <typename InDataType,
ck::index_t NPerBlock, ck::index_t NPerBlock,
ck::index_t K0PerBlock, ck::index_t K0PerBlock,
ck::index_t K1, ck::index_t K1,
ck::index_t MPerXDL, ck::index_t MPerXdl,
ck::index_t NPerXDL, ck::index_t NPerXdl,
ck::index_t MXdlPerWave, ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave, ck::index_t NXdlPerWave,
typename ABlockTransferThreadSliceLengths_K0_M_K1, typename ABlockTransferThreadSliceLengths_K0_M_K1,
...@@ -266,8 +266,8 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N ...@@ -266,8 +266,8 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, K0PerBlock,
MPerXDL, MPerXdl,
NPerXDL, NPerXdl,
K1, K1,
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
...@@ -299,10 +299,12 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N ...@@ -299,10 +299,12 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
ABlockLdsAddExtraM, ABlockLdsAddExtraM,
BBlockLdsAddExtraN>; BBlockLdsAddExtraN>;
#if !DEBUG_USE_C_SHUFFLE
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
#endif
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -331,7 +333,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N ...@@ -331,7 +333,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
a_grid_desc_k0_m_k1_{}, a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{}, b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{}, c_grid_desc_m_n_{},
#if !DEBUG_USE_C_SHUFFLE
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
#else
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_{},
#endif
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
...@@ -358,8 +364,15 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N ...@@ -358,8 +364,15 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
if(GridwiseGemm::CheckValidity( if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
{ {
#if !DEBUG_USE_C_SHUFFLE
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
#else
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_ =
GridwiseGemm::
MakeCGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl(
c_grid_desc_m_n_);
#endif
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
} }
...@@ -372,8 +385,15 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N ...@@ -372,8 +385,15 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; #if !DEBUG_USE_C_SHUFFLE
Block2CTileMap block_2_ctile_map_; typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
#else
typename GridwiseGemm::
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_;
#endif
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
...@@ -427,11 +447,17 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N ...@@ -427,11 +447,17 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, #if !DEBUG_USE_C_SHUFFLE
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
#else
remove_reference_t<
typename GridwiseGemm::
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl>,
#endif
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
...@@ -444,7 +470,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N ...@@ -444,7 +470,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
#if !DEBUG_USE_C_SHUFFLE
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
#else
arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_,
#endif
arg.in_element_op_, arg.in_element_op_,
arg.wei_element_op_, arg.wei_element_op_,
arg.out_element_op_, arg.out_element_op_,
...@@ -458,11 +488,17 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N ...@@ -458,11 +488,17 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, #if !DEBUG_USE_C_SHUFFLE
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
#else
remove_reference_t<
typename GridwiseGemm::
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl>,
#endif
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
...@@ -475,7 +511,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N ...@@ -475,7 +511,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
#if !DEBUG_USE_C_SHUFFLE
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
#else
arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_,
#endif
arg.in_element_op_, arg.in_element_op_,
arg.wei_element_op_, arg.wei_element_op_,
arg.out_element_op_, arg.out_element_op_,
......
#ifndef ELEMENT_WISE_OPERATION_HPP
#define ELEMENT_WISE_OPERATION_HPP
namespace ck {
namespace tensor_operation {
namespace element_wise {
struct PassThrough
{
template <typename T>
__host__ __device__ constexpr T operator()(T v) const
{
return v;
}
};
struct AddRelu
{
template <typename T1>
__host__ constexpr float operator()(float v0, T1 v1) const
{
float b = v0 + v1;
float c = b > 0 ? b : 0;
return c;
}
template <typename T1>
__device__ constexpr float operator()(float v0, T1 v1) const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
return b;
#else
float b = v1 + v0;
float c = b > 0 ? b : 0;
return c;
#endif
}
};
struct AddReluAdd
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
float b = v0 + v1;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
}
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
float c = b + v2;
return c;
#else
float b = v1 + v2;
float c = (v0 > -v1) ? b + v0 : v2;
return c;
#endif
}
};
struct AddLeakyReluAdd
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
}
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
#if 0
// this use not too many registers, but use fp64 mul
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#elif 0
// this spill register
float a = v0 + v1;
float b = float(0.1) * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#elif 0
// this use lots of registers (but no spill)
constexpr float alpha = 0.1;
constexpr float alpha_inv = 1.0 / alpha;
float a = v2 * alpha_inv;
float b = v1 + v0;
float c = b > 0 ? b : 0;
float d = alpha * (a + c);
return d;
#elif 1
// this use lots of registers (but no spill), 89 Tflops
constexpr float alpha = 0.1;
constexpr float alpha_inv = 1.0 / alpha;
float a = v2 * alpha_inv;
float b = v1 + v0;
float c = max(b, float(0));
float d = alpha * (a + c);
return d;
#elif 1
// this spill registers, 89 Tflops
float a = v0 + v1;
float alpha = 0.1;
float b;
asm volatile("\n \
v_mul_f32_e32 %0, %1, %2 \n \
"
: "=v"(b)
: "s"(alpha), "v"(a));
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#endif
}
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment