"composable_kernel/include/utility/Array.hpp" did not exist on "a9031464271a961b4b23d9cbf0e5d944dc8a78bf"
Commit 2185affb authored by Tejash Shah's avatar Tejash Shah
Browse files

Added fp16 support in implicit gemm

parent c15ff3c8
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_FP16_BFP16_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_FP16_BFP16_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace ck {
// define B = merge(N0, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
class Float,
class AccDataType,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t N1,
index_t N2,
index_t ES,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopySubLengths_E_N1_B_N2_ES,
class InBlockCopyClusterLengths_E_N1_B_N2_ES,
class InBlockCopyThreadClusterArrangeOrder,
class InBlockCopySrcAccessOrder,
class InBlockCopyDstAccessOrder,
index_t InBlockCopySrcDataPerRead_B,
index_t InBlockCopyDstDataPerWrite_N2,
class WeiBlockCopySubLengths_E_K_ES,
class WeiBlockCopyClusterLengths_E_K_ES,
class WeiBlockCopyThreadClusterArrangeOrder,
class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert(N2 == GemmNPerThreadSubC, "wrong!");
static_assert((N1 * N2 * BPerBlock) %
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
0,
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
constexpr index_t N0 = N / (N1 * N2);
constexpr index_t B = N0 * Ho * Wo;
// ES=1 for float32, =2 for bfloat16, =4 for float16
static_assert(C % ES == 0, "C needs to be multiple of vectorized C (ES)");
constexpr auto nonVectorizedC = C / ES;
constexpr index_t E = nonVectorizedC * Y * X;
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo, {2C/4C}]
constexpr auto in_n0_n1_n2_h_w_2cor4c_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrides::Get(I0)>{})
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrides::Get(I1)>{})
.Fold(I1, Number<nonVectorizedC>{})
.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 3, 5, 6>{})
.ReorderGivenNew2Old(Sequence<0, 1, 2, 4, 5, 3>{});
// batch descritpor for device memory
constexpr auto in_c_y_x_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilations::Get(I0)>{})
.StridedSlice(I3, Number<X>{}, Number<ConvDilations::Get(I1)>{})
.Fold(I1, Number<nonVectorizedC>{})
.Extract(Sequence<2, 3, 4>{});
// merged tensor descriptor in device memory [E, N1, B, N2, {2E/4E}], src of blockwise
// copy
constexpr auto in_e_n1_b_n2_2eor4e_global_merged_desc = make_ConstantMergedTensorDescriptor(
in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_2cor4c_global_desc),
Sequence<0, 1, 2>{},
Sequence<4>{},
Sequence<3, 6, 7>{},
Sequence<5>{},
Sequence<8>{});
// memory layout descriptor in LDS [E, N1, B, N2, {2C/4C}], dst of blockwise copy
// be careful of LDS alignment
constexpr auto in_e_n1_b_n2_2eor4e_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, N1, BPerBlock, N2, ES>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(in_e_n1_b_n2_2eor4e_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
"GemmDataPerReadB alignment requirement is not satisfied");
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1<
BlockSize,
Float,
decltype(in_e_n1_b_n2_2eor4e_global_merged_desc),
decltype(in_e_n1_b_n2_2eor4e_block_desc),
decltype(in_e_n1_b_n2_2eor4e_block_desc.GetLengths()),
InBlockCopySubLengths_E_N1_B_N2_ES,
InBlockCopyClusterLengths_E_N1_B_N2_ES,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0, 0}, {0, 0, 0, 0, 0});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_2eor4e_global_desc =
wei_k_c_y_x_global_desc.Fold(I1, Number<nonVectorizedC>{})
.Unfold(I2, I4)
.ReorderGivenNew2Old(Sequence<2, 0, 1>{});
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_e_k_2eor4e_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock, ES>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
Float,
decltype(wei_e_k_2eor4e_global_desc),
decltype(wei_e_k_2eor4e_block_desc),
decltype(wei_e_k_2eor4e_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K_ES,
WeiBlockCopyClusterLengths_E_K_ES,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global, 0}, {0, 0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock ] is in LDS of type float/bfloat16 vec2/ float16 vec4
// b_mtx[EPerBlocl, N1 * BPerBlock * N2 ] is in LDS of type float/bfloat16 vec2/ float16
// vec4
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<EPerBlock>{}, Number<KPerBlock>{});
constexpr auto b_e_n1bn2_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<EPerBlock>{}, Number<N1 * BPerBlock * N2>{});
// sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
0,
"wrong!");
constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<N1 * N2>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_e_k_block_mtx_desc),
decltype(b_e_n1bn2_block_mtx_desc),
decltype(c_k0k2_n1n2_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2,
WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t in_block_space = math::integer_least_multiple(
in_e_n1_b_n2_2eor4e_block_desc.GetElementSpace(), max_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_e_k_2eor4e_block_desc.GetElementSpace(), max_align);
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
AccDataType p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread);
const Float* p_wei_block_on_global = p_wei_global;
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.Run(p_in_global, p_in_block_double);
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
}
// LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock)
{
// hcc compilation error: loop not unrolled: the optimizer was unable to perform the
// requested transformation;
// the transformation might be disabled or specified as part of an unsupported
// transformation
// ordering [-Werror,-Wpass-failed=transform-warning]
//#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_2eor4e_global_desc.GetStride(I0);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
p_wei_register_clipboard);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
// even iteration
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_2eor4e_global_desc.GetStride(I0);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
p_wei_register_clipboard);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
// copy output: register to global memory
{
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc =
make_ConstantTensorDescriptor_packed(
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, 1, 1, 1, N2>{});
// output tensor descriptor in register, src of threadwise copy
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc =
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old(
Sequence<4, 3, 7, 0, 1, 2, 5, 6>{});
// output memory layout descriptor in device memory, dst of threadwise copy
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc =
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{})
.Fold(I0, Number<N1>{}, Number<N2>{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5),
Sequence<3>{},
Sequence<1>{},
Sequence<0, 4, 5>{},
Sequence<2>{});
// origin of dst in device memory
Float* p_out_thread_on_global =
p_out_global +
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, 0, b_thread_data_on_global, 0);
threadwise_generic_tensor_slice_copy_v1(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc,
p_out_thread,
{0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc,
p_out_thread_on_global,
{0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
arithmetic_sequence_gen<0, 8, 1>::type{},
Number<1>{});
}
}
};
} // namespace ck
#endif
......@@ -15,6 +15,7 @@ namespace ck {
template <index_t GridSize,
index_t BlockSize,
class Float,
class AccDataType,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
......@@ -50,9 +51,10 @@ template <index_t GridSize,
index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
__device__ void
Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
......@@ -84,12 +86,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
constexpr index_t N0 = N / (N1 * N2);
......@@ -98,14 +94,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
constexpr index_t E = C * Y * X;
// sanity-check for vectorized memory load
static_assert(ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1,
"wrong! global vector load of input tensor is wrong");
static_assert((X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
"wrong! cannot divide work evenly among block");
......@@ -125,15 +113,15 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr auto in_n0_n1_n2_h_w_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrideH>{})
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrideW>{})
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrides::Get(I0)>{})
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrides::Get(I1)>{})
.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{});
// batch descritpor for device memory
constexpr auto in_c_y_x_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilations::Get(I0)>{})
.StridedSlice(I3, Number<X>{}, Number<ConvDilations::Get(I1)>{})
.Extract(Sequence<1, 2, 3>{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
......@@ -260,7 +248,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
Float p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()];
AccDataType p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread);
......@@ -332,7 +320,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
p_wei_register_clipboard);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
......
......@@ -37,7 +37,7 @@ struct ConstantMergedTensorDescriptor
return OriginalTensorDesc{};
}
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }
template <index_t IDim>
__host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
......@@ -52,7 +52,7 @@ struct ConstantMergedTensorDescriptor
}
template <index_t IDim>
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
__host__ __device__ static constexpr index_t GetLength(Number<IDim>)
{
constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs);
......@@ -60,7 +60,7 @@ struct ConstantMergedTensorDescriptor
}
template <index_t IDim>
__host__ __device__ static constexpr auto GetStride(Number<IDim>)
__host__ __device__ static constexpr index_t GetStride(Number<IDim>)
{
static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
"wrong! stride of a merged dimension is undefined");
......@@ -75,7 +75,7 @@ struct ConstantMergedTensorDescriptor
return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{};
}
__host__ __device__ static constexpr auto GetElementSize()
__host__ __device__ static constexpr index_t GetElementSize()
{
return OriginalTensorDesc::GetElementSize();
}
......
......@@ -43,22 +43,22 @@ struct ConstantTensorDescriptor
return Sequence<IDim>{};
}
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }
__host__ __device__ static constexpr auto GetLengths() { return Lengths{}; }
__host__ __device__ static constexpr auto GetStrides() { return Strides{}; }
template <class IDim>
__host__ __device__ static constexpr auto GetLength(IDim)
template <index_t I>
__host__ __device__ static constexpr index_t GetLength(Number<I>)
{
return Lengths::Get(IDim{});
return Lengths::Get(Number<I>{});
}
template <class IDim>
__host__ __device__ static constexpr auto GetStride(IDim)
template <index_t I>
__host__ __device__ static constexpr index_t GetStride(Number<I>)
{
return Strides::Get(IDim{});
return Strides::Get(Number<I>{});
}
struct lambda_AreDimensionsContinuous
......@@ -102,18 +102,17 @@ struct ConstantTensorDescriptor
return false;
}
__host__ __device__ static constexpr auto GetElementSize()
__host__ __device__ static constexpr index_t GetElementSize()
{
return Number<accumulate_on_sequence(
Lengths{}, math::multiplies<index_t>{}, Number<1>{})>{};
return accumulate_on_sequence(Lengths{}, math::multiplies<index_t>{}, Number<1>{});
}
__host__ __device__ static constexpr auto GetElementSpace()
__host__ __device__ static constexpr index_t GetElementSpace()
{
constexpr index_t element_space_unaligned = accumulate_on_sequence(
(GetLengths() - Number<1>{}) * GetStrides(), math::plus<index_t>{}, Number<1>{});
return Number<element_space_unaligned>{};
return element_space_unaligned;
}
// emulate constexpr lambda
......@@ -157,14 +156,13 @@ struct ConstantTensorDescriptor
}
template <index_t... Is>
__host__ __device__ static constexpr auto GetOffsetFromMultiIndex(Sequence<Is...>)
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence<Is...>)
{
static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent");
constexpr auto multi_id = Sequence<Is...>{};
return Number<accumulate_on_sequence(
multi_id * GetStrides(), math::plus<index_t>{}, Number<0>{})>{};
return accumulate_on_sequence(multi_id * GetStrides(), math::plus<index_t>{}, Number<0>{});
}
// emulate constexpr lambda
......
......@@ -142,47 +142,80 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
// assertion for inline asm
static_assert(is_same<FloatA, float>{} && is_same<FloatB, float>{} &&
is_same<FloatC, float>{},
"Run_amd_asm only deal with float");
static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
MPerThread == 8 && NPerThread == 8,
"Run_amd_asm cannot deal with this GEMM shape yet");
static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_amd_asm only do float4 read");
using Float4 = vector_type<float, 4>::MemoryType;
Float4* reg_a = reinterpret_cast<Float4*>(p_a_thread);
Float4* reg_b = reinterpret_cast<Float4*>(p_b_thread);
Float4* reg_c = reinterpret_cast<Float4*>(p_c_thread);
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
reg_b[1] =
*reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
reg_a[1] =
*reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
// If A and B datatype is float
static_if<std::is_same<FloatA, float>::value &&
std::is_same<FloatB, float>::value>{}([&](auto) {
using Float4 = vector_type<float, 4>::MemoryType;
Float4* reg_a = reinterpret_cast<Float4*>(p_a_thread);
Float4* reg_b = reinterpret_cast<Float4*>(p_b_thread);
Float4* reg_c = reinterpret_cast<Float4*>(p_c_thread);
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
reg_b[1] =
*reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
reg_a[1] =
*reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
#pragma unroll
for(index_t k = 1; k < K; ++k)
{
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + k * M]);
for(index_t k = 1; k < K; ++k)
{
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + k * M]);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + k * N]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
reg_b[1] = *reinterpret_cast<const Float4*>(
&p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
reg_a[1] = *reinterpret_cast<const Float4*>(
&p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
}
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + k * N]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
reg_b[1] = *reinterpret_cast<const Float4*>(
&p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
reg_a[1] = *reinterpret_cast<const Float4*>(
&p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
}).Else([&](auto) { // If A and B datatype is bfloat16/float16
using Half4x4 = vector_type<vector_type<half, 4>, 4>;
using Float4 = vector_type<float, 4>::MemoryType;
Half4x4* reg_a = reinterpret_cast<Half4x4*>(p_a_thread);
Half4x4* reg_b = reinterpret_cast<Half4x4*>(p_b_thread);
Float4* reg_c = reinterpret_cast<Float4*>(p_c_thread);
reg_a[0] = *reinterpret_cast<const Half4x4*>(&p_a_block[mMyThreadOffsetA]);
reg_b[0] = *reinterpret_cast<const Half4x4*>(&p_b_block[mMyThreadOffsetB]);
reg_b[1] =
*reinterpret_cast<const Half4x4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
reg_a[1] =
*reinterpret_cast<const Half4x4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
}
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
#pragma unroll
for(index_t k = 1; k < K; ++k)
{
reg_a[0] = *reinterpret_cast<const Half4x4*>(&p_a_block[mMyThreadOffsetA + k * M]);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
reg_b[0] = *reinterpret_cast<const Half4x4*>(&p_b_block[mMyThreadOffsetB + k * N]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
reg_b[1] = *reinterpret_cast<const Half4x4*>(
&p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
reg_a[1] = *reinterpret_cast<const Half4x4*>(
&p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
}
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
});
}
#endif
......@@ -204,11 +237,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto a_thread_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThread>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
constexpr auto b_thread_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThread>{}, Number<NPerThread>{});
// thread A-sub, B-sub for copy
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
......@@ -415,7 +448,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
#if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
Run_amd_asm(p_a_block, p_b_block, p_c_thread);
static_if<std::is_same<FloatA, ushort>::value && std::is_same<FloatB, ushort>::value>{}(
[&](auto) { Run_source(p_a_block, p_b_block, p_c_thread); })
.Else([&](auto) { // If A and B datatype is bfloat16/float16
Run_amd_asm(p_a_block, p_b_block, p_c_thread);
});
#else
Run_source(p_a_block, p_b_block, p_c_thread);
#endif
......
......@@ -10,13 +10,15 @@
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
#endif
#define JOINTCAT(x, y) x##y
#define ASSERT_MSG_ARG1(msg, var1) JOINTCAT(msg, var1)
#define ASSERT_MSG_ARG2(msg, var1, va2) ASSERT_MSG_ARG1(JOINTCAT(msg, var1), var2)
namespace ck {
// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// memory layout (ordering of dimensions) can be different between src and dst.
// on a merged dimension that constains multiple original dimensions,
// its sub-length need to evenly divide the length of the last original dimension
// so each thread is effectively reading a normal (not merged) tensor
// memory layout (ordering of dimensions) can be different between src and dst
// For now, only support SubLengths[...] == 1 on a merged dimension
template <index_t BlockSize,
class Float,
class SrcDesc,
......@@ -77,15 +79,18 @@ struct BlockwiseGenericTensorSliceCopy_v1
// thread cluster
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
DataClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
DataClusterLengths{}.ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
// BlockSize
static_assert(BlockSize == thread_cluster_desc.GetElementSize(), "wrong! BlockSize");
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
"wrong! block size doesn't match with thread cluster size.");
// divide work
constexpr auto data_per_cluster_per_dims = SubLengths{} * DataClusterLengths{};
static_for<0, nDim, 1>{}([&](auto IDim) {
static_for<0, nDim, 1>{}([&](auto IDim_) {
constexpr auto IDim = decltype(IDim_){};
static_assert(SliceLengths::Get(IDim) % SubLengths::Get(IDim) == 0,
"wrong! cannot evenly divide sliced tensor into sub-tensor");
......@@ -93,23 +98,15 @@ struct BlockwiseGenericTensorSliceCopy_v1
"wrong! cannot evenly divide sliced tensor into cluster");
});
// on a merged dimension that constains multiple original dimensions,
// its sub-length need to evenly divide the length of the last original dimension,
// so each thread is effectively reading a normal (not merged) tensor
static_for<0, nDim, 1>{}([&](auto IDim) {
constexpr auto sub_length = SubLengths::Get(IDim);
constexpr auto idim_original_src = SrcDesc::GetContainedOriginalDimensions(IDim).Back();
static_assert(SrcDesc::GetOriginalTensorDescriptor().GetLength(idim_original_src) %
sub_length ==
0,
"wrong!");
constexpr auto idim_original_dst = DstDesc::GetContainedOriginalDimensions(IDim).Back();
static_assert(DstDesc::GetOriginalTensorDescriptor().GetLength(idim_original_dst) %
sub_length ==
0,
"wrong!");
// for now, only support SubLengths == 1 on a merged dimension that constains
// multiple original dimensions
static_for<0, nDim, 1>{}([&](auto IDim_) {
constexpr auto IDim = decltype(IDim_){};
static_assert(SubLengths::Get(IDim) == 1 ||
(!SrcDesc::ContainMultipleOriginalDimensions(IDim) &&
!DstDesc::ContainMultipleOriginalDimensions(IDim)),
"wrong! only support Sub-Length == 1 on a merged dimension");
});
// calculate mThreadSrcOffset, mThreadDstOffset
......@@ -129,25 +126,31 @@ struct BlockwiseGenericTensorSliceCopy_v1
dst_block_data_multi_id_begin + thread_data_multi_id_begin);
// partial offset on each dimension
static_for<0, nDim, 1>{}([&](auto IDim) {
static_for<0, nDim, 1>{}([&](auto IDim_) {
constexpr auto IDim = decltype(IDim_){};
constexpr index_t idim = IDim;
constexpr auto src_partial_original_dims =
SrcDesc::GetContainedOriginalDimensions(IDim);
constexpr auto src_partial_original_desc =
SrcDesc::GetOriginalTensorDescriptor().Extract(src_partial_original_dims);
mThreadSrcPartialOffsets(IDim) = src_partial_original_desc.GetOffsetFromMultiIndex(
mThreadSrcPartialOffsets(idim) = src_partial_original_desc.GetOffsetFromMultiIndex(
extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims));
});
static_for<0, nDim, 1>{}([&](auto IDim) {
static_for<0, nDim, 1>{}([&](auto IDim_) {
constexpr auto IDim = decltype(IDim_){};
constexpr index_t idim = IDim;
constexpr auto dst_partial_original_dims =
DstDesc::GetContainedOriginalDimensions(IDim);
constexpr auto dst_partial_original_desc =
DstDesc::GetOriginalTensorDescriptor().Extract(dst_partial_original_dims);
mThreadDstPartialOffsets(IDim) = dst_partial_original_desc.GetOffsetFromMultiIndex(
mThreadDstPartialOffsets(idim) = dst_partial_original_desc.GetOffsetFromMultiIndex(
extract_array(mThreadDstOriginalMultiId, dst_partial_original_dims));
});
......@@ -181,8 +184,10 @@ struct BlockwiseGenericTensorSliceCopy_v1
constexpr auto thread_tensor_desc =
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id) {
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
constexpr auto src_thread_data_multi_id_begin =
repeat_multi_id * data_per_cluster_per_dims;
......@@ -195,25 +200,19 @@ struct BlockwiseGenericTensorSliceCopy_v1
constexpr index_t clipboard_offset =
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
#else
ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id) {
constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){});
const auto src_thread_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
const auto clipboard_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
const index_t src_offset =
SrcDesc::GetOffsetFromMultiIndex(src_thread_data_multi_id_begin);
SrcDesc{}.GetOffsetFromMultiIndex(src_thread_data_multi_id_begin);
const index_t clipboard_offset =
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
#endif
// By position the origin of the per-thread window at the point, where multi-index
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
// is assuming each thread is copy a noraml (not merged) tensor.
// User need to guarantee this is true.
// By setting SubLengths = 1 at the merged dimension, this is always true;
// If in the future, you want to enable SubLengths > 1 at the merged dimension,
// special care in implementation is needed
threadwise_generic_tensor_slice_copy_v1(SrcDesc{},
p_src + src_offset + mThreadSrcOffset,
make_zero_array<index_t, nDim>(),
......@@ -238,8 +237,10 @@ struct BlockwiseGenericTensorSliceCopy_v1
constexpr auto thread_tensor_desc =
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id) {
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
constexpr auto clipboard_data_multi_id_begin =
repeat_multi_id * thread_sub_tensor_lengths;
......@@ -249,9 +250,10 @@ struct BlockwiseGenericTensorSliceCopy_v1
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
constexpr index_t dst_offset =
DstDesc::GetOffsetFromMultiIndex(dst_data_multi_id_begin);
DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id_begin);
#else
ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id) {
constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){});
const auto clipboard_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
const auto dst_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
......@@ -259,16 +261,9 @@ struct BlockwiseGenericTensorSliceCopy_v1
const index_t clipboard_offset =
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
const index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(dst_data_multi_id_begin);
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id_begin);
#endif
// By position the origin of the per-thread window at the point, where multi-index
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
// is assuming each thread is copy a noraml (not merged) tensor.
// User need to guarantee this is true.
// By setting SubLengths = 1 at the merged dimension, this is always true;
// If in the future, you want to enable SubLengths > 1 at the merged dimension,
// special care in implementation is needed
threadwise_generic_tensor_slice_copy_v1(thread_tensor_desc,
p_clipboard + clipboard_offset,
make_zero_array<index_t, nDim>(),
......@@ -302,7 +297,8 @@ struct BlockwiseGenericTensorSliceCopy_v1
__device__ void MoveSlicingWindowOnSourceTensor(
Number<IDim_>, Number<StepSize>, integral_constant<bool, PositiveDirection> direction)
{
constexpr auto IDim = Number<IDim_>{};
constexpr auto IDim = Number<IDim_>{};
constexpr index_t idim = IDim;
static_if<SrcDesc::ContainMultipleOriginalDimensions(IDim)>{}([&](auto) {
// logic for a merged dimension, also works for non-merged dimension, but its logic may
......@@ -325,21 +321,22 @@ struct BlockwiseGenericTensorSliceCopy_v1
old_src_partial_original_multi_id, StepSize, direction);
// update "mThreadSrcOriginalMultiId"
static_for<0, decltype(src_partial_original_dims)::GetSize(), 1>{}([&](auto I) {
constexpr auto IDimOriginal = src_partial_original_dims[I];
static_for<0, decltype(src_partial_original_dims)::GetSize(), 1>{}([&](auto I_) {
constexpr auto I = decltype(I_){};
constexpr index_t idim_original = src_partial_original_dims.Get(I);
mThreadSrcOriginalMultiId(IDimOriginal) = new_src_partial_original_multi_id[I];
mThreadSrcOriginalMultiId(idim_original) = new_src_partial_original_multi_id[I];
});
// calculate new partial offset on this merged dimension
const index_t old_src_partial_offset = mThreadSrcPartialOffsets[IDim];
const index_t old_src_partial_offset = mThreadSrcPartialOffsets[idim];
const index_t new_src_partial_offset =
src_partial_original_desc.GetOffsetFromMultiIndex(
new_src_partial_original_multi_id);
// update "mThreadSrcPartialOffsets"
mThreadSrcPartialOffsets(IDim) = new_src_partial_offset;
mThreadSrcPartialOffsets(idim) = new_src_partial_offset;
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mThreadSrcOffset = (mThreadSrcOffset + new_src_partial_offset) - old_src_partial_offset;
......@@ -354,20 +351,20 @@ struct BlockwiseGenericTensorSliceCopy_v1
// of the boundary of the tensor being sliced. Otherwise, there might be hazard like
// unsigned integer underflow. That is NO runtime sanity check to prevent the hazard
constexpr auto IDimOriginal = SrcDesc::GetContainedOriginalDimensions(IDim).Front();
constexpr index_t idim_original = SrcDesc::GetContainedOriginalDimensions(IDim).Front();
static_if<PositiveDirection>{}([&](auto fwd) {
mThreadSrcOffset += StepSize * fwd(SrcDesc{}).GetStride(IDim);
mThreadSrcOriginalMultiId(IDimOriginal) += StepSize;
mThreadSrcOriginalMultiId(idim_original) += StepSize;
mThreadSrcPartialOffsets(IDim) += StepSize * fwd(SrcDesc{}).GetStride(IDim);
mThreadSrcPartialOffsets(idim) += StepSize * fwd(SrcDesc{}).GetStride(IDim);
}).Else([&](auto fwd) {
mThreadSrcOffset -= StepSize * fwd(SrcDesc{}).GetStride(IDim);
mThreadSrcOriginalMultiId(IDimOriginal) -= StepSize;
mThreadSrcOriginalMultiId(idim_original) -= StepSize;
mThreadSrcPartialOffsets(IDim) -= StepSize * fwd(SrcDesc{}).GetStride(IDim);
mThreadSrcPartialOffsets(idim) -= StepSize * fwd(SrcDesc{}).GetStride(IDim);
});
});
}
......
......@@ -3,6 +3,7 @@
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "float_types.h"
namespace ck {
......@@ -34,22 +35,60 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
{
static_assert(NCol % DataPerRead == 0, "wrong! should be NCol % == DataPerRead == 0");
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
constexpr auto src_mtx = SrcMatrix{};
constexpr auto dst_mtx = DstMatrix{};
for(index_t i = 0; i < NRow; ++i)
{
for(index_t j = 0; j < NCol; j += DataPerRead)
// Depending upon datatype i.e float/half/bfloat16, carry out data movement
// in appropriate vectorized form
// float - 4, half - 4, bfloat16 - 2
static_if<std::is_same<Float, float>::value>{}([&](auto) {
using vector_t = typename vector_type<float, DataPerRead>::MemoryType;
for(index_t i = 0; i < NRow; ++i)
{
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
for(index_t j = 0; j < NCol; j += DataPerRead)
{
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
}
}
}
}).Else([&](auto) {
static_if<std::is_same<Float, half>::value>{}([&](auto) {
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4 respectively)
using vector_t = typename vector_type<Float, 4>::MemoryType;
for(index_t i = 0; i < NRow; ++i)
{
for(index_t j = 0; j < NCol; ++j)
{
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
*reinterpret_cast<vector_t*>(&p_dst[dst_index*4]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index*4]);
}
}
}).Else([&](auto) {
using vector_t = typename vector_type<Float, 2>::MemoryType;
for(index_t i = 0; i < NRow; ++i)
{
for(index_t j = 0; j < NCol; ++j)
{
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
*reinterpret_cast<vector_t*>(&p_dst[dst_index*2]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index*2]);
}
}
});
});
}
template <class MatrixA,
......@@ -90,7 +129,32 @@ __device__ void threadwise_gemm(MatrixA,
const index_t bindex = b_mtx.GetOffsetFromMultiIndex(k, j);
const index_t cindex = c_mtx.GetOffsetFromMultiIndex(i, j);
p_c_thread[cindex] += p_a_thread[aindex] * p_b_thread[bindex];
static_if<std::is_same<FloatA, float>::value>{}([&](auto) {
p_c_thread[cindex] += CVT_FLOAT2ACCUM(p_a_thread[aindex]) *
CVT_FLOAT2ACCUM(p_b_thread[bindex]);
}).Else([&](auto) {
static_if<std::is_same<FloatA, half>::value>{}([&](auto) {
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
// respectively)
float acc = 0.0;
for(index_t v = 0; v < 4; ++v)
{
acc += CVT_FLOAT2ACCUM(p_a_thread[aindex*4 + v]) *
CVT_FLOAT2ACCUM(p_b_thread[bindex*4 + v]);
}
p_c_thread[cindex] = acc;
}).Else([&](auto) {
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
// respectively)
float acc = 0.0;
for(index_t v = 0; v < 2; ++v)
{
acc += CVT_FLOAT2ACCUM(p_a_thread[aindex*2 + v]) *
CVT_FLOAT2ACCUM(p_b_thread[bindex*2 + v]);
}
p_c_thread[cindex] += acc;
});
});
}
}
}
......
......@@ -4,6 +4,7 @@
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "float_types.h"
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
......@@ -12,7 +13,8 @@
namespace ck {
// user need to make sure alignment requirement is satisfied when setting DataPerAccesss > 1
template <class Float,
template <class SrcFloat,
class DesFloat,
class SrcDesc,
class DstDesc,
class SliceLengths,
......@@ -20,10 +22,10 @@ template <class Float,
index_t DataPerAccess>
__device__ void threadwise_generic_tensor_slice_copy_v1(
SrcDesc,
const Float* __restrict__ p_src,
const SrcFloat* __restrict__ p_src,
Array<index_t, SrcDesc::GetNumOfDimension()> src_multi_id_begin,
DstDesc,
Float* __restrict__ p_dst,
DesFloat* __restrict__ p_dst,
Array<index_t, DstDesc::GetNumOfDimension()> dst_multi_id_begin,
SliceLengths,
DimAccessOrder,
......@@ -64,7 +66,8 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
constexpr auto access_lengths = slice_lengths_in_access_order.Modify(
Number<nDim - 1>{}, Number<num_access_on_lowest_access_dimension>{});
using vector_t = typename vector_type<Float, DataPerAccess>::MemoryType;
using vector_src_t = typename vector_type<SrcFloat, DataPerAccess>::MemoryType;
using vector_dest_t = typename vector_type<DesFloat, DataPerAccess>::MemoryType;
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1
static_ford<decltype(access_lengths)>{}([&](auto access_multi_id) {
......@@ -82,8 +85,15 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
const index_t dst_index =
DstDesc::GetOffsetFromMultiIndex(dst_multi_id_begin + data_multi_id);
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
static_if<std::is_same<vector_src_t, vector_dest_t>::value>{}([&](auto) {
*reinterpret_cast<vector_dest_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_src_t*>(&p_src[src_index]);
}).Else([&](auto) {
for(unsigned int data_idx = 0; data_idx < DataPerAccess; ++data_idx)
{
p_dst[dst_index + data_idx] = CVT_ACCUM2FLOAT(p_src[src_index + data_idx]);
}
});
});
#else
ford<decltype(access_lengths)>{}([&](auto access_multi_id) {
......@@ -99,8 +109,16 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
const index_t dst_index =
DstDesc::GetOffsetFromMultiIndex(dst_multi_id_begin + data_multi_id);
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
static_if<std::is_same<vector_src_t, vector_dest_t>::value>{}([&](auto) {
*reinterpret_cast<vector_dest_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_src_t*>(&p_src[src_index]);
//printf("%f ", static_cast<float>(p_dst[dst_index]));
}).Else([&](auto) {
for(unsigned int data_idx = 0; data_idx < DataPerAccess; ++data_idx)
{
p_dst[dst_index + data_idx] = CVT_ACCUM2FLOAT(p_src[src_index + data_idx]);
}
});
});
#endif
}
......
......@@ -16,32 +16,31 @@ struct Sequence
static constexpr index_t mSize = sizeof...(Is);
__host__ __device__ static constexpr auto GetSize() { return Number<mSize>{}; }
__host__ __device__ static constexpr index_t GetSize() { return mSize; }
__host__ __device__ static constexpr index_t GetImpl(index_t I)
template <index_t I>
__host__ __device__ static constexpr index_t Get(Number<I>)
{
static_assert(I < mSize, "wrong! I too large");
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const index_t mData[mSize + 1] = {Is..., 0};
return mData[I];
}
template <index_t I>
__host__ __device__ static constexpr auto Get(Number<I>)
__host__ __device__ constexpr auto operator[](Number<I>) const
{
static_assert(I < mSize, "wrong! I too large");
return Number<GetImpl(Number<I>{})>{};
return Number<Get(Number<I>{})>{};
}
template <index_t I>
__host__ __device__ constexpr auto operator[](Number<I>) const
// make sure I is constepxr
__host__ __device__ constexpr index_t operator[](index_t I) const
{
return Get(Number<I>{});
const index_t mData[mSize + 1] = {Is..., 0};
return mData[I];
}
// make sure I is constepxr if you want a constexpr return type
__host__ __device__ constexpr index_t operator[](index_t I) const { return GetImpl(I); }
template <index_t... IRs>
__host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
{
......@@ -55,16 +54,16 @@ struct Sequence
__host__ __device__ static constexpr auto Reverse();
__host__ __device__ static constexpr auto Front()
__host__ __device__ static constexpr index_t Front()
{
static_assert(mSize > 0, "wrong!");
return Get(Number<0>{});
const index_t mData[mSize + 1] = {Is..., 0};
return mData[0];
}
__host__ __device__ static constexpr auto Back()
__host__ __device__ static constexpr index_t Back()
{
static_assert(mSize > 0, "wrong!");
return Get(Number<mSize - 1>{});
const index_t mData[mSize + 1] = {Is..., 0};
return mData[mSize - 1];
}
__host__ __device__ static constexpr auto PopFront();
......
......@@ -118,6 +118,58 @@ __device__ void outerProduct4x4(const vector_type<float, 4>::MemoryType& a,
outerProduct1x4(a.w, b, c3);
}
__device__ void outerProduct1x4(const half2* a, const half2* b, float* c)
{
asm volatile("\n \
v_dot2_f32_f16 %0, %4, %6 %0\n \
v_dot2_f32_f16 %1, %4, %8 %1\n \
v_dot2_f32_f16 %2, %4, %10 %2\n \
v_dot2_f32_f16 %3, %4, %12 %3\n \
v_dot2_f32_f16 %0, %5, %7 %0\n \
v_dot2_f32_f16 %1, %5, %9 %1\n \
v_dot2_f32_f16 %2, %5, %11 %2\n \
v_dot2_f32_f16 %3, %5, %13 %3\n \
"
: "=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3]) // Dest registers
: "v"(a[0]),
"v"(a[1]), // 1st Src registers for 2 half2 registers
"v"(b[0]),
"v"(b[1]),
"v"(b[2]),
"v"(b[3]), // 2nd Src registers for 2 half2 registers
"v"(b[4]),
"v"(b[5]),
"v"(b[6]),
"v"(b[7]), // 2nd Src registers for 2 half2 registers
"0"(c[0]),
"1"(c[1]),
"2"(c[2]),
"3"(c[3])); // 3rd Src Acc registers for 2 half2 registers
}
__device__ void outerProduct1x4Half(const vector_type<half, 4>& a,
const vector_type<vector_type<half, 4>, 4>& b,
vector_type<float, 4>::MemoryType& c)
{
outerProduct1x4(reinterpret_cast<const half2*>(&a),
reinterpret_cast<const half2*>(&b),
reinterpret_cast<float*>(&c));
}
__device__ void outerProduct4x4(const vector_type<vector_type<half, 4>, 4>& a,
const vector_type<vector_type<half, 4>, 4>& b,
vector_type<float, 4>::MemoryType& c0,
vector_type<float, 4>::MemoryType& c1,
vector_type<float, 4>::MemoryType& c2,
vector_type<float, 4>::MemoryType& c3)
{
const vector_type<half, 4>* reg_a = reinterpret_cast<const vector_type<half, 4>*>(&a);
outerProduct1x4Half(reg_a[0], b, c0);
outerProduct1x4Half(reg_a[1], b, c1);
outerProduct1x4Half(reg_a[2], b, c2);
outerProduct1x4Half(reg_a[3], b, c3);
}
__device__ void outerProduct8x8(const vector_type<float, 4>::MemoryType* a,
const vector_type<float, 4>::MemoryType* b,
vector_type<float, 4>::MemoryType* c)
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2019 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef BFLOAT16_DEVICE_HPP
#define BFLOAT16_DEVICE_HPP
#ifdef __cplusplus
extern "C" {
#endif
#ifdef __HIP_PLATFORM_HCC__
#define EXECUTION_SPECIFIER __device__
#else
#define EXECUTION_SPECIFIER
#endif // MIOPEN_BACKEND_HIP
typedef union
{
uint u32;
ushort2 ushortx2;
// Composable kernels are written in HIP language. The language doesnt support
// ushort2.hi or ushort2.low.
#ifdef __HIP_PLATFORM_HCC__
ushort ushortvec[2];
#endif // MIOPEN_BACKEND_HIP
float f32;
} cvt_bf16_fp32_t;
EXECUTION_SPECIFIER float bfloat16_to_float(ushort src_val)
{
cvt_bf16_fp32_t target_val;
#ifdef __HIP_PLATFORM_HCC__
target_val.ushortx2 = make_ushort2(0, src_val);
#else
target_val.ushortx2 = (ushort2)(0, src_val);
#endif
return target_val.f32;
}
EXECUTION_SPECIFIER ushort float_to_bfloat16(float src_val)
{
cvt_bf16_fp32_t target_val;
target_val.f32 = src_val;
// BF16 round and NaN preservation code matches
// https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/include/rocblas_bfloat16.h
if((~target_val.u32 & 0x7f800000) == 0) // Inf or NaN
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
if((target_val.u32 & 0xffff) != 0)
{
target_val.u32 |= 0x10000; // Preserve signaling NaN
}
}
else
{
#ifdef MIOPEN_USE_RNE_BFLOAT16
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
#ifdef __HIP_PLATFORM_HCC__
target_val.u32 += (0x7fff + (target_val.ushortvec[0] & 1));
#else
target_val.u32 +=
(0x7fff + (target_val.ushortx2.hi & 1)); // Round to nearest, round to even
#endif // MIOPEN_BACKEND_HIP
#endif // MIOPEN_USE_RNE_BFLOAT16
}
#ifdef __HIP_PLATFORM_HCC__
return target_val.ushortvec[0];
#else
return target_val.ushortx2.hi;
#endif // MIOPEN_BACKEND_HIP
}
#ifdef __cplusplus
}
#endif
#endif // BFLOAT16_DEVICE_HPP
#ifndef CK_COMMON_HEADER_HPP
#define CK_COMMON_HEADER_HPP
#define MIOPEN_USE_FP16 1
#define MIOPEN_USE_BFP16 0
#define MIOPEN_USE_FP32 0
#define __HIP_PLATFORM_HCC__ 1
#include "config.hpp"
#include "utility.hpp"
#include "integral_constant.hpp"
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2019 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef FLOAT_TYPES_HPP
#define FLOAT_TYPES_HPP
#include "bfloat16_dev.hpp"
#define PPCAT_NX(A, B) A##B
#define PPCAT(A, B) PPCAT_NX(A, B)
#define TWO 2
#define FOUR 4
#define EIGHT 8
#if MIOPEN_USE_FP16 == 1
#ifdef __HIP_PLATFORM_HCC__
#define FLOAT half
#define FLOAT_ACCUM float
#else
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define _FLOAT half
#define _FLOAT_ACCUM float
#endif // __HIP_PLATFORM_HCC__
#define SIZEOF_FLOAT 2 /* sizeof is unavailable for preprocessor */
#ifndef HALF_MAX
#define MAX_VAL 65504 /* max value */
#else
#define MAX_VAL HALF_MAX
#endif // HALF_MAX
#endif // MIOPEN_USE_FP16
#if MIOPEN_USE_FP32 == 1
#ifdef __HIP_PLATFORM_HCC__
#define FLOAT float
#define FLOAT_ACCUM float
#else
#define _FLOAT float
#define _FLOAT_ACCUM float
#endif // __HIP_PLATFORM_HCC__
#define SIZEOF_FLOAT 4 /* sizeof is unavailable for preprocessor */
#ifndef FLT_MAX
#define MAX_VAL 3.402823466e+38F /* max value */
#else
#define MAX_VAL FLT_MAX
#endif // FLT_MAX
#endif // MIOPEN_USE_FP32
#if MIOPEN_USE_BFP16 == 1
#ifdef __HIP_PLATFORM_HCC__
#define FLOAT ushort
#define FLOAT_ACCUM float
#else
#define _FLOAT ushort
#define _FLOAT_ACCUM float
#endif //
#define SIZEOF_FLOAT 2 /* sizeof is unavailable for preprocessor */
#define MAX_VAL 0x7F7F /* max value */
#endif // MIOPEN_USE_BFP16
#if MIOPEN_USE_FP16 == 1
#ifdef __HIP_PLATFORM_HCC__
#define CVT_FLOAT2ACCUM(x) (static_cast<FLOAT_ACCUM>(x))
#define CVT_ACCUM2FLOAT(x) (static_cast<FLOAT>(x))
#else
#define CVT_FLOAT2ACCUM(x) ((_FLOAT_ACCUM)(x))
#define CVT_ACCUM2FLOAT(x) ((_FLOAT)(x))
#endif // MIOPEN_BACKEND_HIP
#endif // MIOPEN_USE_FP16
#if MIOPEN_USE_FP32 == 1
#ifdef __HIP_PLATFORM_HCC__
#define CVT_FLOAT2ACCUM(x) (static_cast<FLOAT_ACCUM>(x))
#define CVT_ACCUM2FLOAT(x) (static_cast<FLOAT>(x))
#else
#define CVT_FLOAT2ACCUM(x) ((_FLOAT_ACCUM)(x))
#define CVT_ACCUM2FLOAT(x) ((_FLOAT)(x))
#endif
#endif // MIOPEN_USE_FP32
#if MIOPEN_USE_BFP16 == 1
#define CVT_FLOAT2ACCUM(x) bfloat16_to_float(x)
#define CVT_ACCUM2FLOAT(x) float_to_bfloat16(x)
#endif
#ifndef __HIP_PLATFORM_HCC__
#define _FLOAT2 PPCAT(_FLOAT, TWO)
#endif
#endif // FLOAT_TYPES_HPP
......@@ -13,64 +13,30 @@ struct integral_constant
__host__ __device__ constexpr value_type operator()() const noexcept { return value; }
};
template <class X, class Y>
struct is_same : public integral_constant<bool, false>
{
};
template <class X>
struct is_same<X, X> : public integral_constant<bool, true>
template <class T, T X, T Y>
__host__ __device__ constexpr auto operator+(integral_constant<T, X>, integral_constant<T, Y>)
{
};
template <index_t N>
using Number = integral_constant<index_t, N>;
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator+(Number<X>, Number<Y>)
{
return Number<X + Y>{};
return integral_constant<T, X + Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator-(Number<X>, Number<Y>)
template <class T, T X, T Y>
__host__ __device__ constexpr auto operator*(integral_constant<T, X>, integral_constant<T, Y>)
{
static_assert(Y <= X, "wrong!");
return Number<X - Y>{};
return integral_constant<T, X * Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator*(Number<X>, Number<Y>)
{
return Number<X * Y>{};
}
template <index_t N>
using Number = integral_constant<index_t, N>;
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator/(Number<X>, Number<Y>)
template <class X, class Y>
struct is_same : public integral_constant<bool, false>
{
static_assert(Y > 0, "wrong!");
return Number<X / Y>{};
}
};
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator%(Number<X>, Number<Y>)
template <class X>
struct is_same<X, X> : public integral_constant<bool, true>
{
static_assert(Y > 0, "wrong!");
return Number<X % Y>{};
}
#if 0
static constexpr Number<0> 0_c;
static constexpr Number<1> 1_c;
static constexpr Number<2> 2_c;
static constexpr Number<3> 3_c;
static constexpr Number<4> 4_c;
static constexpr Number<5> 5_c;
static constexpr Number<6> 6_c;
static constexpr Number<7> 7_c;
static constexpr Number<8> 8_c;
static constexpr Number<9> 9_c;
#endif
};
} // namespace ck
#endif
......@@ -42,16 +42,20 @@ struct integer_divide_ceiler
}
};
template <class X, class Y>
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
template <class T>
__host__ __device__ constexpr T integer_divide_ceil(T a, T b)
{
return (x + y - 1) / y;
static_assert(is_same<T, index_t>{} || is_same<T, int>{}, "wrong type");
return (a + b - 1) / b;
}
template <class X, class Y>
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
template <class T>
__host__ __device__ constexpr T integer_least_multiple(T a, T b)
{
return y * integer_divide_ceil(x, y);
static_assert(is_same<T, index_t>{} || is_same<T, int>{}, "wrong type");
return b * integer_divide_ceil(a, b);
}
template <class T>
......
#ifndef CK_VECTOR_TYPE_HPP
#define CK_VECTOR_TYPE_HPP
#include "cuda_fp16.h"
#include "config.hpp"
#include "integral_constant.hpp"
......@@ -9,12 +10,15 @@ namespace ck {
template <class T, index_t N>
struct vector_type
{
T vector[N];
};
template <>
struct vector_type<float, 1>
{
typedef float MemoryType;
using MemoryType = float;
__host__ __device__ static constexpr index_t GetSize() { return 1; }
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
......@@ -29,6 +33,8 @@ struct vector_type<float, 2>
{
using MemoryType = float2_t;
__host__ __device__ static constexpr index_t GetSize() { return 2; }
union Data
{
MemoryType vector;
......@@ -42,13 +48,6 @@ struct vector_type<float, 2>
*(reinterpret_cast<float*>(&v) + I) = s;
}
__host__ __device__ static MemoryType Pack(float s0, float s1)
{
Data data;
data.scalar[0] = s0;
data.scalar[1] = s1;
return data.vector;
}
};
template <>
......@@ -56,6 +55,8 @@ struct vector_type<float, 4>
{
using MemoryType = float4_t;
__host__ __device__ static constexpr index_t GetSize() { return 4; }
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
{
......@@ -64,6 +65,116 @@ struct vector_type<float, 4>
}
};
template <>
struct vector_type<half, 1>
{
using MemoryType = half;
__host__ __device__ static constexpr index_t GetSize() { return 1; }
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>)
{
static_assert(I < 1, "wrong");
*(reinterpret_cast<half*>(&v) + I) = s;
}
};
template <>
struct vector_type<half, 2>
{
using MemoryType = half2;
union Data
{
MemoryType vector;
half scalar[2];
};
__host__ __device__ static constexpr index_t GetSize() { return 2; }
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>)
{
static_assert(I < 2, "wrong");
*(reinterpret_cast<half*>(&v) + I) = s;
}
};
template <>
struct vector_type<half, 4>
{
typedef struct MemoryType
{
half2 vector[2];
} MemoryType;
__host__ __device__ static constexpr index_t GetSize() { return 4; }
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>)
{
static_assert(I < 4, "wrong");
*(reinterpret_cast<half*>(&v) + I) = s;
}
};
template <>
struct vector_type<ushort, 1>
{
using MemoryType = ushort;
__host__ __device__ static constexpr index_t GetSize() { return 1; }
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, ushort s, Number<I>)
{
static_assert(I < 1, "wrong");
*(reinterpret_cast<ushort*>(&v) + I) = s;
}
};
template <>
struct vector_type<ushort, 2>
{
using MemoryType = ushort2;
union Data
{
MemoryType vector;
half scalar[2];
};
__host__ __device__ static constexpr index_t GetSize() { return 2; }
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, ushort s, Number<I>)
{
static_assert(I < 2, "wrong");
*(reinterpret_cast<ushort*>(&v) + I) = s;
}
};
template <>
struct vector_type<ushort, 4>
{
typedef struct MemoryType
{
ushort2 vector[2];
} MemoryType;
__host__ __device__ static constexpr index_t GetSize() { return 4; }
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, ushort s, Number<I>)
{
static_assert(I < 4, "wrong");
*(reinterpret_cast<ushort*>(&v) + I) = s;
}
};
} // namespace ck
#endif
#pragma once
#include <unistd.h>
#define MIOPEN_USE_FP16 1
#define MIOPEN_USE_BFP16 0
#define MIOPEN_USE_FP32 0
#define __HIP_PLATFORM_HCC__ 1
#include "float_types.h"
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp"
#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp"
#include "gridwise_convolution_implicit_gemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp"
#define CK_PARAM_TUNABLE_K_PER_BLOCK 64
using namespace ck;
......@@ -24,6 +35,10 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
ConvDilations,
index_t nrepeat)
{
// read params: problem decription
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
......@@ -59,16 +74,22 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
#if 1
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 8;
constexpr index_t KPerBlock = K % 128 == 0 ? 128 : (K % 64 == 0 ? 64 : 32);
constexpr index_t BlockSize = K % 128 == 0 ? 256 : (K % 64 == 0 ? 128 : 64);
#if MIOPEN_USE_FP16 == 1
// ES set to 4 as dot4 operator is supported on fp16 in MI100
constexpr index_t ES = 4;
#elif MIOPEN_USE_BFP16 == 1
// ES set to 2 as dot2 operator is supported on bfp16 in MI100
constexpr index_t ES = 2;
#else
// do nothing
#endif
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
......@@ -76,92 +97,103 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
#if MIOPEN_USE_FP32 == 1
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
#elif MIOPEN_USE_FP16 == 1 || MIOPEN_USE_BFP16 == 1
// ES - E dimension is folded into 2 dimensions E and ES
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2, 4>; // [E, N1, N2, B, ES]
using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2, 4>; // [E, N1, N2, B, ES]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3, 4>; // [E, N1, B, N2, ES]
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0, 2>; // [K, E, ES]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0, 2>; // [K, E, ES]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1, 2>; // [E, K, ES]
#endif
#if CK_PARAM_TUNABLE_K_PER_BLOCK == 32
constexpr index_t EPerBlock = 4;
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t GemmMLevel0Cluster = 1;
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 1;
#if MIOPEN_USE_FP32 == 1
// all_of(X_Per_Block % (X_Sub_Length * X_Cluster_Length) == 0)
// accumulate(X_Cluster_Lengths, multiply) == BlockSize
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 1, 16, 1>;
using WeiBlockCopySubLengths_E_K = Sequence<2, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 32>;
#elif MIOPEN_USE_FP16 == 1 || MIOPEN_USE_BFP16 == 1
using InBlockCopySubLengths_E_N1_B_N2_ES = Sequence<1, 2, 1, 4, ES>;
using InBlockCopyClusterLengths_E_N1_B_N2_ES = Sequence<4, 1, 16, 1, 1>;
using WeiBlockCopySubLengths_E_K_ES = Sequence<2, 1, ES>;
using WeiBlockCopyClusterLengths_E_K_ES = Sequence<2, 32, 1>;
#endif // MIOPEN_USE_FP32 == 1
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 8;
#elif CK_PARAM_TUNABLE_K_PER_BLOCK == 64
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t EPerBlock = 8;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 4, 1>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 4, 4>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t InBlockCopySrcDataPerRead_B = 4;
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 1;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
#if MIOPEN_USE_FP32 == 1
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 16, 1>;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>;
#elif MIOPEN_USE_FP16 == 1 || MIOPEN_USE_BFP16 == 1
// ES - E dimension is folded into 2 dimensions E and ES
using InBlockCopySubLengths_E_N1_B_N2_ES = Sequence<1, 2, 1, 4, ES>;
using InBlockCopyClusterLengths_E_N1_B_N2_ES = Sequence<8, 1, 16, 1, 1>;
using WeiBlockCopySubLengths_E_K_ES = Sequence<4, 1, ES>;
using WeiBlockCopyClusterLengths_E_K_ES = Sequence<2, 64, 1>;
#endif // MIOPEN_USE_FP32 == 1
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
#elif CK_PARAM_TUNABLE_K_PER_BLOCK == 128
constexpr index_t EPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 2, 2>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 8, 2>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 2;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 1;
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
#if MIOPEN_USE_FP32 == 1
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
#elif MIOPEN_USE_FP16 == 1 || MIOPEN_USE_BFP16 == 1
// ES - E dimension is folded into 2 dimensions E and ES
using InBlockCopySubLengths_E_N1_B_N2_ES = Sequence<1, 1, 1, 4, ES>;
using InBlockCopyClusterLengths_E_N1_B_N2_ES = Sequence<8, 2, 16, 1, 1>;
using WeiBlockCopySubLengths_E_K_ES = Sequence<4, 1, ES>;
using WeiBlockCopyClusterLengths_E_K_ES = Sequence<2, 128, 1>;
#endif // MIOPEN_USE_FP32 == 1
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#endif
#else
static_assert(false, "wrong! Only kperblock could be 32/64/128 not supported");
#endif // CK_PARAM_TUNABLE_K_PER_BLOCK == 32
constexpr index_t GridSize =
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
......@@ -171,47 +203,86 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
for(index_t i = 0; i < nrepeat; ++i)
{
constexpr auto gridwise_conv =
#if 0
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
#else
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
#endif
<GridSize,
BlockSize,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
BPerBlock,
KPerBlock,
EPerBlock,
N1,
N2,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2,
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>{};
#if MIOPEN_USE_FP32 == 1
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer<
GridSize,
BlockSize,
FLOAT,
FLOAT_ACCUM,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
BPerBlock,
KPerBlock,
EPerBlock,
N1,
N2,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2,
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>{};
#elif MIOPEN_USE_FP16 == 1 || MIOPEN_USE_BFP16 == 1
GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer<
GridSize,
BlockSize,
half,
float,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
BPerBlock,
KPerBlock,
EPerBlock,
N1,
N2,
ES,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_E_N1_B_N2_ES,
InBlockCopyClusterLengths_E_N1_B_N2_ES,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2,
WeiBlockCopySubLengths_E_K_ES,
WeiBlockCopyClusterLengths_E_K_ES,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>{};
#endif
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
......
......@@ -790,13 +790,13 @@ int main(int argc, char* argv[])
#elif 1
// 1x1 filter, 7x7 image
// cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52%
constexpr index_t N = 128;
constexpr index_t C = 832;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr index_t N = 32;
constexpr index_t C = 128;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 192;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
......@@ -817,8 +817,8 @@ int main(int argc, char* argv[])
ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
using in_data_t = float;
using out_data_t = float;
using in_data_t = half;
using out_data_t = half;
Tensor<in_data_t> in_nchw(make_TensorDescriptor(in_nchw_desc));
Tensor<in_data_t> wei_kcyx(make_TensorDescriptor(wei_kcyx_desc));
Tensor<out_data_t> out_nkhw_host(make_TensorDescriptor(out_nkhw_desc));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment