Unverified Commit 64350aff authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Use __builtin_memcpy to implement bit_cast and for accessing vector from pointer of scalars (#53)

* reworking vector_type

* use __builtin_memcpy for bit_cast and vector access of scalar pointer

* clean up
parent 970fa3e9
...@@ -268,14 +268,14 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -268,14 +268,14 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
const float2_t tmp = llvm_amdgcn_raw_buffer_load_fp32x2( const float2_t tmp = llvm_amdgcn_raw_buffer_load_fp32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<double>(tmp); return bit_cast<double>(tmp);
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
const float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( const float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<double2_t>(tmp); return bit_cast<double2_t>(tmp);
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
...@@ -289,8 +289,8 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -289,8 +289,8 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
0); 0);
vector_type<double, 4> tmp; vector_type<double, 4> tmp;
tmp.AsType<double2_t>()(Number<0>{}) = as_type<double2_t>(f32_0); tmp.AsType<double2_t>()(Number<0>{}) = bit_cast<double2_t>(f32_0);
tmp.AsType<double2_t>()(Number<1>{}) = as_type<double2_t>(f32_1); tmp.AsType<double2_t>()(Number<1>{}) = bit_cast<double2_t>(f32_1);
return tmp.AsType<double4_t>()(Number<0>{}); return tmp.AsType<double4_t>()(Number<0>{});
} }
...@@ -351,7 +351,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -351,7 +351,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<half8_t>(tmp); return bit_cast<half8_t>(tmp);
} }
} }
else if constexpr(is_same<T, ushort>::value) else if constexpr(is_same<T, ushort>::value)
...@@ -376,7 +376,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -376,7 +376,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<ushort8_t>(tmp); return bit_cast<ushort8_t>(tmp);
} }
} }
else if constexpr(is_same<T, int32_t>::value) else if constexpr(is_same<T, int32_t>::value)
...@@ -427,7 +427,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -427,7 +427,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
int16_t tmp = llvm_amdgcn_raw_buffer_load_i16( int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<int8x2_t>(tmp); return bit_cast<int8x2_t>(tmp);
#endif #endif
} }
else if constexpr(N == 4) else if constexpr(N == 4)
...@@ -439,7 +439,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -439,7 +439,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
int32_t tmp = llvm_amdgcn_raw_buffer_load_i32( int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<int8x4_t>(tmp); return bit_cast<int8x4_t>(tmp);
#endif #endif
} }
else if constexpr(N == 8) else if constexpr(N == 8)
...@@ -461,7 +461,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -461,7 +461,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2( int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<int8x8_t>(tmp); return bit_cast<int8x8_t>(tmp);
#endif #endif
} }
else if constexpr(N == 16) else if constexpr(N == 16)
...@@ -495,7 +495,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -495,7 +495,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<int8x16_t>(tmp); return bit_cast<int8x16_t>(tmp);
#endif #endif
} }
} }
...@@ -521,7 +521,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -521,7 +521,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
// use fp32 store to mimic fp64 store // use fp32 store to mimic fp64 store
if constexpr(N == 1) if constexpr(N == 1)
{ {
llvm_amdgcn_raw_buffer_store_fp32x2(as_type<float2_t>(src_thread_data), llvm_amdgcn_raw_buffer_store_fp32x2(bit_cast<float2_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
...@@ -529,7 +529,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -529,7 +529,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
llvm_amdgcn_raw_buffer_store_fp32x4(as_type<float4_t>(src_thread_data), llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<float4_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
...@@ -606,7 +606,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -606,7 +606,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset + 4 * sizeof(half_t), dst_wave_addr_offset + 4 * sizeof(half_t),
0); 0);
#else #else
llvm_amdgcn_raw_buffer_store_fp32x4(as_type<float4_t>(src_thread_data), llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<float4_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
...@@ -703,7 +703,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -703,7 +703,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
#else #else
llvm_amdgcn_raw_buffer_store_i16(as_type<int16_t>(src_thread_data), llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
...@@ -719,7 +719,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -719,7 +719,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
#else #else
llvm_amdgcn_raw_buffer_store_i32(as_type<int32_t>(src_thread_data), llvm_amdgcn_raw_buffer_store_i32(bit_cast<int32_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
...@@ -728,7 +728,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -728,7 +728,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
llvm_amdgcn_raw_buffer_store_i32x2(as_type<int32x2_t>(src_thread_data), llvm_amdgcn_raw_buffer_store_i32x2(bit_cast<int32x2_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
...@@ -736,7 +736,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -736,7 +736,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
} }
else if constexpr(N == 16) else if constexpr(N == 16)
{ {
llvm_amdgcn_raw_buffer_store_i32x4(as_type<int32x4_t>(src_thread_data), llvm_amdgcn_raw_buffer_store_i32x4(bit_cast<int32x4_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
......
...@@ -211,14 +211,14 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0 ...@@ -211,14 +211,14 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0
v_dot4_i32_i8 %1, %2, %4, %1\n \ v_dot4_i32_i8 %1, %2, %4, %1\n \
" "
: "=v"(c0), "=v"(c1) : "=v"(c0), "=v"(c1)
: "v"(as_type<int32_t>(a)), : "v"(bit_cast<int32_t>(a)),
"v"(as_type<int32_t>(b0)), "v"(bit_cast<int32_t>(b0)),
"v"(as_type<int32_t>(b1)), "v"(bit_cast<int32_t>(b1)),
"0"(c0), "0"(c0),
"1"(c1)); "1"(c1));
#else #else
c0 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b0), c0, false); c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b1), c1, false); c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
#endif #endif
} }
...@@ -244,20 +244,20 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a, ...@@ -244,20 +244,20 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
v_dot4_i32_i8 %3, %4, %8, %3\n \ v_dot4_i32_i8 %3, %4, %8, %3\n \
" "
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(as_type<int32_t>(a)), : "v"(bit_cast<int32_t>(a)),
"v"(as_type<int32_t>(b0)), "v"(bit_cast<int32_t>(b0)),
"v"(as_type<int32_t>(b1)), "v"(bit_cast<int32_t>(b1)),
"v"(as_type<int32_t>(b2)), "v"(bit_cast<int32_t>(b2)),
"v"(as_type<int32_t>(b3)), "v"(bit_cast<int32_t>(b3)),
"0"(c0), "0"(c0),
"1"(c1), "1"(c1),
"2"(c2), "2"(c2),
"3"(c3)); "3"(c3));
#else #else
c0 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b0), c0, false); c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b1), c1, false); c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
c2 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b2), c2, false); c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
c3 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b3), c3, false); c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
#endif #endif
} }
......
...@@ -340,8 +340,8 @@ struct intrin_mfma_i32_32x32x8i8<32, 32> ...@@ -340,8 +340,8 @@ struct intrin_mfma_i32_32x32x8i8<32, 32>
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<int32x16_t>()(Number<0>{}) = reg_c.template AsType<int32x16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_i32_32x32x8i8(as_type<int>(reg_a), llvm_intrin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int>(reg_a),
as_type<int>(reg_b), bit_cast<int>(reg_b),
reg_c.template AsType<int32x16_t>()[Number<0>{}], reg_c.template AsType<int32x16_t>()[Number<0>{}],
0, 0,
0, 0,
...@@ -359,8 +359,8 @@ struct intrin_mfma_i32_16x16x16i8<16, 16> ...@@ -359,8 +359,8 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<int32x4_t>()(Number<0>{}) = reg_c.template AsType<int32x4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_i32_16x16x16i8(as_type<int>(reg_a), llvm_intrin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int>(reg_a),
as_type<int>(reg_b), bit_cast<int>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}], reg_c.template AsType<int32x4_t>()[Number<0>{}],
0, 0,
0, 0,
......
...@@ -99,7 +99,19 @@ ...@@ -99,7 +99,19 @@
#define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0 #define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0
// merge transformation use magic number division // merge transformation use magic number division
#ifndef CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1 #define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1
#endif
// use __builtin_memcpy instead of pointer cast to access a vector from pointer of scalar
#ifndef CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0
#endif
// use __builtin_memcpy instead of union to do bit_cast
#ifndef CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST
#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1
#endif
// hack: have underlying assumption that need to be satsified, otherwise it's a bug // hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be // hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
...@@ -119,7 +131,7 @@ ...@@ -119,7 +131,7 @@
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1 #define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1
#endif #endif
// workaround for compiler crash when using buffer load/store for i8 // workaround for compiler gnerating inefficient ds_write instructions
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE #ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 #define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif #endif
......
...@@ -1081,11 +1081,11 @@ struct NumericLimits<half_t> ...@@ -1081,11 +1081,11 @@ struct NumericLimits<half_t>
static constexpr unsigned short binary_max = 0x7BFF; static constexpr unsigned short binary_max = 0x7BFF;
static constexpr unsigned short binary_lowest = 0xFBFF; static constexpr unsigned short binary_lowest = 0xFBFF;
__host__ __device__ static constexpr half_t Min() { return as_type<half_t>(binary_min); } __host__ __device__ static constexpr half_t Min() { return bit_cast<half_t>(binary_min); }
__host__ __device__ static constexpr half_t Max() { return as_type<half_t>(binary_max); } __host__ __device__ static constexpr half_t Max() { return bit_cast<half_t>(binary_max); }
__host__ __device__ static constexpr half_t Lowest() { return as_type<half_t>(binary_lowest); } __host__ __device__ static constexpr half_t Lowest() { return bit_cast<half_t>(binary_lowest); }
}; };
} // namespace ck } // namespace ck
......
...@@ -83,12 +83,28 @@ struct DynamicBuffer ...@@ -83,12 +83,28 @@ struct DynamicBuffer
{ {
if constexpr(InvalidElementUseNumericalZeroValue) if constexpr(InvalidElementUseNumericalZeroValue)
{ {
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp;
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
return is_valid_element ? tmp : X{0};
#else
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0}; return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
#endif
} }
else else
{ {
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp;
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
return is_valid_element ? tmp : X{invalid_element_value_};
#else
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i]) return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i])
: X{invalid_element_value_}; : X{invalid_element_value_};
#endif
} }
} }
} }
...@@ -117,7 +133,13 @@ struct DynamicBuffer ...@@ -117,7 +133,13 @@ struct DynamicBuffer
#else #else
if(is_valid_element) if(is_valid_element)
{ {
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp = x;
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
*c_style_pointer_cast<X*>(&p_data_[i]) = x; *c_style_pointer_cast<X*>(&p_data_[i]) = x;
#endif
} }
#endif #endif
} }
...@@ -126,7 +148,13 @@ struct DynamicBuffer ...@@ -126,7 +148,13 @@ struct DynamicBuffer
if(is_valid_element) if(is_valid_element)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp = x;
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
*c_style_pointer_cast<X*>(&p_data_[i]) = x; *c_style_pointer_cast<X*>(&p_data_[i]) = x;
#endif
#else #else
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into // HACK: compiler would lower IR "store<i8, 16> address_space(3)" into
// inefficient // inefficient
...@@ -201,7 +229,13 @@ struct DynamicBuffer ...@@ -201,7 +229,13 @@ struct DynamicBuffer
} }
else else
{ {
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp = x;
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
*c_style_pointer_cast<X*>(&p_data_[i]) = x; *c_style_pointer_cast<X*>(&p_data_[i]) = x;
#endif
} }
#endif #endif
} }
...@@ -210,7 +244,13 @@ struct DynamicBuffer ...@@ -210,7 +244,13 @@ struct DynamicBuffer
{ {
if(is_valid_element) if(is_valid_element)
{ {
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp = x;
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
*c_style_pointer_cast<X*>(&p_data_[i]) = x; *c_style_pointer_cast<X*>(&p_data_[i]) = x;
#endif
} }
} }
} }
......
...@@ -144,9 +144,9 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, ...@@ -144,9 +144,9 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
v_dot4_i32_i8 %0, %1, %2, %0\n \ v_dot4_i32_i8 %0, %1, %2, %0\n \
" "
: "=v"(c) : "=v"(c)
: "v"(as_type<int32_t>(a)), "v"(as_type<int32_t>(b)), "0"(c)); : "v"(bit_cast<int32_t>(a)), "v"(bit_cast<int32_t>(b)), "0"(c));
#else #else
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false); c = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b), c, false);
#endif #endif
#else #else
const vector_type<int8_t, 4> a_vector{a}; const vector_type<int8_t, 4> a_vector{a};
......
...@@ -125,7 +125,7 @@ struct MagicDivision ...@@ -125,7 +125,7 @@ struct MagicDivision
__host__ __device__ static constexpr int32_t __host__ __device__ static constexpr int32_t
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{ {
uint32_t dividend_u32 = as_type<uint32_t>(dividend_i32); uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = __umulhi(dividend_u32, multiplier); uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift; return (tmp + dividend_u32) >> shift;
} }
......
...@@ -54,5 +54,49 @@ __host__ __device__ constexpr auto make_statically_indexed_array() ...@@ -54,5 +54,49 @@ __host__ __device__ constexpr auto make_statically_indexed_array()
return StaticallyIndexedArray<X, 0>(); return StaticallyIndexedArray<X, 0>();
} }
template <typename T, index_t N>
struct StaticallyIndexedArray_v2
{
__host__ __device__ constexpr StaticallyIndexedArray_v2() = default;
__host__ __device__ static constexpr index_t Size() { return N; }
// read access
template <index_t I>
__host__ __device__ constexpr const auto& At(Number<I>) const
{
static_assert(I < N, "wrong! out of range");
return data_[I];
}
// write access
template <index_t I>
__host__ __device__ constexpr auto& At(Number<I>)
{
static_assert(I < N, "wrong! out of range");
return data_[I];
}
// read access
template <index_t I>
__host__ __device__ constexpr const auto& operator[](Number<I> i) const
{
return At(i);
}
// write access
template <index_t I>
__host__ __device__ constexpr auto& operator()(Number<I> i)
{
return At(i);
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
T data_[N];
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -32,8 +32,15 @@ template <typename T> ...@@ -32,8 +32,15 @@ template <typename T>
inline constexpr bool is_pointer_v = std::is_pointer<T>::value; inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false> template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
__host__ __device__ constexpr Y as_type(X x) __host__ __device__ constexpr Y bit_cast(const X& x)
{ {
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST
Y y;
__builtin_memcpy(&y, &x, sizeof(X));
return y;
#else
union AsType union AsType
{ {
X x; X x;
...@@ -41,6 +48,7 @@ __host__ __device__ constexpr Y as_type(X x) ...@@ -41,6 +48,7 @@ __host__ __device__ constexpr Y as_type(X x)
}; };
return AsType{x}.y; return AsType{x}.y;
#endif
} }
} // namespace ck } // namespace ck
......
...@@ -9,7 +9,6 @@ ...@@ -9,7 +9,6 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "gemm_common.hpp"
#include "host_gemm.hpp" #include "host_gemm.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_base.hpp" #include "device_base.hpp"
...@@ -139,12 +138,12 @@ int main(int argc, char* argv[]) ...@@ -139,12 +138,12 @@ int main(int argc, char* argv[])
{ {
case 0: break; case 0: break;
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
} }
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
......
...@@ -258,7 +258,7 @@ int main(int argc, char* argv[]) ...@@ -258,7 +258,7 @@ int main(int argc, char* argv[])
using in_data_t = half_t; using in_data_t = half_t;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = half_t; using out_data_t = half_t;
#elif 1 #elif 0
using in_data_t = ushort; using in_data_t = ushort;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = ushort; using out_data_t = ushort;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment