Commit 59968d8d authored by Chao Liu's avatar Chao Liu
Browse files

adding int8x4

parent 4f31669f
...@@ -12,8 +12,9 @@ ...@@ -12,8 +12,9 @@
namespace ck { namespace ck {
template <index_t BlockSize, template <index_t BlockSize,
typename Float, typename FloatAB,
typename AccFloat, typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation, InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc, typename AGlobalDesc,
typename BGlobalDesc, typename BGlobalDesc,
...@@ -52,7 +53,7 @@ template <index_t BlockSize, ...@@ -52,7 +53,7 @@ template <index_t BlockSize,
typename CGlobalIteratorHacks, typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks, typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks> typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_mn_v1 struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
{ {
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
...@@ -78,17 +79,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -78,17 +79,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
constexpr auto b_block_space_size = constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float); return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
} }
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc, __device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const Float* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc, const BGlobalDesc& b_k_n_global_desc,
const Float* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
Float* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
...@@ -144,8 +145,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -144,8 +145,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
ABlockTransferThreadSliceLengths_K_M, ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M, ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
Float, FloatAB,
Float, FloatAB,
decltype(a_k_m_global_desc), decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc), decltype(a_k_m_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -173,8 +174,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -173,8 +174,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
BBlockTransferThreadSliceLengths_K_N, BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N, BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
Float, FloatAB,
Float, FloatAB,
decltype(b_k_n_global_desc), decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc), decltype(b_k_n_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -235,11 +236,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -235,11 +236,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
constexpr auto b_block_space_size = constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
Float* p_a_block_double = p_shared_block; FloatAB* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output // register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()]; FloatAcc p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()];
// zero out threadwise output // zero out threadwise output
threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread); threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread);
...@@ -269,11 +270,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -269,11 +270,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
Float* p_a_block_even = p_a_block_double; FloatAB* p_a_block_even = p_a_block_double;
Float* p_b_block_even = p_b_block_double; FloatAB* p_b_block_even = p_b_block_double;
Float* p_a_block_odd = p_a_block_double + a_block_space_size; FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size;
Float* p_b_block_odd = p_b_block_double + b_block_space_size; FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size;
index_t k_block_data_begin = 0; index_t k_block_data_begin = 0;
...@@ -400,8 +401,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -400,8 +401,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{})); Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<
AccFloat, FloatAcc,
Float, FloatC,
decltype(c_m0_m1_n0_n1_thread_desc), decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc), decltype(c_m0_m1_n0_n1_global_desc),
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>, Sequence<MRepeat, MPerThread, NRepeat, NPerThread>,
...@@ -429,17 +430,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -429,17 +430,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// pass tensor descriptor by reference // pass tensor descriptor by reference
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc, __device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const Float* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc, const BGlobalDesc& b_k_n_global_desc,
const Float* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float); constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ Float p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
Run(a_k_m_global_desc, Run(a_k_m_global_desc,
p_a_global, p_a_global,
...@@ -452,14 +453,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -452,14 +453,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
// pass tensor descriptors by their pointers // pass tensor descriptors by pointers
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_k_m_global_desc, __device__ void Run(const AGlobalDesc* p_a_k_m_global_desc,
const Float* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc* p_b_k_n_global_desc, const BGlobalDesc* p_b_k_n_global_desc,
const Float* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const CGlobalDesc* p_c_m0_m1_n0_n1_global_desc, const CGlobalDesc* p_c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
...@@ -480,11 +481,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -480,11 +481,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// pass tensor descriptors by void* // pass tensor descriptors by void*
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_k_m_global_desc, __device__ void Run(const void* p_a_k_m_global_desc,
const Float* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const void* p_b_k_n_global_desc, const void* p_b_k_n_global_desc,
const Float* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const void* p_c_m0_m1_n0_n1_global_desc, const void* p_c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
......
...@@ -182,7 +182,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -182,7 +182,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx + src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx +
i * dst_scalar_step_in_vector); i * dst_scalar_step_in_vector);
dst_vector.Scalars()(i) = p_src[Number<src_offset>{}]; dst_vector.Scalars()(i) = type_convert<DstData>{}(p_src[Number<src_offset>{}]);
}); });
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
......
...@@ -75,9 +75,9 @@ struct ThreadwiseGemm_km_kn_mn_v1 ...@@ -75,9 +75,9 @@ struct ThreadwiseGemm_km_kn_mn_v1
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto M = CDesc{}[I0]; constexpr auto M = CDesc{}.GetLength(I0);
constexpr auto N = CDesc{}[I1]; constexpr auto N = CDesc{}.GetLength(I1);
constexpr auto K = ADesc{}[I0]; constexpr auto K = ADesc{}.GetLength(I0);
static_for<0, K, 1>{}([&](auto k) { static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) { static_for<0, M, 1>{}([&](auto m) {
...@@ -161,19 +161,7 @@ struct ThreadwiseGemm_km_kn_mn_v1 ...@@ -161,19 +161,7 @@ struct ThreadwiseGemm_km_kn_mn_v1
__device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) __device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{ {
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM #if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
constexpr bool has_amd_asm = is_same<FloatC, float>{} &&
((is_same<FloatA, float>{} && is_same<FloatB, float>{}) ||
(is_same<FloatA, half2_t>{} && is_same<FloatB, half2_t>{}) ||
(is_same<FloatA, half4_t>{} && is_same<FloatB, half4_t>{}));
if constexpr(has_amd_asm)
{
Run_amd_asm(p_a, p_b, p_c); Run_amd_asm(p_a, p_b, p_c);
}
else
{
Run_source(p_a, p_b, p_c);
}
#else #else
Run_source(p_a, p_b, p_c); Run_source(p_a, p_b, p_c);
#endif #endif
......
...@@ -31,6 +31,7 @@ __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_siz ...@@ -31,6 +31,7 @@ __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_siz
return wave_buffer_resource.data; return wave_buffer_resource.data;
} }
// fp32 load
__device__ float __device__ float
__llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, __llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
index_t voffset, index_t voffset,
...@@ -49,6 +50,7 @@ __llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc, ...@@ -49,6 +50,7 @@ __llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
// fp32 store
__device__ void __device__ void
__llvm_amdgcn_raw_buffer_store_fp32(float vdata, __llvm_amdgcn_raw_buffer_store_fp32(float vdata,
int32x4_t rsrc, int32x4_t rsrc,
...@@ -70,213 +72,247 @@ __llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, ...@@ -70,213 +72,247 @@ __llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
// buffer_load requires: // i32 load
// 1) p_src_wave must be in global memory space __device__ int32_t
// 2) p_src_wave to be a wavewise pointer. __llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
// It is user's responsibility to make sure that is true. index_t voffset,
template <typename T, index_t VectorSize> index_t soffset,
__device__ typename vector_type<T, VectorSize>::type index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
amd_buffer_load_v2(const T* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_elemenst_space);
// buffer_store requires:
// 1) p_dst_wave must be global memory
// 2) p_dst_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t VectorSize>
__device__ void amd_buffer_store_v2(const typename vector_type<T, VectorSize>::type src_thread_data,
T* p_dst_wave,
const index_t dst_thread_data_offset,
const bool dst_thread_data_valid,
const index_t dst_data_range);
template <>
__device__ float amd_buffer_load_v2<float, 1>(const float* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_data_range);
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); __device__ int32x2_t
__llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK __device__ int32x4_t
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; __llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
return __llvm_amdgcn_raw_buffer_load_fp32( // i32 store
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0, 0); __device__ void
#else __llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
float tmp = int32x4_t rsrc,
__llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource, src_thread_addr_offset, 0, 0); index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
return src_thread_data_valid ? tmp : float(0); // i16 store
#endif __device__ void
} __llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
template <> template <typename T, index_t N>
__device__ float2_t amd_buffer_load_v2<float, 2>(const float* p_src_wave, __device__ typename vector_type<T, N>::type
index_t src_thread_data_offset, amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
bool src_thread_data_valid, index_t src_thread_addr_offset,
index_t src_data_range) index_t src_wave_addr_offset)
{ {
const int32x4_t src_wave_buffer_resource = static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
make_wave_buffer_resource(p_src_wave, src_data_range); (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
"wrong! not implemented");
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
if constexpr(is_same<T, float>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_fp32(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 2)
{
return __llvm_amdgcn_raw_buffer_load_fp32x2( return __llvm_amdgcn_raw_buffer_load_fp32x2(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
#else }
float2_t tmp = __llvm_amdgcn_raw_buffer_load_fp32x2( else if constexpr(N == 4)
src_wave_buffer_resource, src_thread_addr_offset, 0, 0); {
return src_thread_data_valid ? tmp : float2_t(0);
#endif
}
template <>
__device__ float4_t amd_buffer_load_v2<float, 4>(const float* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_data_range);
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
return __llvm_amdgcn_raw_buffer_load_fp32x4( return __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
#else }
float4_t tmp = __llvm_amdgcn_raw_buffer_load_fp32x4( else if constexpr(N == 8)
src_wave_buffer_resource, src_thread_addr_offset, 0, 0); {
return src_thread_data_valid ? tmp : float4_t(0);
#endif
}
template <>
__device__ float8_t amd_buffer_load_v2<float, 8>(const float* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_data_range);
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
vector_type<float, 8> tmp; vector_type<float, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_wave_buffer_resource, src_thread_addr_offset, 4 * sizeof(float), 0);
src_addr_shift + src_thread_addr_offset + 4 * sizeof(float),
0,
0);
return tmp.Vector(); return tmp.Vector();
#else }
vector_type<float, 8> tmp; }
else if constexpr(is_same<T, int32_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_i32(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 2)
{
return __llvm_amdgcn_raw_buffer_load_i32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 4)
{
return __llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 8)
{
vector_type<int32_t, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, 0, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset + 4 * sizeof(float), 0, 0); src_wave_buffer_resource, src_thread_addr_offset, 4 * sizeof(int32_t), 0);
return src_thread_data_valid ? tmp.Vector() : float8_t(0); return tmp.Vector();
#endif }
}
} }
template <> template <typename T, index_t N>
__device__ void amd_buffer_store_v2<float, 1>(const float src_thread_data, __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type src_thread_data,
float* p_dst_wave, int32x4_t dst_wave_buffer_resource,
const index_t dst_thread_data_offset, index_t dst_thread_addr_offset,
const bool dst_thread_data_valid, index_t dst_wave_addr_offset)
const index_t dst_data_range)
{ {
const int32x4_t dst_wave_buffer_resource = static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
make_wave_buffer_resource(p_dst_wave, dst_data_range); (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 2)),
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); "wrong! not implemented");
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK if constexpr(is_same<T, float>::value)
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_raw_buffer_store_fp32(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0, 0);
#else
if(dst_thread_data_valid)
{ {
__llvm_amdgcn_buffer_store_fp32( if constexpr(N == 1)
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0, 0); {
__llvm_amdgcn_raw_buffer_store_fp32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
__llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
__llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
else if constexpr(is_same<T, int32_t>::value)
{
if constexpr(N == 1)
{
__llvm_amdgcn_raw_buffer_store_i32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
__llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
__llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
else if constexpr(is_same<T, int8_t>::value)
{
if constexpr(N == 2)
{
__llvm_amdgcn_raw_buffer_store_i16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
} }
#endif
} }
template <> // buffer_load requires:
__device__ void amd_buffer_store_v2<float, 2>(const float2_t src_thread_data, // 1) p_src_wave must be in global memory space
float* p_dst_wave, // 2) p_src_wave to be a wavewise pointer.
const index_t dst_thread_data_offset, // It is user's responsibility to make sure that is true.
const bool dst_thread_data_valid, template <typename T, index_t N>
const index_t dst_data_range) __device__ typename vector_type<T, N>::type amd_buffer_load_v2(const T* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_element_space)
{ {
const int32x4_t dst_wave_buffer_resource = const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_data_range); make_wave_buffer_resource(p_src_wave, src_element_space);
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(T);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_raw_buffer_store_fp32x2( return amd_buffer_load_impl_v2<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#else #else
if(dst_thread_data_valid) using vector_t = typename vector_type<T, N>::type;
{
__llvm_amdgcn_raw_buffer_store_fp32x2( vector_t tmp =
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0, 0); amd_buffer_load_impl_v2<T, N>(src_wave_buffer_resource, src_thread_addr_offset, 0);
}
return src_thread_data_valid ? tmp : vector_t(0);
#endif #endif
} }
template <> // buffer_store requires:
__device__ void amd_buffer_store_v2<float, 4>(const float4_t src_thread_data, // 1) p_dst_wave must be global memory
float* p_dst_wave, // 2) p_dst_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
__device__ void amd_buffer_store_v2(const typename vector_type<T, N>::type src_thread_data,
T* p_dst_wave,
const index_t dst_thread_data_offset, const index_t dst_thread_data_offset,
const bool dst_thread_data_valid, const bool dst_thread_data_valid,
const index_t dst_data_range) const index_t dst_element_space)
{ {
const int32x4_t dst_wave_buffer_resource = const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_data_range); make_wave_buffer_resource(p_dst_wave, dst_element_space);
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(T);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_raw_buffer_store_fp32x4( amd_buffer_store_impl_v2<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else #else
if(dst_thread_data_valid) if(dst_thread_data_valid)
{ {
__llvm_amdgcn_raw_buffer_store_fp32x4( amd_buffer_store_impl_v2<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0, 0); src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
} }
#endif #endif
} }
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
namespace ck { namespace ck {
// outer-product: c[i,j] += inner_product(a[i], b[j]) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
{ {
#if CK_USE_AMD_V_FMAC_F32 #if CK_USE_AMD_V_FMAC_F32
...@@ -25,7 +26,10 @@ __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, floa ...@@ -25,7 +26,10 @@ __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, floa
#endif #endif
} }
// outer-product: c[i,j] += inner_product(a[i], b[j]) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4( __device__ void amd_assembly_outer_product_1x4(
float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3) float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3)
{ {
...@@ -50,7 +54,8 @@ __device__ void amd_assembly_outer_product_1x4( ...@@ -50,7 +54,8 @@ __device__ void amd_assembly_outer_product_1x4(
#endif #endif
} }
// outer-product: c[i,j] += inner_product(a[i], b[j]) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void __device__ void
amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1) amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1)
{ {
...@@ -58,15 +63,12 @@ amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, flo ...@@ -58,15 +63,12 @@ amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, flo
v_dot2_f32_f16 %0, %2, %3, %0\n \ v_dot2_f32_f16 %0, %2, %3, %0\n \
v_dot2_f32_f16 %1, %2, %4, %1\n \ v_dot2_f32_f16 %1, %2, %4, %1\n \
" "
: "=v"(c0), "=v"(c1) // Dest registers : "=v"(c0), "=v"(c1)
: "v"(a), // 1st Src register for 1 half2 registers : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
"v"(b0), // 2nd Src register
"v"(b1),
"0"(c0), // 3rd Src register
"1"(c1));
} }
// outer-product: c[i,j] += inner_product(a[i], b[j]) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void __device__ void
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1) amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
{ {
...@@ -81,18 +83,21 @@ amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, flo ...@@ -81,18 +83,21 @@ amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, flo
v_dot2_f32_f16 %0, %3, %5, %0\n \ v_dot2_f32_f16 %0, %3, %5, %0\n \
v_dot2_f32_f16 %1, %3, %7, %1\n \ v_dot2_f32_f16 %1, %3, %7, %1\n \
" "
: "=v"(c0), "=v"(c1) // Dest registers : "=v"(c0), "=v"(c1)
: "v"(p_a_half2[0]), : "v"(p_a_half2[0]),
"v"(p_a_half2[1]), // 1st Src registers for 2 half2 registers "v"(p_a_half2[1]),
"v"(p_b0_half2[0]), "v"(p_b0_half2[0]),
"v"(p_b0_half2[1]), "v"(p_b0_half2[1]),
"v"(p_b1_half2[0]), "v"(p_b1_half2[0]),
"v"(p_b1_half2[1]), // 2nd Src registers for 2 half2 registers "v"(p_b1_half2[1]),
"0"(c0), "0"(c0),
"1"(c1)); // 3rd Src Acc registers for 2 half2 registers "1"(c1));
} }
// outer-product: c[i,j] += inner_product(a[i], b[j]) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(half2_t a, __device__ void amd_assembly_outer_product_1x4(half2_t a,
half2_t b0, half2_t b0,
half2_t b1, half2_t b1,
...@@ -109,19 +114,14 @@ __device__ void amd_assembly_outer_product_1x4(half2_t a, ...@@ -109,19 +114,14 @@ __device__ void amd_assembly_outer_product_1x4(half2_t a,
v_dot2_f32_f16 %2, %4, %7, %2\n \ v_dot2_f32_f16 %2, %4, %7, %2\n \
v_dot2_f32_f16 %3, %4, %8, %3\n \ v_dot2_f32_f16 %3, %4, %8, %3\n \
" "
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // Dest registers : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), // 1st Src register for 1 half2 registers : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
"v"(b0), // 2nd Src register
"v"(b1),
"v"(b2),
"v"(b3),
"0"(c0), // 3rd Src register
"1"(c1),
"2"(c2),
"3"(c3));
} }
// outer-product: c[i,j] += inner_product(a[i], b[j]) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(half4_t a, __device__ void amd_assembly_outer_product_1x4(half4_t a,
half4_t b0, half4_t b0,
half4_t b1, half4_t b1,
...@@ -149,21 +149,70 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a, ...@@ -149,21 +149,70 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
v_dot2_f32_f16 %2, %5, %11, %2\n \ v_dot2_f32_f16 %2, %5, %11, %2\n \
v_dot2_f32_f16 %3, %5, %13, %3\n \ v_dot2_f32_f16 %3, %5, %13, %3\n \
" "
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // Dest registers : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(p_a_half2[0]), : "v"(p_a_half2[0]),
"v"(p_a_half2[1]), // 1st Src registers for 2 half2 registers "v"(p_a_half2[1]),
"v"(p_b0_half2[0]), "v"(p_b0_half2[0]),
"v"(p_b0_half2[1]), "v"(p_b0_half2[1]),
"v"(p_b1_half2[0]), "v"(p_b1_half2[0]),
"v"(p_b1_half2[1]), // 2nd Src registers for 2 half2 registers "v"(p_b1_half2[1]),
"v"(p_b2_half2[0]), "v"(p_b2_half2[0]),
"v"(p_b2_half2[1]), "v"(p_b2_half2[1]),
"v"(p_b3_half2[0]), "v"(p_b3_half2[0]),
"v"(p_b3_half2[1]), // 2nd Src registers for 2 half2 registers "v"(p_b3_half2[1]),
"0"(c0), "0"(c0),
"1"(c1), "1"(c1),
"2"(c2), "2"(c2),
"3"(c3)); // 3rd Src Acc registers for 2 half2 registers "3"(c3));
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void
amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1)
{
#if 1
asm volatile("\n \
v_dot4_i32_i8 %0, %2, %3, %0\n \
v_dot4_i32_i8 %1, %2, %4, %1\n \
"
: "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
#else
c0 = __builtin_amdgcn_sdot4(a, b0, c0, false);
c1 = __builtin_amdgcn_sdot4(a, b1, c1, false);
#endif
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(int8x4_t a,
int8x4_t b0,
int8x4_t b1,
int8x4_t b2,
int8x4_t b3,
int32_t& c0,
int32_t& c1,
int32_t& c2,
int32_t& c3)
{
#if 1
asm volatile("\n \
v_dot4_i32_i8 %0, %4, %5, %0\n \
v_dot4_i32_i8 %1, %4, %6, %1\n \
v_dot4_i32_i8 %2, %4, %7, %2\n \
v_dot4_i32_i8 %3, %4, %8, %3\n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
#else
c0 = __builtin_amdgcn_sdot4(a, b0, c0, false);
c1 = __builtin_amdgcn_sdot4(a, b1, c1, false);
c2 = __builtin_amdgcn_sdot4(a, b2, c2, false);
c3 = __builtin_amdgcn_sdot4(a, b3, c3, false);
#endif
} }
} // namespace ck } // namespace ck
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#endif #endif
#include "bfloat16_dev.hpp" #include "bfloat16_dev.hpp"
#if 0 #if 1
#define CK_AMD_GPU_GFX906 1 #define CK_AMD_GPU_GFX906 1
#elif 0 #elif 0
#define CK_AMD_GPU_GFX908 1 #define CK_AMD_GPU_GFX908 1
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
#endif #endif
#ifndef CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM #ifndef CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 1 #define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 0
#endif #endif
#ifndef CK_USE_AMD_V_FMAC_F32 #ifndef CK_USE_AMD_V_FMAC_F32
...@@ -140,10 +140,5 @@ enum InMemoryDataOperation ...@@ -140,10 +140,5 @@ enum InMemoryDataOperation
// index type // index type
using index_t = int32_t; using index_t = int32_t;
typedef int32_t int32x2_t __attribute__((ext_vector_type(2)));
// int32x4_t use by buffer_load and buffer_store llvm intrinsic
typedef int32_t int32x4_t __attribute__((ext_vector_type(4)));
} // namespace ck } // namespace ck
#endif #endif
...@@ -3,172 +3,6 @@ ...@@ -3,172 +3,6 @@
namespace ck { namespace ck {
// For some reason, HIP compiler need this definition to generate optimal ISA
// fp32
typedef float float2_t __attribute__((ext_vector_type(2)));
typedef float float4_t __attribute__((ext_vector_type(4)));
typedef float float8_t __attribute__((ext_vector_type(8)));
typedef float float16_t __attribute__((ext_vector_type(16)));
typedef float float32_t __attribute__((ext_vector_type(32)));
// fp16
typedef _Float16 half_t;
typedef _Float16 half2_t __attribute__((ext_vector_type(2)));
typedef _Float16 half4_t __attribute__((ext_vector_type(4)));
typedef _Float16 half8_t __attribute__((ext_vector_type(8)));
// bfp16
typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
typedef ushort ushort4_t __attribute__((ext_vector_type(4)));
typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
struct c_vec32_4_t
{
union VecType
{
struct
{
float32_t x;
float32_t y;
float32_t z;
float32_t w;
} s;
float n[128];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
c.s.y = 0;
c.s.z = 0;
c.s.w = 0;
return c;
}
};
struct c_vec32_2_t
{
union VecType
{
struct
{
float32_t x;
float32_t y;
} s;
float n[64];
} l;
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
c.s.y = 0;
return c;
}
};
struct c_vec32_2_2_t
{
union VecType
{
struct
{
c_vec32_2_t x;
c_vec32_2_t y;
} s;
float n[128];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x.l.s.x = 0;
c.s.x.l.s.y = 0;
c.s.y.l.s.x = 0;
c.s.y.l.s.y = 0;
return c;
}
};
struct c_vec32_1_t
{
union VecType
{
struct
{
float32_t x;
} s;
float n[32];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
return c;
}
};
struct c_vec16_1_t
{
union VecType
{
struct
{
float16_t x;
} s;
float n[16];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
return c;
}
};
struct c_vec4_2_t
{
union VecType
{
struct
{
float4_t x;
float4_t y;
} s;
float n[8];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
c.s.y = 0;
return c;
}
};
struct c_vec4_1_t
{
union VecType
{
struct
{
float4_t x;
} s;
float n[4];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
return c;
}
};
template <typename T, index_t N> template <typename T, index_t N>
struct vector_type; struct vector_type;
...@@ -183,7 +17,9 @@ struct vector_type<T, 1> ...@@ -183,7 +17,9 @@ struct vector_type<T, 1>
StaticallyIndexedArray<T, 1> d1x1_; StaticallyIndexedArray<T, 1> d1x1_;
} data_; } data_;
__host__ __device__ constexpr vector_type() : data_{T{0}} {} __host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 1; } __host__ __device__ static constexpr index_t Size() { return 1; }
...@@ -215,7 +51,9 @@ struct vector_type<T, 2> ...@@ -215,7 +51,9 @@ struct vector_type<T, 2>
StaticallyIndexedArray<d2_t, 1> d2x1_; StaticallyIndexedArray<d2_t, 1> d2x1_;
} data_; } data_;
__host__ __device__ constexpr vector_type() : data_{d2_t{0}} {} __host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 2; } __host__ __device__ static constexpr index_t Size() { return 2; }
...@@ -253,7 +91,9 @@ struct vector_type<T, 4> ...@@ -253,7 +91,9 @@ struct vector_type<T, 4>
StaticallyIndexedArray<d4_t, 1> d4x1_; StaticallyIndexedArray<d4_t, 1> d4x1_;
} data_; } data_;
__host__ __device__ constexpr vector_type() : data_{d4_t{0}} {} __host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 4; } __host__ __device__ static constexpr index_t Size() { return 4; }
...@@ -297,7 +137,9 @@ struct vector_type<T, 8> ...@@ -297,7 +137,9 @@ struct vector_type<T, 8>
StaticallyIndexedArray<d8_t, 1> d8x1_; StaticallyIndexedArray<d8_t, 1> d8x1_;
} data_; } data_;
__host__ __device__ constexpr vector_type() : data_{d8_t{0}} {} __host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 8; } __host__ __device__ static constexpr index_t Size() { return 8; }
...@@ -326,6 +168,114 @@ struct vector_type<T, 8> ...@@ -326,6 +168,114 @@ struct vector_type<T, 8>
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; } __host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; }
}; };
template <>
struct vector_type<int8_t, 2>
{
using d1_t = int8_t;
typedef int16_t d2_t;
using type = d2_t;
union
{
d2_t d2_;
StaticallyIndexedArray<d1_t, 2> d1x2_;
StaticallyIndexedArray<d2_t, 1> d2x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 2; }
__host__ __device__ constexpr const auto& Vector() const { return data_.d2_; }
__host__ __device__ constexpr auto& Vector() { return data_.d2_; }
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x2_; }
__host__ __device__ constexpr auto& Scalars() { return data_.d1x2_; }
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x2_; }
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x1_; }
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x2_; }
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x1_; }
};
template <>
struct vector_type<int8_t, 4>
{
using d1_t = int8_t;
typedef int16_t d2_t;
typedef int32_t d4_t;
using type = d4_t;
union
{
d4_t d4_;
StaticallyIndexedArray<d1_t, 4> d1x4_;
StaticallyIndexedArray<d2_t, 2> d2x2_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 4; }
__host__ __device__ constexpr const auto& Vector() const { return data_.d4_; }
__host__ __device__ constexpr auto& Vector() { return data_.d4_; }
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x4_; }
__host__ __device__ constexpr auto& Scalars() { return data_.d1x4_; }
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x4_; }
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x2_; }
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x1_; }
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x4_; }
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x2_; }
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x1_; }
};
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
// fp16
using half_t = _Float16;
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
// bfp16
using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type;
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
// i8
// hack for int8x4_t, because compiler does not have native support for int8x4_t
// int8x4_t is defined as int32_t
using int8x4_t = typename vector_type<int8_t, 4>::type;
// data type conversion // data type conversion
template <typename T> template <typename T>
struct type_convert struct type_convert
...@@ -356,113 +306,35 @@ struct inner_product_with_conversion ...@@ -356,113 +306,35 @@ struct inner_product_with_conversion
{ {
static constexpr auto convert = type_convert<T>(); static constexpr auto convert = type_convert<T>();
__device__ T operator()(float4_t a, float4_t b) const template <typename X, index_t N>
{ __device__ T operator()(typename vector_type<X, N>::type a,
const float* p_a_float = reinterpret_cast<const float*>(&a); typename vector_type<X, N>::type b) const
const float* p_b_float = reinterpret_cast<const float*>(&b);
T acc = 0;
for(index_t v = 0; v < 4; ++v)
{
acc += convert(p_a_float[v]) * convert(p_b_float[v]);
}
return acc;
}
__device__ T operator()(float2_t a, float2_t b) const
{ {
const float* p_a_float = reinterpret_cast<const float*>(&a); const vector_type<X, N> a_vector{a};
const float* p_b_float = reinterpret_cast<const float*>(&b); const vector_type<X, N> b_vector{b};
T acc = 0; T acc = 0;
for(index_t v = 0; v < 2; ++v)
{
acc += convert(p_a_float[v]) * convert(p_b_float[v]);
}
return acc; static_for<0, N, 1>{}([&](auto i) {
} acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
});
__device__ T operator()(float a, float b) const { return convert(a) * convert(b); }
__device__ T operator()(half2_t a, half2_t b) const
{
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
T acc = 0;
for(index_t v = 0; v < 2; ++v)
{
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
}
return acc; return acc;
} }
__device__ T operator()(half4_t a, half4_t b) const // hack for int8x4_t, because compiler does not have native support for int8x4_t
// int8x4_t is defined as int32_t
__device__ T operator()(int8x4_t a, int8x4_t b) const
{ {
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a); const vector_type<int8_t, 4> a_vector{a};
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b); const vector_type<int8_t, 4> b_vector{b};
T acc = 0; T acc = 0;
for(index_t v = 0; v < 4; ++v)
{
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
}
return acc;
}
__device__ T operator()(half8_t a, half8_t b) const static_for<0, 4, 1>{}([&](auto i) {
{ acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a); });
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
T acc = 0;
for(index_t v = 0; v < 8; ++v)
{
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
}
return acc;
}
__device__ T operator()(ushort2_t a, ushort2_t b) const
{
const ushort* p_a_bfloat16 = reinterpret_cast<const ushort*>(&a);
const ushort* p_b_bfloat16 = reinterpret_cast<const ushort*>(&b);
T acc = 0;
for(index_t v = 0; v < 2; ++v)
{
acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]);
}
return acc;
}
__device__ T operator()(ushort4_t a, ushort4_t b) const
{
const ushort* p_a_bfloat16 = reinterpret_cast<const ushort*>(&a);
const ushort* p_b_bfloat16 = reinterpret_cast<const ushort*>(&b);
T acc = 0;
for(index_t v = 0; v < 4; ++v)
{
acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]);
}
return acc;
}
__device__ T operator()(ushort8_t a, ushort8_t b) const
{
const ushort* p_a_bfloat16 = reinterpret_cast<const ushort*>(&a);
const ushort* p_b_bfloat16 = reinterpret_cast<const ushort*>(&b);
T acc = 0;
for(index_t v = 0; v < 8; ++v)
{
acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]);
}
return acc; return acc;
} }
}; };
......
...@@ -39,7 +39,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -39,7 +39,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0 #if 1
// run-time variables // run-time variables
const auto in_n_c_hi_wi_desc = const auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
...@@ -368,6 +368,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -368,6 +368,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
#endif #endif
<BlockSize, <BlockSize,
TDevice,
TDevice, TDevice,
TDevice, TDevice,
GemmMPerBlock, GemmMPerBlock,
......
...@@ -3,14 +3,17 @@ ...@@ -3,14 +3,17 @@
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" #include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
template <class T, template <class TInWei,
class TAcc,
class TOut,
class InDesc, class InDesc,
class WeiDesc, class WeiDesc,
class OutDesc, class OutDesc,
class ConvStrides, class ConvStrides,
class ConvDilations, class ConvDilations,
class InLeftPads, class InLeftPads,
class InRightPads> class InRightPads,
class T>
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc, void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc,
const Tensor<T>& in_nchw, const Tensor<T>& in_nchw,
WeiDesc, WeiDesc,
...@@ -28,8 +31,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -28,8 +31,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
using namespace ck; using namespace ck;
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
...@@ -76,11 +77,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -76,11 +77,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif #endif
Tensor<float> in_nhwc( Tensor<TInWei> in_nhwc(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{}))); make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{})));
Tensor<float> wei_kyxc( Tensor<TInWei> wei_kyxc(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{}))); make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{})));
Tensor<float> out_nhwk( Tensor<TOut> out_nhwk(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{}))); make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{})));
auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) { auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) {
...@@ -95,15 +96,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -95,15 +96,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
out_nhwk(n, ho, wo, k) = out_nkhw(n, k, ho, wo); out_nhwk(n, ho, wo, k) = out_nkhw(n, k, ho, wo);
}; };
make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)();
make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)();
make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)();
std::size_t data_sz = sizeof(T);
DeviceMem in_nhwc_device_buf(data_sz * in_nhwc.mDesc.GetElementSpace()); DeviceMem in_nhwc_device_buf(sizeof(TInWei) * in_nhwc.mDesc.GetElementSpace());
DeviceMem wei_kyxc_device_buf(data_sz * wei_kyxc.mDesc.GetElementSpace()); DeviceMem wei_kyxc_device_buf(sizeof(TInWei) * wei_kyxc.mDesc.GetElementSpace());
DeviceMem out_nhwk_device_buf(data_sz * out_nhwk.mDesc.GetElementSpace()); DeviceMem out_nhwk_device_buf(sizeof(TOut) * out_nhwk.mDesc.GetElementSpace());
in_nhwc_device_buf.ToDevice(in_nhwc.mData.data()); in_nhwc_device_buf.ToDevice(in_nhwc.mData.data());
wei_kyxc_device_buf.ToDevice(wei_kyxc.mData.data()); wei_kyxc_device_buf.ToDevice(wei_kyxc.mData.data());
...@@ -378,8 +377,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -378,8 +377,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
#endif #endif
<BlockSize, <BlockSize,
TDevice, TInWei,
TDevice, TAcc,
TOut,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
...@@ -407,9 +407,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -407,9 +407,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
in_right_pads, in_right_pads,
static_cast<TDevice*>(wei_kyxc_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(wei_kyxc_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(in_nhwc_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(in_nhwc_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(out_nhwk_device_buf.GetDeviceBuffer())); static_cast<TOut*>(out_nhwk_device_buf.GetDeviceBuffer()));
out_nhwk_device_buf.FromDevice(out_nhwk.mData.data()); out_nhwk_device_buf.FromDevice(out_nhwk.mData.data());
...@@ -417,5 +417,5 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -417,5 +417,5 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
out_nkhw(n, k, ho, wo) = out_nhwk(n, ho, wo, k); out_nkhw(n, k, ho, wo) = out_nhwk(n, ho, wo, k);
}; };
make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)();
} }
...@@ -158,7 +158,7 @@ struct ParallelTensorFunctor ...@@ -158,7 +158,7 @@ struct ParallelTensorFunctor
return indices; return indices;
} }
void operator()(std::size_t num_thread) const void operator()(std::size_t num_thread = std::thread::hardware_concurrency()) const
{ {
std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread; std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread;
......
...@@ -49,7 +49,7 @@ int main(int argc, char* argv[]) ...@@ -49,7 +49,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 4;
constexpr index_t HI = 270; constexpr index_t HI = 270;
...@@ -644,14 +644,15 @@ int main(int argc, char* argv[]) ...@@ -644,14 +644,15 @@ int main(int argc, char* argv[])
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
if(argc != 3) if(argc != 4)
{ {
printf("arg1: do_verification, arg2: nrepeat\n"); printf("arg1: do_verification, arg2: do_log, arg3: nrepeat\n");
exit(1); exit(1);
} }
bool do_verification = atoi(argv[1]); bool do_verification = atoi(argv[1]);
index_t nrepeat = atoi(argv[2]); bool do_log = atoi(argv[2]);
index_t nrepeat = atoi(argv[3]);
if(do_verification) if(do_verification)
{ {
...@@ -660,13 +661,16 @@ int main(int argc, char* argv[]) ...@@ -660,13 +661,16 @@ int main(int argc, char* argv[])
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 0 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_3{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 0 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 1 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_2{-2, 2}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-2, 2}, num_thread);
#elif 0 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
...@@ -726,7 +730,14 @@ int main(int argc, char* argv[]) ...@@ -726,7 +730,14 @@ int main(int argc, char* argv[])
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1 #elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(in_nchw_desc, #if 0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<float, float, float>(
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<int8x4_t, int32_t, int32_t>(
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<int8x4_t, int32_t, int8_t>(
#endif
in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
wei_kcyx, wei_kcyx,
...@@ -761,11 +772,12 @@ int main(int argc, char* argv[]) ...@@ -761,11 +772,12 @@ int main(int argc, char* argv[])
} }
check_error(out_nkhw_host, out_nkhw_device); check_error(out_nkhw_host, out_nkhw_device);
#if 0 if(do_log)
{
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl; LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl; LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl; LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
#endif }
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment