Commit 4b456610 authored by root's avatar root
Browse files

merge

parents 1014e6c9 4d93ce0e
...@@ -219,6 +219,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -219,6 +219,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
KernelTimer timer; KernelTimer timer;
timer.Start(); timer.Start();
std::cout << "has_main_k_block_loop: " << has_main_k_block_loop
<< " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop
<< std::endl;
for(index_t j = 0; j < nrepeat; ++j) for(index_t j = 0; j < nrepeat; ++j)
{ {
......
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
template <typename GridwiseOp, typename... Xs> template <typename GridwiseOp, typename... Xs>
__global__ void __global__ void
#if 0 #if 1
__launch_bounds__(256, 2) __launch_bounds__(64, 2)
#endif #endif
run_gridwise_operation(Xs... xs) run_gridwise_operation(Xs... xs)
{ {
GridwiseOp{}.Run(xs...); GridwiseOp{}.Run(xs...);
} }
......
...@@ -154,6 +154,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -154,6 +154,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
decltype(b_thread_mtx), decltype(b_thread_mtx),
decltype(c_thread_mtx)>{}; decltype(c_thread_mtx)>{};
// loop over k // loop over k
#pragma unroll
for(index_t cyx_begin = 0; cyx_begin < CYXPerBlock; cyx_begin += CYXPerThreadLoop) for(index_t cyx_begin = 0; cyx_begin < CYXPerBlock; cyx_begin += CYXPerThreadLoop)
{ {
a_thread_copy.Run(p_a_block + a_block_mtx.CalculateOffset(make_tuple(cyx_begin, 0)) + a_thread_copy.Run(p_a_block + a_block_mtx.CalculateOffset(make_tuple(cyx_begin, 0)) +
......
...@@ -12,8 +12,9 @@ ...@@ -12,8 +12,9 @@
namespace ck { namespace ck {
template <index_t BlockSize, template <index_t BlockSize,
typename Float, typename FloatAB,
typename AccFloat, typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation, InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc, typename AGlobalDesc,
typename BGlobalDesc, typename BGlobalDesc,
...@@ -52,7 +53,7 @@ template <index_t BlockSize, ...@@ -52,7 +53,7 @@ template <index_t BlockSize,
typename CGlobalIteratorHacks, typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks, typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks> typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_mn_v1 struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
{ {
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
...@@ -78,17 +79,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -78,17 +79,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
constexpr auto b_block_space_size = constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float); return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
} }
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc, __device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const Float* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc, const BGlobalDesc& b_k_n_global_desc,
const Float* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
Float* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
...@@ -144,8 +145,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -144,8 +145,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
ABlockTransferThreadSliceLengths_K_M, ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M, ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
Float, FloatAB,
Float, FloatAB,
decltype(a_k_m_global_desc), decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc), decltype(a_k_m_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -173,8 +174,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -173,8 +174,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
BBlockTransferThreadSliceLengths_K_N, BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N, BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
Float, FloatAB,
Float, FloatAB,
decltype(b_k_n_global_desc), decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc), decltype(b_k_n_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -235,11 +236,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -235,11 +236,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
constexpr auto b_block_space_size = constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
Float* p_a_block_double = p_shared_block; FloatAB* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output // register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()]; FloatAcc p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()];
// zero out threadwise output // zero out threadwise output
threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread); threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread);
...@@ -269,11 +270,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -269,11 +270,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
Float* p_a_block_even = p_a_block_double; FloatAB* p_a_block_even = p_a_block_double;
Float* p_b_block_even = p_b_block_double; FloatAB* p_b_block_even = p_b_block_double;
Float* p_a_block_odd = p_a_block_double + a_block_space_size; FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size;
Float* p_b_block_odd = p_b_block_double + b_block_space_size; FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size;
index_t k_block_data_begin = 0; index_t k_block_data_begin = 0;
...@@ -400,8 +401,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -400,8 +401,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{})); Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<
AccFloat, FloatAcc,
Float, FloatC,
decltype(c_m0_m1_n0_n1_thread_desc), decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc), decltype(c_m0_m1_n0_n1_global_desc),
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>, Sequence<MRepeat, MPerThread, NRepeat, NPerThread>,
...@@ -429,17 +430,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -429,17 +430,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// pass tensor descriptor by reference // pass tensor descriptor by reference
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc, __device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const Float* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc, const BGlobalDesc& b_k_n_global_desc,
const Float* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float); constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ Float p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
Run(a_k_m_global_desc, Run(a_k_m_global_desc,
p_a_global, p_a_global,
...@@ -452,14 +453,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -452,14 +453,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
// pass tensor descriptors by their pointers // pass tensor descriptors by pointers
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_k_m_global_desc, __device__ void Run(const AGlobalDesc* p_a_k_m_global_desc,
const Float* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc* p_b_k_n_global_desc, const BGlobalDesc* p_b_k_n_global_desc,
const Float* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const CGlobalDesc* p_c_m0_m1_n0_n1_global_desc, const CGlobalDesc* p_c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
...@@ -480,11 +481,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -480,11 +481,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// pass tensor descriptors by void* // pass tensor descriptors by void*
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_k_m_global_desc, __device__ void Run(const void* p_a_k_m_global_desc,
const Float* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const void* p_b_k_n_global_desc, const void* p_b_k_n_global_desc,
const Float* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const void* p_c_m0_m1_n0_n1_global_desc, const void* p_c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
......
...@@ -537,12 +537,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -537,12 +537,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_cyx_k_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_cyx_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<CYX>{}, Number<K>{}), max_lds_align); make_tuple(Number<CYX>{}, Number<K>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_cyx_k_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_cyx_k_desc.GetElementSpaceSize(), max_lds_align);
return a_block_space_size * sizeof(Float); return a_block_space_size * sizeof(Float);
} }
......
...@@ -181,7 +181,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -181,7 +181,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx + src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx +
i * dst_scalar_step_in_vector); i * dst_scalar_step_in_vector);
dst_vector.Scalars()(i) = p_src[Number<src_offset>{}]; dst_vector.Scalars()(i) = type_convert<DstData>{}(p_src[Number<src_offset>{}]);
}); });
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
......
...@@ -161,19 +161,7 @@ struct ThreadwiseGemm_km_kn_mn_v1 ...@@ -161,19 +161,7 @@ struct ThreadwiseGemm_km_kn_mn_v1
__device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) __device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{ {
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM #if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
constexpr bool has_amd_asm = is_same<FloatC, float>{} && Run_amd_asm(p_a, p_b, p_c);
((is_same<FloatA, float>{} && is_same<FloatB, float>{}) ||
(is_same<FloatA, half2_t>{} && is_same<FloatB, half2_t>{}) ||
(is_same<FloatA, half4_t>{} && is_same<FloatB, half4_t>{}));
if constexpr(has_amd_asm)
{
Run_amd_asm(p_a, p_b, p_c);
}
else
{
Run_source(p_a, p_b, p_c);
}
#else #else
Run_source(p_a, p_b, p_c); Run_source(p_a, p_b, p_c);
#endif #endif
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
namespace ck { namespace ck {
// outer-product: c[i,j] += inner_product(a[i], b[j]) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
{ {
#if CK_USE_AMD_V_FMAC_F32 #if CK_USE_AMD_V_FMAC_F32
...@@ -25,7 +26,10 @@ __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, floa ...@@ -25,7 +26,10 @@ __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, floa
#endif #endif
} }
// outer-product: c[i,j] += inner_product(a[i], b[j]) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4( __device__ void amd_assembly_outer_product_1x4(
float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3) float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3)
{ {
...@@ -50,7 +54,8 @@ __device__ void amd_assembly_outer_product_1x4( ...@@ -50,7 +54,8 @@ __device__ void amd_assembly_outer_product_1x4(
#endif #endif
} }
// outer-product: c[i,j] += inner_product(a[i], b[j]) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void __device__ void
amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1) amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1)
{ {
...@@ -58,15 +63,12 @@ amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, flo ...@@ -58,15 +63,12 @@ amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, flo
v_dot2_f32_f16 %0, %2, %3, %0\n \ v_dot2_f32_f16 %0, %2, %3, %0\n \
v_dot2_f32_f16 %1, %2, %4, %1\n \ v_dot2_f32_f16 %1, %2, %4, %1\n \
" "
: "=v"(c0), "=v"(c1) // Dest registers : "=v"(c0), "=v"(c1)
: "v"(a), // 1st Src register for 1 half2 registers : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
"v"(b0), // 2nd Src register
"v"(b1),
"0"(c0), // 3rd Src register
"1"(c1));
} }
// outer-product: c[i,j] += inner_product(a[i], b[j]) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void __device__ void
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1) amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
{ {
...@@ -81,18 +83,21 @@ amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, flo ...@@ -81,18 +83,21 @@ amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, flo
v_dot2_f32_f16 %0, %3, %5, %0\n \ v_dot2_f32_f16 %0, %3, %5, %0\n \
v_dot2_f32_f16 %1, %3, %7, %1\n \ v_dot2_f32_f16 %1, %3, %7, %1\n \
" "
: "=v"(c0), "=v"(c1) // Dest registers : "=v"(c0), "=v"(c1)
: "v"(p_a_half2[0]), : "v"(p_a_half2[0]),
"v"(p_a_half2[1]), // 1st Src registers for 2 half2 registers "v"(p_a_half2[1]),
"v"(p_b0_half2[0]), "v"(p_b0_half2[0]),
"v"(p_b0_half2[1]), "v"(p_b0_half2[1]),
"v"(p_b1_half2[0]), "v"(p_b1_half2[0]),
"v"(p_b1_half2[1]), // 2nd Src registers for 2 half2 registers "v"(p_b1_half2[1]),
"0"(c0), "0"(c0),
"1"(c1)); // 3rd Src Acc registers for 2 half2 registers "1"(c1));
} }
// outer-product: c[i,j] += inner_product(a[i], b[j]) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(half2_t a, __device__ void amd_assembly_outer_product_1x4(half2_t a,
half2_t b0, half2_t b0,
half2_t b1, half2_t b1,
...@@ -109,19 +114,14 @@ __device__ void amd_assembly_outer_product_1x4(half2_t a, ...@@ -109,19 +114,14 @@ __device__ void amd_assembly_outer_product_1x4(half2_t a,
v_dot2_f32_f16 %2, %4, %7, %2\n \ v_dot2_f32_f16 %2, %4, %7, %2\n \
v_dot2_f32_f16 %3, %4, %8, %3\n \ v_dot2_f32_f16 %3, %4, %8, %3\n \
" "
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // Dest registers : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), // 1st Src register for 1 half2 registers : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
"v"(b0), // 2nd Src register
"v"(b1),
"v"(b2),
"v"(b3),
"0"(c0), // 3rd Src register
"1"(c1),
"2"(c2),
"3"(c3));
} }
// outer-product: c[i,j] += inner_product(a[i], b[j]) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(half4_t a, __device__ void amd_assembly_outer_product_1x4(half4_t a,
half4_t b0, half4_t b0,
half4_t b1, half4_t b1,
...@@ -149,21 +149,70 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a, ...@@ -149,21 +149,70 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
v_dot2_f32_f16 %2, %5, %11, %2\n \ v_dot2_f32_f16 %2, %5, %11, %2\n \
v_dot2_f32_f16 %3, %5, %13, %3\n \ v_dot2_f32_f16 %3, %5, %13, %3\n \
" "
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // Dest registers : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(p_a_half2[0]), : "v"(p_a_half2[0]),
"v"(p_a_half2[1]), // 1st Src registers for 2 half2 registers "v"(p_a_half2[1]),
"v"(p_b0_half2[0]), "v"(p_b0_half2[0]),
"v"(p_b0_half2[1]), "v"(p_b0_half2[1]),
"v"(p_b1_half2[0]), "v"(p_b1_half2[0]),
"v"(p_b1_half2[1]), // 2nd Src registers for 2 half2 registers "v"(p_b1_half2[1]),
"v"(p_b2_half2[0]), "v"(p_b2_half2[0]),
"v"(p_b2_half2[1]), "v"(p_b2_half2[1]),
"v"(p_b3_half2[0]), "v"(p_b3_half2[0]),
"v"(p_b3_half2[1]), // 2nd Src registers for 2 half2 registers "v"(p_b3_half2[1]),
"0"(c0), "0"(c0),
"1"(c1), "1"(c1),
"2"(c2), "2"(c2),
"3"(c3)); // 3rd Src Acc registers for 2 half2 registers "3"(c3));
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void
amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1)
{
#if 1
asm volatile("\n \
v_dot4_i32_i8 %0, %2, %3, %0\n \
v_dot4_i32_i8 %1, %2, %4, %1\n \
"
: "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
#else
c0 = __builtin_amdgcn_sdot4(a, b0, c0, false);
c1 = __builtin_amdgcn_sdot4(a, b1, c1, false);
#endif
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(int8x4_t a,
int8x4_t b0,
int8x4_t b1,
int8x4_t b2,
int8x4_t b3,
int32_t& c0,
int32_t& c1,
int32_t& c2,
int32_t& c3)
{
#if 1
asm volatile("\n \
v_dot4_i32_i8 %0, %4, %5, %0\n \
v_dot4_i32_i8 %1, %4, %6, %1\n \
v_dot4_i32_i8 %2, %4, %7, %2\n \
v_dot4_i32_i8 %3, %4, %8, %3\n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
#else
c0 = __builtin_amdgcn_sdot4(a, b0, c0, false);
c1 = __builtin_amdgcn_sdot4(a, b1, c1, false);
c2 = __builtin_amdgcn_sdot4(a, b2, c2, false);
c3 = __builtin_amdgcn_sdot4(a, b3, c3, false);
#endif
} }
} // namespace ck } // namespace ck
......
...@@ -140,10 +140,5 @@ enum InMemoryDataOperation ...@@ -140,10 +140,5 @@ enum InMemoryDataOperation
// index type // index type
using index_t = int32_t; using index_t = int32_t;
typedef int32_t int32x2_t __attribute__((ext_vector_type(2)));
// int32x4_t use by buffer_load and buffer_store llvm intrinsic
typedef int32_t int32x4_t __attribute__((ext_vector_type(4)));
} // namespace ck } // namespace ck
#endif #endif
...@@ -3,172 +3,6 @@ ...@@ -3,172 +3,6 @@
namespace ck { namespace ck {
// For some reason, HIP compiler need this definition to generate optimal ISA
// fp32
typedef float float2_t __attribute__((ext_vector_type(2)));
typedef float float4_t __attribute__((ext_vector_type(4)));
typedef float float8_t __attribute__((ext_vector_type(8)));
typedef float float16_t __attribute__((ext_vector_type(16)));
typedef float float32_t __attribute__((ext_vector_type(32)));
// fp16
typedef _Float16 half_t;
typedef _Float16 half2_t __attribute__((ext_vector_type(2)));
typedef _Float16 half4_t __attribute__((ext_vector_type(4)));
typedef _Float16 half8_t __attribute__((ext_vector_type(8)));
// bfp16
typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
typedef ushort ushort4_t __attribute__((ext_vector_type(4)));
typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
struct c_vec32_4_t
{
union VecType
{
struct
{
float32_t x;
float32_t y;
float32_t z;
float32_t w;
} s;
float n[128];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
c.s.y = 0;
c.s.z = 0;
c.s.w = 0;
return c;
}
};
struct c_vec32_2_t
{
union VecType
{
struct
{
float32_t x;
float32_t y;
} s;
float n[64];
} l;
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
c.s.y = 0;
return c;
}
};
struct c_vec32_2_2_t
{
union VecType
{
struct
{
c_vec32_2_t x;
c_vec32_2_t y;
} s;
float n[128];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x.l.s.x = 0;
c.s.x.l.s.y = 0;
c.s.y.l.s.x = 0;
c.s.y.l.s.y = 0;
return c;
}
};
struct c_vec32_1_t
{
union VecType
{
struct
{
float32_t x;
} s;
float n[32];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
return c;
}
};
struct c_vec16_1_t
{
union VecType
{
struct
{
float16_t x;
} s;
float n[16];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
return c;
}
};
struct c_vec4_2_t
{
union VecType
{
struct
{
float4_t x;
float4_t y;
} s;
float n[8];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
c.s.y = 0;
return c;
}
};
struct c_vec4_1_t
{
union VecType
{
struct
{
float4_t x;
} s;
float n[4];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
return c;
}
};
template <typename T, index_t N> template <typename T, index_t N>
struct vector_type; struct vector_type;
...@@ -183,7 +17,9 @@ struct vector_type<T, 1> ...@@ -183,7 +17,9 @@ struct vector_type<T, 1>
StaticallyIndexedArray<T, 1> d1x1_; StaticallyIndexedArray<T, 1> d1x1_;
} data_; } data_;
__host__ __device__ constexpr vector_type() : data_{T{0}} {} __host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 1; } __host__ __device__ static constexpr index_t Size() { return 1; }
...@@ -215,7 +51,9 @@ struct vector_type<T, 2> ...@@ -215,7 +51,9 @@ struct vector_type<T, 2>
StaticallyIndexedArray<d2_t, 1> d2x1_; StaticallyIndexedArray<d2_t, 1> d2x1_;
} data_; } data_;
__host__ __device__ constexpr vector_type() : data_{d2_t{0}} {} __host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 2; } __host__ __device__ static constexpr index_t Size() { return 2; }
...@@ -253,7 +91,9 @@ struct vector_type<T, 4> ...@@ -253,7 +91,9 @@ struct vector_type<T, 4>
StaticallyIndexedArray<d4_t, 1> d4x1_; StaticallyIndexedArray<d4_t, 1> d4x1_;
} data_; } data_;
__host__ __device__ constexpr vector_type() : data_{d4_t{0}} {} __host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 4; } __host__ __device__ static constexpr index_t Size() { return 4; }
...@@ -297,7 +137,9 @@ struct vector_type<T, 8> ...@@ -297,7 +137,9 @@ struct vector_type<T, 8>
StaticallyIndexedArray<d8_t, 1> d8x1_; StaticallyIndexedArray<d8_t, 1> d8x1_;
} data_; } data_;
__host__ __device__ constexpr vector_type() : data_{d8_t{0}} {} __host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 8; } __host__ __device__ static constexpr index_t Size() { return 8; }
...@@ -326,6 +168,114 @@ struct vector_type<T, 8> ...@@ -326,6 +168,114 @@ struct vector_type<T, 8>
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; } __host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; }
}; };
template <>
struct vector_type<int8_t, 2>
{
using d1_t = int8_t;
typedef int16_t d2_t;
using type = d2_t;
union
{
d2_t d2_;
StaticallyIndexedArray<d1_t, 2> d1x2_;
StaticallyIndexedArray<d2_t, 1> d2x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 2; }
__host__ __device__ constexpr const auto& Vector() const { return data_.d2_; }
__host__ __device__ constexpr auto& Vector() { return data_.d2_; }
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x2_; }
__host__ __device__ constexpr auto& Scalars() { return data_.d1x2_; }
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x2_; }
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x1_; }
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x2_; }
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x1_; }
};
template <>
struct vector_type<int8_t, 4>
{
using d1_t = int8_t;
typedef int16_t d2_t;
typedef int32_t d4_t;
using type = d4_t;
union
{
d4_t d4_;
StaticallyIndexedArray<d1_t, 4> d1x4_;
StaticallyIndexedArray<d2_t, 2> d2x2_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 4; }
__host__ __device__ constexpr const auto& Vector() const { return data_.d4_; }
__host__ __device__ constexpr auto& Vector() { return data_.d4_; }
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x4_; }
__host__ __device__ constexpr auto& Scalars() { return data_.d1x4_; }
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x4_; }
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x2_; }
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x1_; }
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x4_; }
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x2_; }
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x1_; }
};
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
// fp16
using half_t = _Float16;
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
// bfp16
using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type;
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
// i8
// hack for int8x4_t, because compiler does not have native support for int8x4_t
// int8x4_t is defined as int32_t
using int8x4_t = typename vector_type<int8_t, 4>::type;
// data type conversion // data type conversion
template <typename T> template <typename T>
struct type_convert struct type_convert
...@@ -356,113 +306,37 @@ struct inner_product_with_conversion ...@@ -356,113 +306,37 @@ struct inner_product_with_conversion
{ {
static constexpr auto convert = type_convert<T>(); static constexpr auto convert = type_convert<T>();
__device__ T operator()(float4_t a, float4_t b) const template <typename X, index_t N>
__device__ T operator()(typename vector_type<X, N>::type a,
typename vector_type<X, N>::type b) const
{ {
const float* p_a_float = reinterpret_cast<const float*>(&a); const vector_type<X, N> a_vector{a};
const float* p_b_float = reinterpret_cast<const float*>(&b); const vector_type<X, N> b_vector{b};
T acc = 0; T acc = 0;
for(index_t v = 0; v < 4; ++v)
{
acc += convert(p_a_float[v]) * convert(p_b_float[v]);
}
return acc; static_for<0, N, 1>{}([&](auto i) {
} acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
});
__device__ T operator()(float2_t a, float2_t b) const
{
const float* p_a_float = reinterpret_cast<const float*>(&a);
const float* p_b_float = reinterpret_cast<const float*>(&b);
T acc = 0;
for(index_t v = 0; v < 2; ++v)
{
acc += convert(p_a_float[v]) * convert(p_b_float[v]);
}
return acc; return acc;
} }
__device__ T operator()(float a, float b) const { return convert(a) * convert(b); } __device__ T operator()(float_t a, float_t b) const { return convert(a) * convert(b); }
__device__ T operator()(half2_t a, half2_t b) const // hack for int8x4_t, because compiler does not have native support for int8x4_t
// int8x4_t is defined as int32_t
__device__ T operator()(int8x4_t a, int8x4_t b) const
{ {
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a); const vector_type<int8_t, 4> a_vector{a};
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b); const vector_type<int8_t, 4> b_vector{b};
T acc = 0; T acc = 0;
for(index_t v = 0; v < 2; ++v)
{
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
}
return acc; static_for<0, 4, 1>{}([&](auto i) {
} acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
});
__device__ T operator()(half4_t a, half4_t b) const
{
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
T acc = 0;
for(index_t v = 0; v < 4; ++v)
{
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
}
return acc;
}
__device__ T operator()(half8_t a, half8_t b) const
{
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
T acc = 0;
for(index_t v = 0; v < 8; ++v)
{
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
}
return acc;
}
__device__ T operator()(ushort2_t a, ushort2_t b) const
{
const ushort* p_a_bfloat16 = reinterpret_cast<const ushort*>(&a);
const ushort* p_b_bfloat16 = reinterpret_cast<const ushort*>(&b);
T acc = 0;
for(index_t v = 0; v < 2; ++v)
{
acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]);
}
return acc;
}
__device__ T operator()(ushort4_t a, ushort4_t b) const
{
const ushort* p_a_bfloat16 = reinterpret_cast<const ushort*>(&a);
const ushort* p_b_bfloat16 = reinterpret_cast<const ushort*>(&b);
T acc = 0;
for(index_t v = 0; v < 4; ++v)
{
acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]);
}
return acc;
}
__device__ T operator()(ushort8_t a, ushort8_t b) const
{
const ushort* p_a_bfloat16 = reinterpret_cast<const ushort*>(&a);
const ushort* p_b_bfloat16 = reinterpret_cast<const ushort*>(&b);
T acc = 0;
for(index_t v = 0; v < 8; ++v)
{
acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]);
}
return acc; return acc;
} }
}; };
......
...@@ -39,7 +39,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -39,7 +39,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0 #if 1
// run-time variables // run-time variables
const auto in_n_c_hi_wi_desc = const auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
...@@ -368,6 +368,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -368,6 +368,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
#endif #endif
<BlockSize, <BlockSize,
TDevice,
TDevice, TDevice,
TDevice, TDevice,
GemmMPerBlock, GemmMPerBlock,
......
...@@ -3,33 +3,36 @@ ...@@ -3,33 +3,36 @@
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" #include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
template <class T, template <class TInWei,
ck::index_t InWeiVectorSize,
class TAcc,
class TOut,
class InDesc, class InDesc,
class WeiDesc, class WeiDesc,
class OutDesc, class OutDesc,
class ConvStrides, class ConvStrides,
class ConvDilations, class ConvDilations,
class InLeftPads, class InLeftPads,
class InRightPads> class InRightPads,
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc, class T>
const Tensor<T>& in_nchw, void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
WeiDesc, InDesc,
const Tensor<T>& wei_kcyx, const Tensor<T>& in_n_c_hi_wi,
OutDesc, WeiDesc,
Tensor<T>& out_nkhw, const Tensor<T>& wei_k_c_y_x,
ConvStrides, OutDesc,
ConvDilations, Tensor<T>& out_n_k_ho_wo,
InLeftPads, ConvStrides,
InRightPads, ConvDilations,
ck::index_t nrepeat) InLeftPads,
InRightPads,
ck::index_t nrepeat)
{ {
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk" std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk"
<< std::endl; << std::endl;
using namespace ck; using namespace ck;
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
...@@ -48,12 +51,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -48,12 +51,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
constexpr auto Y = WeiDesc::GetLengths()[I2]; constexpr auto Y = WeiDesc::GetLengths()[I2];
constexpr auto X = WeiDesc::GetLengths()[I3]; constexpr auto X = WeiDesc::GetLengths()[I3];
constexpr auto C0 = C / Number<InWeiVectorSize>{};
constexpr auto C1 = Number<InWeiVectorSize>{};
#if 0 #if 0
// run-time variables // run-time variables
constexpr auto in_n_hi_wi_c_desc = constexpr auto in_n_hi_wi_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C)); make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0));
constexpr auto wei_k_y_x_c_desc = constexpr auto wei_k_y_x_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(K, Y, X, C)); make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(K, Y, X, C0));
constexpr auto out_n_ho_wo_k_desc = constexpr auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Ho, Wo, K)); make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Ho, Wo, K));
...@@ -63,10 +69,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -63,10 +69,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
const auto in_right_pads = to_multi_index(InRightPads{}); const auto in_right_pads = to_multi_index(InRightPads{});
#else #else
// compile-time variables // compile-time variables
constexpr auto in_n_hi_wi_c_desc = constexpr auto in_n_hi_wi_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Hi, Wi, C)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Hi, Wi, C0));
constexpr auto wei_k_y_x_c_desc = constexpr auto wei_k_y_x_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y, X, C)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y, X, C0));
constexpr auto out_n_ho_wo_k_desc = constexpr auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Ho, Wo, K)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Ho, Wo, K));
...@@ -76,38 +82,36 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -76,38 +82,36 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif #endif
Tensor<float> in_nhwc( Tensor<TInWei> in_n_hi_wi_c(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{}))); make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{})));
Tensor<float> wei_kyxc( Tensor<TInWei> wei_k_y_x_c(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{}))); make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{})));
Tensor<float> out_nhwk( Tensor<TOut> out_n_ho_wo_k(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{}))); make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{})));
auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) { auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) {
in_nhwc(n, hi, wi, c) = in_nchw(n, c, hi, wi); in_n_hi_wi_c(n, hi, wi, c) = in_n_c_hi_wi(n, c, hi, wi);
}; };
auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) { auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) {
wei_kyxc(k, y, x, c) = wei_kcyx(k, c, y, x); wei_k_y_x_c(k, y, x, c) = wei_k_c_y_x(k, c, y, x);
}; };
auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) { auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) {
out_nhwk(n, ho, wo, k) = out_nkhw(n, k, ho, wo); out_n_ho_wo_k(n, ho, wo, k) = out_n_k_ho_wo(n, k, ho, wo);
}; };
make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)();
make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)();
make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)();
std::size_t data_sz = sizeof(T);
DeviceMem in_nhwc_device_buf(data_sz * in_nhwc.mDesc.GetElementSpace()); DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_kyxc_device_buf(data_sz * wei_kyxc.mDesc.GetElementSpace()); DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_nhwk_device_buf(data_sz * out_nhwk.mDesc.GetElementSpace()); DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_nhwc_device_buf.ToDevice(in_nhwc.mData.data()); in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_kyxc_device_buf.ToDevice(wei_kyxc.mData.data()); wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_nhwk_device_buf.ToDevice(out_nhwk.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
#if 1 #if 1
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
...@@ -378,8 +382,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -378,8 +382,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
#endif #endif
<BlockSize, <BlockSize,
TDevice, typename vector_type<TInWei, InWeiVectorSize>::type,
TDevice, TAcc,
TOut,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
...@@ -400,22 +405,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -400,22 +405,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
GemmBBlockTransferDstScalarPerVector_GemmN, GemmBBlockTransferDstScalarPerVector_GemmN,
GemmCThreadTransferDstScalarPerVector_GemmM1>{}; GemmCThreadTransferDstScalarPerVector_GemmM1>{};
conv_driver.Run(wei_k_y_x_c_desc, conv_driver.Run(wei_k_y_x_c0_desc,
in_n_hi_wi_c_desc, in_n_hi_wi_c0_desc,
out_n_ho_wo_k_desc, out_n_ho_wo_k_desc,
conv_strides, conv_strides,
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
in_right_pads, in_right_pads,
static_cast<TDevice*>(wei_kyxc_device_buf.GetDeviceBuffer()), static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
static_cast<TDevice*>(in_nhwc_device_buf.GetDeviceBuffer()), wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(out_nhwk_device_buf.GetDeviceBuffer())); static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()));
out_nhwk_device_buf.FromDevice(out_nhwk.mData.data()); #if 1
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
#endif
auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) { auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) {
out_nkhw(n, k, ho, wo) = out_nhwk(n, ho, wo, k); out_n_k_ho_wo(n, k, ho, wo) = out_n_ho_wo_k(n, ho, wo, k);
}; };
make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)();
} }
...@@ -68,16 +68,16 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -68,16 +68,16 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
#endif #endif
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr index_t HPerBlock = 8; constexpr index_t HPerBlock = 16;
constexpr index_t WPerBlock = 8; constexpr index_t WPerBlock = 16;
constexpr index_t CYXPerBlock = 4; constexpr index_t CYXPerBlock = 4;
constexpr index_t KPerThread = 8; constexpr index_t KPerThread = 16;
constexpr index_t HPerThread = 1; constexpr index_t HPerThread = 2;
constexpr index_t WPerThread = 1; constexpr index_t WPerThread = 2;
constexpr index_t CYXPerThread = 4; constexpr index_t CYXPerThread = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
......
...@@ -158,7 +158,7 @@ struct ParallelTensorFunctor ...@@ -158,7 +158,7 @@ struct ParallelTensorFunctor
return indices; return indices;
} }
void operator()(std::size_t num_thread) const void operator()(std::size_t num_thread = std::thread::hardware_concurrency()) const
{ {
std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread; std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread;
......
...@@ -25,7 +25,21 @@ int main(int argc, char* argv[]) ...@@ -25,7 +25,21 @@ int main(int argc, char* argv[])
#if 0 #if 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 16;
constexpr index_t HI = 1;
constexpr index_t WI = 64;
constexpr index_t K = 16;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 1080; constexpr index_t HI = 1080;
constexpr index_t WI = 1920; constexpr index_t WI = 1920;
constexpr index_t K = 16; constexpr index_t K = 16;
...@@ -35,11 +49,11 @@ int main(int argc, char* argv[]) ...@@ -35,11 +49,11 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 16;
constexpr index_t HI = 540; constexpr index_t HI = 540;
constexpr index_t WI = 960; constexpr index_t WI = 960;
constexpr index_t K = 16; constexpr index_t K = 16;
...@@ -49,11 +63,11 @@ int main(int argc, char* argv[]) ...@@ -49,11 +63,11 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 16;
constexpr index_t HI = 270; constexpr index_t HI = 270;
constexpr index_t WI = 480; constexpr index_t WI = 480;
constexpr index_t K = 16; constexpr index_t K = 16;
...@@ -65,20 +79,6 @@ int main(int argc, char* argv[]) ...@@ -65,20 +79,6 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 4;
constexpr index_t HI = 1080;
constexpr index_t WI = 1920;
constexpr index_t K = 16;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1 #elif 1
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 4;
...@@ -95,7 +95,7 @@ int main(int argc, char* argv[]) ...@@ -95,7 +95,7 @@ int main(int argc, char* argv[])
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 16;
constexpr index_t HI = 540; constexpr index_t HI = 540;
constexpr index_t WI = 960; constexpr index_t WI = 960;
constexpr index_t K = 16; constexpr index_t K = 16;
...@@ -109,7 +109,7 @@ int main(int argc, char* argv[]) ...@@ -109,7 +109,7 @@ int main(int argc, char* argv[])
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 16;
constexpr index_t HI = 270; constexpr index_t HI = 270;
constexpr index_t WI = 480; constexpr index_t WI = 480;
constexpr index_t K = 16; constexpr index_t K = 16;
...@@ -631,12 +631,16 @@ int main(int argc, char* argv[]) ...@@ -631,12 +631,16 @@ int main(int argc, char* argv[])
print_array("ConvStrides", to_multi_index(ConvStrides{})); print_array("ConvStrides", to_multi_index(ConvStrides{}));
print_array("ConvDilations", to_multi_index(ConvDilations{})); print_array("ConvDilations", to_multi_index(ConvDilations{}));
#if 1 #if 0
using in_data_t = float; using in_data_t = float;
constexpr index_t in_vector_size = 1;
using out_data_t = float; using out_data_t = float;
using acc_data_t = float;
#else #else
using in_data_t = half_float::half; using in_data_t = int8_t;
using out_data_t = half_float::half; constexpr index_t in_vector_size = 4;
using acc_data_t = int32_t;
using out_data_t = int8_t;
#endif #endif
Tensor<in_data_t> in_nchw(make_HostTensorDescriptor(in_nchw_desc)); Tensor<in_data_t> in_nchw(make_HostTensorDescriptor(in_nchw_desc));
...@@ -646,14 +650,15 @@ int main(int argc, char* argv[]) ...@@ -646,14 +650,15 @@ int main(int argc, char* argv[])
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
if(argc != 3) if(argc != 4)
{ {
printf("arg1: do_verification, arg2: nrepeat\n"); printf("arg1: do_verification, arg2: do_log, arg3: nrepeat\n");
exit(1); exit(1);
} }
bool do_verification = atoi(argv[1]); bool do_verification = atoi(argv[1]);
index_t nrepeat = atoi(argv[2]); bool do_log = atoi(argv[2]);
index_t nrepeat = atoi(argv[3]);
if(do_verification) if(do_verification)
{ {
...@@ -662,7 +667,7 @@ int main(int argc, char* argv[]) ...@@ -662,7 +667,7 @@ int main(int argc, char* argv[])
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 0 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_3{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 0 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
...@@ -751,36 +756,42 @@ int main(int argc, char* argv[]) ...@@ -751,36 +756,42 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
in_vector_size,
acc_data_t,
out_data_t>(
in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
#endif #endif
if(do_verification) if(do_verification)
{ {
#if 0 host_direct_convolution(in_nchw,
if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 && wei_kcyx,
ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1) out_nkhw_host,
{ ConvStrides{},
host_winograd_3x3_convolution( ConvDilations{},
in_nchw, wei_kcyx, out_nkhw_host, LeftPads{}, RightPads{}); LeftPads{},
} RightPads{});
else
#endif
{
host_direct_convolution(in_nchw,
wei_kcyx,
out_nkhw_host,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{});
}
check_error(out_nkhw_host, out_nkhw_device); check_error(out_nkhw_host, out_nkhw_device);
#if 0 if(do_log)
// LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; {
// LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl; LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl; LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl; LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
#endif LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
}
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment