Commit 59968d8d authored by Chao Liu's avatar Chao Liu
Browse files

adding int8x4

parent 4f31669f
......@@ -12,8 +12,9 @@
namespace ck {
template <index_t BlockSize,
typename Float,
typename AccFloat,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
......@@ -52,7 +53,7 @@ template <index_t BlockSize,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_mn_v1
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
......@@ -78,17 +79,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const Float* __restrict__ p_a_global,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc,
const Float* __restrict__ p_b_global,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block,
FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
......@@ -144,8 +145,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
Float,
Float,
FloatAB,
FloatAB,
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
ABlockTransferSrcAccessOrder,
......@@ -173,8 +174,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
Float,
Float,
FloatAB,
FloatAB,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
BBlockTransferSrcAccessOrder,
......@@ -235,11 +236,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
Float* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space_size;
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()];
FloatAcc p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()];
// zero out threadwise output
threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread);
......@@ -269,11 +270,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
if constexpr(HasMainKBlockLoop)
{
Float* p_a_block_even = p_a_block_double;
Float* p_b_block_even = p_b_block_double;
FloatAB* p_a_block_even = p_a_block_double;
FloatAB* p_b_block_even = p_b_block_double;
Float* p_a_block_odd = p_a_block_double + a_block_space_size;
Float* p_b_block_odd = p_b_block_double + b_block_space_size;
FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size;
FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size;
index_t k_block_data_begin = 0;
......@@ -400,8 +401,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
ThreadwiseDynamicTensorSliceTransfer_v1r3<
AccFloat,
Float,
FloatAcc,
FloatC,
decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>,
......@@ -429,17 +430,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// pass tensor descriptor by reference
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const Float* __restrict__ p_a_global,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc,
const Float* __restrict__ p_b_global,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ Float p_shared_block[shared_block_size];
__shared__ FloatAB p_shared_block[shared_block_size];
Run(a_k_m_global_desc,
p_a_global,
......@@ -452,14 +453,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
// pass tensor descriptors by their pointers
// pass tensor descriptors by pointers
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_k_m_global_desc,
const Float* __restrict__ p_a_global,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc* p_b_k_n_global_desc,
const Float* __restrict__ p_b_global,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc* p_c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
......@@ -480,11 +481,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// pass tensor descriptors by void*
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_k_m_global_desc,
const Float* __restrict__ p_a_global,
const FloatAB* __restrict__ p_a_global,
const void* p_b_k_n_global_desc,
const Float* __restrict__ p_b_global,
const FloatAB* __restrict__ p_b_global,
const void* p_c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
......
......@@ -182,7 +182,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx +
i * dst_scalar_step_in_vector);
dst_vector.Scalars()(i) = p_src[Number<src_offset>{}];
dst_vector.Scalars()(i) = type_convert<DstData>{}(p_src[Number<src_offset>{}]);
});
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
......
......@@ -75,9 +75,9 @@ struct ThreadwiseGemm_km_kn_mn_v1
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto M = CDesc{}[I0];
constexpr auto N = CDesc{}[I1];
constexpr auto K = ADesc{}[I0];
constexpr auto M = CDesc{}.GetLength(I0);
constexpr auto N = CDesc{}.GetLength(I1);
constexpr auto K = ADesc{}.GetLength(I0);
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) {
......@@ -161,19 +161,7 @@ struct ThreadwiseGemm_km_kn_mn_v1
__device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
constexpr bool has_amd_asm = is_same<FloatC, float>{} &&
((is_same<FloatA, float>{} && is_same<FloatB, float>{}) ||
(is_same<FloatA, half2_t>{} && is_same<FloatB, half2_t>{}) ||
(is_same<FloatA, half4_t>{} && is_same<FloatB, half4_t>{}));
if constexpr(has_amd_asm)
{
Run_amd_asm(p_a, p_b, p_c);
}
else
{
Run_source(p_a, p_b, p_c);
}
#else
Run_source(p_a, p_b, p_c);
#endif
......
......@@ -31,6 +31,7 @@ __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_siz
return wave_buffer_resource.data;
}
// fp32 load
__device__ float
__llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
index_t voffset,
......@@ -49,6 +50,7 @@ __llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
// fp32 store
__device__ void
__llvm_amdgcn_raw_buffer_store_fp32(float vdata,
int32x4_t rsrc,
......@@ -70,213 +72,247 @@ __llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
// buffer_load requires:
// 1) p_src_wave must be in global memory space
// 2) p_src_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t VectorSize>
__device__ typename vector_type<T, VectorSize>::type
amd_buffer_load_v2(const T* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_elemenst_space);
// buffer_store requires:
// 1) p_dst_wave must be global memory
// 2) p_dst_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t VectorSize>
__device__ void amd_buffer_store_v2(const typename vector_type<T, VectorSize>::type src_thread_data,
T* p_dst_wave,
const index_t dst_thread_data_offset,
const bool dst_thread_data_valid,
const index_t dst_data_range);
template <>
__device__ float amd_buffer_load_v2<float, 1>(const float* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_data_range);
// i32 load
__device__ int32_t
__llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
__device__ int32x2_t
__llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
__device__ int32x4_t
__llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
return __llvm_amdgcn_raw_buffer_load_fp32(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0, 0);
#else
float tmp =
__llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource, src_thread_addr_offset, 0, 0);
// i32 store
__device__ void
__llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
return src_thread_data_valid ? tmp : float(0);
#endif
}
// i16 store
__device__ void
__llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
template <>
__device__ float2_t amd_buffer_load_v2<float, 2>(const float* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
template <typename T, index_t N>
__device__ typename vector_type<T, N>::type
amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_data_range);
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
"wrong! not implemented");
if constexpr(is_same<T, float>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_fp32(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 2)
{
return __llvm_amdgcn_raw_buffer_load_fp32x2(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0, 0);
#else
float2_t tmp = __llvm_amdgcn_raw_buffer_load_fp32x2(
src_wave_buffer_resource, src_thread_addr_offset, 0, 0);
return src_thread_data_valid ? tmp : float2_t(0);
#endif
}
template <>
__device__ float4_t amd_buffer_load_v2<float, 4>(const float* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_data_range);
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 4)
{
return __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0, 0);
#else
float4_t tmp = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, 0, 0);
return src_thread_data_valid ? tmp : float4_t(0);
#endif
}
template <>
__device__ float8_t amd_buffer_load_v2<float, 8>(const float* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_data_range);
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 8)
{
vector_type<float, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0, 0);
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource,
src_addr_shift + src_thread_addr_offset + 4 * sizeof(float),
0,
0);
src_wave_buffer_resource, src_thread_addr_offset, 4 * sizeof(float), 0);
return tmp.Vector();
#else
vector_type<float, 8> tmp;
}
}
else if constexpr(is_same<T, int32_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_i32(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 2)
{
return __llvm_amdgcn_raw_buffer_load_i32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 4)
{
return __llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 8)
{
vector_type<int32_t, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, 0, 0);
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset + 4 * sizeof(float), 0, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, 4 * sizeof(int32_t), 0);
return src_thread_data_valid ? tmp.Vector() : float8_t(0);
#endif
return tmp.Vector();
}
}
}
template <>
__device__ void amd_buffer_store_v2<float, 1>(const float src_thread_data,
float* p_dst_wave,
const index_t dst_thread_data_offset,
const bool dst_thread_data_valid,
const index_t dst_data_range)
template <typename T, index_t N>
__device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_data_range);
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 2)),
"wrong! not implemented");
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_raw_buffer_store_fp32(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0, 0);
#else
if(dst_thread_data_valid)
if constexpr(is_same<T, float>::value)
{
__llvm_amdgcn_buffer_store_fp32(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0, 0);
if constexpr(N == 1)
{
__llvm_amdgcn_raw_buffer_store_fp32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
__llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
__llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
else if constexpr(is_same<T, int32_t>::value)
{
if constexpr(N == 1)
{
__llvm_amdgcn_raw_buffer_store_i32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
__llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
__llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
else if constexpr(is_same<T, int8_t>::value)
{
if constexpr(N == 2)
{
__llvm_amdgcn_raw_buffer_store_i16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
#endif
}
template <>
__device__ void amd_buffer_store_v2<float, 2>(const float2_t src_thread_data,
float* p_dst_wave,
const index_t dst_thread_data_offset,
const bool dst_thread_data_valid,
const index_t dst_data_range)
// buffer_load requires:
// 1) p_src_wave must be in global memory space
// 2) p_src_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
__device__ typename vector_type<T, N>::type amd_buffer_load_v2(const T* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_element_space)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_data_range);
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space);
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(T);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_raw_buffer_store_fp32x2(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0, 0);
return amd_buffer_load_impl_v2<T, N>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_raw_buffer_store_fp32x2(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0, 0);
}
using vector_t = typename vector_type<T, N>::type;
vector_t tmp =
amd_buffer_load_impl_v2<T, N>(src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_data_valid ? tmp : vector_t(0);
#endif
}
template <>
__device__ void amd_buffer_store_v2<float, 4>(const float4_t src_thread_data,
float* p_dst_wave,
// buffer_store requires:
// 1) p_dst_wave must be global memory
// 2) p_dst_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
__device__ void amd_buffer_store_v2(const typename vector_type<T, N>::type src_thread_data,
T* p_dst_wave,
const index_t dst_thread_data_offset,
const bool dst_thread_data_valid,
const index_t dst_data_range)
const index_t dst_element_space)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_data_range);
make_wave_buffer_resource(p_dst_wave, dst_element_space);
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(T);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_raw_buffer_store_fp32x4(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0, 0);
amd_buffer_store_impl_v2<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_raw_buffer_store_fp32x4(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0, 0);
amd_buffer_store_impl_v2<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif
}
......
......@@ -5,7 +5,8 @@
namespace ck {
// outer-product: c[i,j] += inner_product(a[i], b[j])
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
{
#if CK_USE_AMD_V_FMAC_F32
......@@ -25,7 +26,10 @@ __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, floa
#endif
}
// outer-product: c[i,j] += inner_product(a[i], b[j])
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(
float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3)
{
......@@ -50,7 +54,8 @@ __device__ void amd_assembly_outer_product_1x4(
#endif
}
// outer-product: c[i,j] += inner_product(a[i], b[j])
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void
amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1)
{
......@@ -58,15 +63,12 @@ amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, flo
v_dot2_f32_f16 %0, %2, %3, %0\n \
v_dot2_f32_f16 %1, %2, %4, %1\n \
"
: "=v"(c0), "=v"(c1) // Dest registers
: "v"(a), // 1st Src register for 1 half2 registers
"v"(b0), // 2nd Src register
"v"(b1),
"0"(c0), // 3rd Src register
"1"(c1));
: "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
}
// outer-product: c[i,j] += inner_product(a[i], b[j])
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
{
......@@ -81,18 +83,21 @@ amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, flo
v_dot2_f32_f16 %0, %3, %5, %0\n \
v_dot2_f32_f16 %1, %3, %7, %1\n \
"
: "=v"(c0), "=v"(c1) // Dest registers
: "=v"(c0), "=v"(c1)
: "v"(p_a_half2[0]),
"v"(p_a_half2[1]), // 1st Src registers for 2 half2 registers
"v"(p_a_half2[1]),
"v"(p_b0_half2[0]),
"v"(p_b0_half2[1]),
"v"(p_b1_half2[0]),
"v"(p_b1_half2[1]), // 2nd Src registers for 2 half2 registers
"v"(p_b1_half2[1]),
"0"(c0),
"1"(c1)); // 3rd Src Acc registers for 2 half2 registers
"1"(c1));
}
// outer-product: c[i,j] += inner_product(a[i], b[j])
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(half2_t a,
half2_t b0,
half2_t b1,
......@@ -109,19 +114,14 @@ __device__ void amd_assembly_outer_product_1x4(half2_t a,
v_dot2_f32_f16 %2, %4, %7, %2\n \
v_dot2_f32_f16 %3, %4, %8, %3\n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // Dest registers
: "v"(a), // 1st Src register for 1 half2 registers
"v"(b0), // 2nd Src register
"v"(b1),
"v"(b2),
"v"(b3),
"0"(c0), // 3rd Src register
"1"(c1),
"2"(c2),
"3"(c3));
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
}
// outer-product: c[i,j] += inner_product(a[i], b[j])
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(half4_t a,
half4_t b0,
half4_t b1,
......@@ -149,21 +149,70 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
v_dot2_f32_f16 %2, %5, %11, %2\n \
v_dot2_f32_f16 %3, %5, %13, %3\n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // Dest registers
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(p_a_half2[0]),
"v"(p_a_half2[1]), // 1st Src registers for 2 half2 registers
"v"(p_a_half2[1]),
"v"(p_b0_half2[0]),
"v"(p_b0_half2[1]),
"v"(p_b1_half2[0]),
"v"(p_b1_half2[1]), // 2nd Src registers for 2 half2 registers
"v"(p_b1_half2[1]),
"v"(p_b2_half2[0]),
"v"(p_b2_half2[1]),
"v"(p_b3_half2[0]),
"v"(p_b3_half2[1]), // 2nd Src registers for 2 half2 registers
"v"(p_b3_half2[1]),
"0"(c0),
"1"(c1),
"2"(c2),
"3"(c3)); // 3rd Src Acc registers for 2 half2 registers
"3"(c3));
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void
amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1)
{
#if 1
asm volatile("\n \
v_dot4_i32_i8 %0, %2, %3, %0\n \
v_dot4_i32_i8 %1, %2, %4, %1\n \
"
: "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
#else
c0 = __builtin_amdgcn_sdot4(a, b0, c0, false);
c1 = __builtin_amdgcn_sdot4(a, b1, c1, false);
#endif
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(int8x4_t a,
int8x4_t b0,
int8x4_t b1,
int8x4_t b2,
int8x4_t b3,
int32_t& c0,
int32_t& c1,
int32_t& c2,
int32_t& c3)
{
#if 1
asm volatile("\n \
v_dot4_i32_i8 %0, %4, %5, %0\n \
v_dot4_i32_i8 %1, %4, %6, %1\n \
v_dot4_i32_i8 %2, %4, %7, %2\n \
v_dot4_i32_i8 %3, %4, %8, %3\n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
#else
c0 = __builtin_amdgcn_sdot4(a, b0, c0, false);
c1 = __builtin_amdgcn_sdot4(a, b1, c1, false);
c2 = __builtin_amdgcn_sdot4(a, b2, c2, false);
c3 = __builtin_amdgcn_sdot4(a, b3, c3, false);
#endif
}
} // namespace ck
......
......@@ -7,7 +7,7 @@
#endif
#include "bfloat16_dev.hpp"
#if 0
#if 1
#define CK_AMD_GPU_GFX906 1
#elif 0
#define CK_AMD_GPU_GFX908 1
......@@ -37,7 +37,7 @@
#endif
#ifndef CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 1
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 0
#endif
#ifndef CK_USE_AMD_V_FMAC_F32
......@@ -140,10 +140,5 @@ enum InMemoryDataOperation
// index type
using index_t = int32_t;
typedef int32_t int32x2_t __attribute__((ext_vector_type(2)));
// int32x4_t use by buffer_load and buffer_store llvm intrinsic
typedef int32_t int32x4_t __attribute__((ext_vector_type(4)));
} // namespace ck
#endif
......@@ -3,172 +3,6 @@
namespace ck {
// For some reason, HIP compiler need this definition to generate optimal ISA
// fp32
typedef float float2_t __attribute__((ext_vector_type(2)));
typedef float float4_t __attribute__((ext_vector_type(4)));
typedef float float8_t __attribute__((ext_vector_type(8)));
typedef float float16_t __attribute__((ext_vector_type(16)));
typedef float float32_t __attribute__((ext_vector_type(32)));
// fp16
typedef _Float16 half_t;
typedef _Float16 half2_t __attribute__((ext_vector_type(2)));
typedef _Float16 half4_t __attribute__((ext_vector_type(4)));
typedef _Float16 half8_t __attribute__((ext_vector_type(8)));
// bfp16
typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
typedef ushort ushort4_t __attribute__((ext_vector_type(4)));
typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
struct c_vec32_4_t
{
union VecType
{
struct
{
float32_t x;
float32_t y;
float32_t z;
float32_t w;
} s;
float n[128];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
c.s.y = 0;
c.s.z = 0;
c.s.w = 0;
return c;
}
};
struct c_vec32_2_t
{
union VecType
{
struct
{
float32_t x;
float32_t y;
} s;
float n[64];
} l;
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
c.s.y = 0;
return c;
}
};
struct c_vec32_2_2_t
{
union VecType
{
struct
{
c_vec32_2_t x;
c_vec32_2_t y;
} s;
float n[128];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x.l.s.x = 0;
c.s.x.l.s.y = 0;
c.s.y.l.s.x = 0;
c.s.y.l.s.y = 0;
return c;
}
};
struct c_vec32_1_t
{
union VecType
{
struct
{
float32_t x;
} s;
float n[32];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
return c;
}
};
struct c_vec16_1_t
{
union VecType
{
struct
{
float16_t x;
} s;
float n[16];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
return c;
}
};
struct c_vec4_2_t
{
union VecType
{
struct
{
float4_t x;
float4_t y;
} s;
float n[8];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
c.s.y = 0;
return c;
}
};
struct c_vec4_1_t
{
union VecType
{
struct
{
float4_t x;
} s;
float n[4];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
return c;
}
};
template <typename T, index_t N>
struct vector_type;
......@@ -183,7 +17,9 @@ struct vector_type<T, 1>
StaticallyIndexedArray<T, 1> d1x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{T{0}} {}
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 1; }
......@@ -215,7 +51,9 @@ struct vector_type<T, 2>
StaticallyIndexedArray<d2_t, 1> d2x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{d2_t{0}} {}
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 2; }
......@@ -253,7 +91,9 @@ struct vector_type<T, 4>
StaticallyIndexedArray<d4_t, 1> d4x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{d4_t{0}} {}
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 4; }
......@@ -297,7 +137,9 @@ struct vector_type<T, 8>
StaticallyIndexedArray<d8_t, 1> d8x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{d8_t{0}} {}
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 8; }
......@@ -326,6 +168,114 @@ struct vector_type<T, 8>
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; }
};
template <>
struct vector_type<int8_t, 2>
{
using d1_t = int8_t;
typedef int16_t d2_t;
using type = d2_t;
union
{
d2_t d2_;
StaticallyIndexedArray<d1_t, 2> d1x2_;
StaticallyIndexedArray<d2_t, 1> d2x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 2; }
__host__ __device__ constexpr const auto& Vector() const { return data_.d2_; }
__host__ __device__ constexpr auto& Vector() { return data_.d2_; }
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x2_; }
__host__ __device__ constexpr auto& Scalars() { return data_.d1x2_; }
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x2_; }
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x1_; }
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x2_; }
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x1_; }
};
template <>
struct vector_type<int8_t, 4>
{
using d1_t = int8_t;
typedef int16_t d2_t;
typedef int32_t d4_t;
using type = d4_t;
union
{
d4_t d4_;
StaticallyIndexedArray<d1_t, 4> d1x4_;
StaticallyIndexedArray<d2_t, 2> d2x2_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 4; }
__host__ __device__ constexpr const auto& Vector() const { return data_.d4_; }
__host__ __device__ constexpr auto& Vector() { return data_.d4_; }
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x4_; }
__host__ __device__ constexpr auto& Scalars() { return data_.d1x4_; }
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x4_; }
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x2_; }
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x1_; }
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x4_; }
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x2_; }
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x1_; }
};
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
// fp16
using half_t = _Float16;
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
// bfp16
using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type;
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
// i8
// hack for int8x4_t, because compiler does not have native support for int8x4_t
// int8x4_t is defined as int32_t
using int8x4_t = typename vector_type<int8_t, 4>::type;
// data type conversion
template <typename T>
struct type_convert
......@@ -356,113 +306,35 @@ struct inner_product_with_conversion
{
static constexpr auto convert = type_convert<T>();
__device__ T operator()(float4_t a, float4_t b) const
{
const float* p_a_float = reinterpret_cast<const float*>(&a);
const float* p_b_float = reinterpret_cast<const float*>(&b);
T acc = 0;
for(index_t v = 0; v < 4; ++v)
{
acc += convert(p_a_float[v]) * convert(p_b_float[v]);
}
return acc;
}
__device__ T operator()(float2_t a, float2_t b) const
template <typename X, index_t N>
__device__ T operator()(typename vector_type<X, N>::type a,
typename vector_type<X, N>::type b) const
{
const float* p_a_float = reinterpret_cast<const float*>(&a);
const float* p_b_float = reinterpret_cast<const float*>(&b);
const vector_type<X, N> a_vector{a};
const vector_type<X, N> b_vector{b};
T acc = 0;
for(index_t v = 0; v < 2; ++v)
{
acc += convert(p_a_float[v]) * convert(p_b_float[v]);
}
return acc;
}
__device__ T operator()(float a, float b) const { return convert(a) * convert(b); }
__device__ T operator()(half2_t a, half2_t b) const
{
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
T acc = 0;
for(index_t v = 0; v < 2; ++v)
{
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
}
static_for<0, N, 1>{}([&](auto i) {
acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
});
return acc;
}
__device__ T operator()(half4_t a, half4_t b) const
// hack for int8x4_t, because compiler does not have native support for int8x4_t
// int8x4_t is defined as int32_t
__device__ T operator()(int8x4_t a, int8x4_t b) const
{
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
const vector_type<int8_t, 4> a_vector{a};
const vector_type<int8_t, 4> b_vector{b};
T acc = 0;
for(index_t v = 0; v < 4; ++v)
{
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
}
return acc;
}
__device__ T operator()(half8_t a, half8_t b) const
{
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
static_for<0, 4, 1>{}([&](auto i) {
acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
});
T acc = 0;
for(index_t v = 0; v < 8; ++v)
{
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
}
return acc;
}
__device__ T operator()(ushort2_t a, ushort2_t b) const
{
const ushort* p_a_bfloat16 = reinterpret_cast<const ushort*>(&a);
const ushort* p_b_bfloat16 = reinterpret_cast<const ushort*>(&b);
T acc = 0;
for(index_t v = 0; v < 2; ++v)
{
acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]);
}
return acc;
}
__device__ T operator()(ushort4_t a, ushort4_t b) const
{
const ushort* p_a_bfloat16 = reinterpret_cast<const ushort*>(&a);
const ushort* p_b_bfloat16 = reinterpret_cast<const ushort*>(&b);
T acc = 0;
for(index_t v = 0; v < 4; ++v)
{
acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]);
}
return acc;
}
__device__ T operator()(ushort8_t a, ushort8_t b) const
{
const ushort* p_a_bfloat16 = reinterpret_cast<const ushort*>(&a);
const ushort* p_b_bfloat16 = reinterpret_cast<const ushort*>(&b);
T acc = 0;
for(index_t v = 0; v < 8; ++v)
{
acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]);
}
return acc;
}
};
......
......@@ -39,7 +39,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0
#if 1
// run-time variables
const auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
......@@ -368,6 +368,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
#endif
<BlockSize,
TDevice,
TDevice,
TDevice,
GemmMPerBlock,
......
......@@ -3,14 +3,17 @@
#include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
template <class T,
template <class TInWei,
class TAcc,
class TOut,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations,
class InLeftPads,
class InRightPads>
class InRightPads,
class T>
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
......@@ -28,8 +31,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
using namespace ck;
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
......@@ -76,11 +77,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif
Tensor<float> in_nhwc(
Tensor<TInWei> in_nhwc(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{})));
Tensor<float> wei_kyxc(
Tensor<TInWei> wei_kyxc(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{})));
Tensor<float> out_nhwk(
Tensor<TOut> out_nhwk(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{})));
auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) {
......@@ -95,15 +96,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
out_nhwk(n, ho, wo, k) = out_nkhw(n, k, ho, wo);
};
make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)(std::thread::hardware_concurrency());
std::size_t data_sz = sizeof(T);
make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)();
make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)();
make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)();
DeviceMem in_nhwc_device_buf(data_sz * in_nhwc.mDesc.GetElementSpace());
DeviceMem wei_kyxc_device_buf(data_sz * wei_kyxc.mDesc.GetElementSpace());
DeviceMem out_nhwk_device_buf(data_sz * out_nhwk.mDesc.GetElementSpace());
DeviceMem in_nhwc_device_buf(sizeof(TInWei) * in_nhwc.mDesc.GetElementSpace());
DeviceMem wei_kyxc_device_buf(sizeof(TInWei) * wei_kyxc.mDesc.GetElementSpace());
DeviceMem out_nhwk_device_buf(sizeof(TOut) * out_nhwk.mDesc.GetElementSpace());
in_nhwc_device_buf.ToDevice(in_nhwc.mData.data());
wei_kyxc_device_buf.ToDevice(wei_kyxc.mData.data());
......@@ -378,8 +377,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
#endif
<BlockSize,
TDevice,
TDevice,
TInWei,
TAcc,
TOut,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
......@@ -407,9 +407,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
conv_dilations,
in_left_pads,
in_right_pads,
static_cast<TDevice*>(wei_kyxc_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(in_nhwc_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(out_nhwk_device_buf.GetDeviceBuffer()));
static_cast<TInWei*>(wei_kyxc_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_nhwc_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_nhwk_device_buf.GetDeviceBuffer()));
out_nhwk_device_buf.FromDevice(out_nhwk.mData.data());
......@@ -417,5 +417,5 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
out_nkhw(n, k, ho, wo) = out_nhwk(n, ho, wo, k);
};
make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)();
}
......@@ -158,7 +158,7 @@ struct ParallelTensorFunctor
return indices;
}
void operator()(std::size_t num_thread) const
void operator()(std::size_t num_thread = std::thread::hardware_concurrency()) const
{
std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread;
......
......@@ -49,7 +49,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
#elif 1
constexpr index_t N = 1;
constexpr index_t C = 4;
constexpr index_t HI = 270;
......@@ -644,14 +644,15 @@ int main(int argc, char* argv[])
std::size_t num_thread = std::thread::hardware_concurrency();
if(argc != 3)
if(argc != 4)
{
printf("arg1: do_verification, arg2: nrepeat\n");
printf("arg1: do_verification, arg2: do_log, arg3: nrepeat\n");
exit(1);
}
bool do_verification = atoi(argv[1]);
index_t nrepeat = atoi(argv[2]);
bool do_log = atoi(argv[2]);
index_t nrepeat = atoi(argv[3]);
if(do_verification)
{
......@@ -660,13 +661,16 @@ int main(int argc, char* argv[])
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_3{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 1
#elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_2{-2, 2}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-2, 2}, num_thread);
#elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
......@@ -726,7 +730,14 @@ int main(int argc, char* argv[])
RightPads{},
nrepeat);
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(in_nchw_desc,
#if 0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<float, float, float>(
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<int8x4_t, int32_t, int32_t>(
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<int8x4_t, int32_t, int8_t>(
#endif
in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
......@@ -761,11 +772,12 @@ int main(int argc, char* argv[])
}
check_error(out_nkhw_host, out_nkhw_device);
#if 0
if(do_log)
{
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
#endif
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment