Commit 8eaed8b3 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/support_amd_buffer_glc_slc' into stream-k-initial-impl

parents ad82c377 790467d6
...@@ -286,7 +286,18 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, ...@@ -286,7 +286,18 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int soffset, // dst_wave_addr_offset int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
template <typename T, index_t N> // memory coherency bit for buffer store/load instruction
enum struct amd_buffer_coherence_bits
{
default_coherence = 0, // default value
glc = 1,
slc = 2,
glc_slc = 3,
};
template <typename T,
index_t N,
amd_buffer_coherence_bits coherence = amd_buffer_coherence_bits::default_coherence>
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset, index_t src_thread_addr_offset,
index_t src_wave_addr_offset) index_t src_wave_addr_offset)
...@@ -305,28 +316,37 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -305,28 +316,37 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
// use fp32 load to mimic fp64 load // use fp32 load to mimic fp64 load
if constexpr(N == 1) if constexpr(N == 1)
{ {
const float2_t tmp = llvm_amdgcn_raw_buffer_load_fp32x2( const float2_t tmp =
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<double>(tmp); return bit_cast<double>(tmp);
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
const float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( const float4_t tmp =
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<double2_t>(tmp); return bit_cast<double2_t>(tmp);
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
const float4_t f32_0 = llvm_amdgcn_raw_buffer_load_fp32x4( const float4_t f32_0 =
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
const float4_t f32_1 = const float4_t f32_1 =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float), src_wave_addr_offset + 4 * sizeof(float),
0); static_cast<index_t>(coherence));
vector_type<double, 4> tmp; vector_type<double, 4> tmp;
tmp.AsType<double2_t>()(Number<0>{}) = bit_cast<double2_t>(f32_0); tmp.AsType<double2_t>()(Number<0>{}) = bit_cast<double2_t>(f32_0);
...@@ -339,31 +359,40 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -339,31 +359,40 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
return llvm_amdgcn_raw_buffer_load_fp32( return llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
return llvm_amdgcn_raw_buffer_load_fp32x2( return llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
return llvm_amdgcn_raw_buffer_load_fp32x4( return llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
vector_type<float, 8> tmp; vector_type<float, 8> tmp;
tmp.AsType<float4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp32x4( tmp.AsType<float4_t>()(Number<0>{}) =
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.AsType<float4_t>()(Number<1>{}) = tmp.AsType<float4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float), src_wave_addr_offset + 4 * sizeof(float),
0); static_cast<index_t>(coherence));
return tmp.AsType<float8_t>()(Number<0>{}); return tmp.AsType<float8_t>()(Number<0>{});
} }
...@@ -372,24 +401,32 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -372,24 +401,32 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
return llvm_amdgcn_raw_buffer_load_fp16( return llvm_amdgcn_raw_buffer_load_fp16(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
return llvm_amdgcn_raw_buffer_load_fp16x2( return llvm_amdgcn_raw_buffer_load_fp16x2(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
return llvm_amdgcn_raw_buffer_load_fp16x4( return llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
// use fp32 load to mimic fp16 load // use fp32 load to mimic fp16 load
float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<half8_t>(tmp); return bit_cast<half8_t>(tmp);
} }
...@@ -398,23 +435,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -398,23 +435,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
return llvm_amdgcn_raw_buffer_load_i16( return llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
return llvm_amdgcn_raw_buffer_load_i16x2( return llvm_amdgcn_raw_buffer_load_i16x2(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
return llvm_amdgcn_raw_buffer_load_i16x4( return llvm_amdgcn_raw_buffer_load_i16x4(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<bhalf8_t>(tmp); return bit_cast<bhalf8_t>(tmp);
} }
...@@ -423,31 +468,40 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -423,31 +468,40 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
return llvm_amdgcn_raw_buffer_load_i32( return llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
return llvm_amdgcn_raw_buffer_load_i32x2( return llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
return llvm_amdgcn_raw_buffer_load_i32x4( return llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
vector_type<int32_t, 8> tmp; vector_type<int32_t, 8> tmp;
tmp.AsType<int32x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i32x4( tmp.AsType<int32x4_t>()(Number<0>{}) =
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.AsType<int32x4_t>()(Number<1>{}) = tmp.AsType<int32x4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t), src_wave_addr_offset + 4 * sizeof(int32_t),
0); static_cast<index_t>(coherence));
return tmp.AsType<int32x8_t>()(Number<0>{}); return tmp.AsType<int32x8_t>()(Number<0>{});
} }
} }
...@@ -455,17 +509,23 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -455,17 +509,23 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
return llvm_amdgcn_raw_buffer_load_i8( return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return llvm_amdgcn_raw_buffer_load_i8x2( return llvm_amdgcn_raw_buffer_load_i8x2(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
#else #else
int16_t tmp = llvm_amdgcn_raw_buffer_load_i16( int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<int8x2_t>(tmp); return bit_cast<int8x2_t>(tmp);
#endif #endif
...@@ -473,11 +533,15 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -473,11 +533,15 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return llvm_amdgcn_raw_buffer_load_i8x4( return llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
#else #else
int32_t tmp = llvm_amdgcn_raw_buffer_load_i32( int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<int8x4_t>(tmp); return bit_cast<int8x4_t>(tmp);
#endif #endif
...@@ -487,19 +551,24 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -487,19 +551,24 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type<int8_t, 8> tmp; vector_type<int8_t, 8> tmp;
tmp.AsType<int8x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4( tmp.AsType<int8x4_t>()(Number<0>{}) =
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.AsType<int8x4_t>()(Number<1>{}) = tmp.AsType<int8x4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int8_t), src_wave_addr_offset + 4 * sizeof(int8_t),
0); static_cast<index_t>(coherence));
return tmp.AsType<int8x8_t>()(Number<0>{}); return tmp.AsType<int8x8_t>()(Number<0>{});
#else #else
int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2( int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<int8x8_t>(tmp); return bit_cast<int8x8_t>(tmp);
#endif #endif
...@@ -509,31 +578,36 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -509,31 +578,36 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type<int8_t, 16> tmp; vector_type<int8_t, 16> tmp;
tmp.AsType<int8x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4( tmp.AsType<int8x4_t>()(Number<0>{}) =
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.AsType<int8x4_t>()(Number<1>{}) = tmp.AsType<int8x4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int8_t), src_wave_addr_offset + 4 * sizeof(int8_t),
0); static_cast<index_t>(coherence));
tmp.AsType<int8x4_t>()(Number<2>{}) = tmp.AsType<int8x4_t>()(Number<2>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 8 * sizeof(int8_t), src_wave_addr_offset + 8 * sizeof(int8_t),
0); static_cast<index_t>(coherence));
tmp.AsType<int8x4_t>()(Number<3>{}) = tmp.AsType<int8x4_t>()(Number<3>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 12 * sizeof(int8_t), src_wave_addr_offset + 12 * sizeof(int8_t),
0); static_cast<index_t>(coherence));
return tmp.AsType<int8x16_t>()(Number<0>{}); return tmp.AsType<int8x16_t>()(Number<0>{});
#else #else
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<int8x16_t>(tmp); return bit_cast<int8x16_t>(tmp);
#endif #endif
...@@ -541,7 +615,9 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -541,7 +615,9 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
} }
} }
template <typename T, index_t N> template <typename T,
index_t N,
amd_buffer_coherence_bits coherence = amd_buffer_coherence_bits::default_coherence>
__device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data, __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource, int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset, index_t dst_thread_addr_offset,
...@@ -565,7 +641,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -565,7 +641,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
...@@ -573,7 +649,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -573,7 +649,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
} }
else if constexpr(is_same<T, float>::value) else if constexpr(is_same<T, float>::value)
...@@ -584,7 +660,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -584,7 +660,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
...@@ -592,7 +668,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -592,7 +668,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
...@@ -600,7 +676,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -600,7 +676,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
...@@ -625,7 +701,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -625,7 +701,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
...@@ -633,7 +709,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -633,7 +709,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
...@@ -641,7 +717,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -641,7 +717,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
...@@ -652,19 +728,19 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -652,19 +728,19 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}], llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t), dst_wave_addr_offset + 4 * sizeof(half_t),
0); static_cast<index_t>(coherence));
#else #else
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<float4_t>(src_thread_data), llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<float4_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
#endif #endif
} }
} }
...@@ -676,7 +752,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -676,7 +752,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
...@@ -684,7 +760,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -684,7 +760,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
...@@ -692,7 +768,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -692,7 +768,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
...@@ -702,13 +778,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -702,13 +778,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<1>{}], llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<1>{}],
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(bhalf_t), dst_wave_addr_offset + 4 * sizeof(bhalf_t),
0); static_cast<index_t>(coherence));
} }
} }
else if constexpr(is_same<T, int32_t>::value) else if constexpr(is_same<T, int32_t>::value)
...@@ -719,7 +795,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -719,7 +795,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
...@@ -727,7 +803,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -727,7 +803,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
...@@ -735,7 +811,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -735,7 +811,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
} }
else if constexpr(is_same<T, int8_t>::value) else if constexpr(is_same<T, int8_t>::value)
...@@ -746,7 +822,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -746,7 +822,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
...@@ -755,13 +831,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -755,13 +831,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
#else #else
llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data), llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
#endif #endif
} }
else if constexpr(N == 4) else if constexpr(N == 4)
...@@ -771,13 +847,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -771,13 +847,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
#else #else
llvm_amdgcn_raw_buffer_store_i32(bit_cast<int32_t>(src_thread_data), llvm_amdgcn_raw_buffer_store_i32(bit_cast<int32_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
#endif #endif
} }
else if constexpr(N == 8) else if constexpr(N == 8)
...@@ -786,7 +862,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -786,7 +862,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 16) else if constexpr(N == 16)
{ {
...@@ -794,7 +870,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -794,7 +870,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
} }
} }
...@@ -1026,7 +1102,9 @@ __device__ void amd_buffer_atomic_max_impl(const typename vector_type<T, N>::typ ...@@ -1026,7 +1102,9 @@ __device__ void amd_buffer_atomic_max_impl(const typename vector_type<T, N>::typ
// 1) p_src_wave must point to global memory space // 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer. // 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T,
index_t N,
amd_buffer_coherence_bits coherence = amd_buffer_coherence_bits::default_coherence>
__device__ typename vector_type_maker<T, N>::type::type __device__ typename vector_type_maker<T, N>::type::type
amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
index_t src_thread_element_offset, index_t src_thread_element_offset,
...@@ -1046,10 +1124,10 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1046,10 +1124,10 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
return amd_buffer_load_impl<scalar_t, vector_size>( return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#else #else
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>( vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(0); return src_thread_element_valid ? tmp : vector_t(0);
...@@ -1060,7 +1138,9 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1060,7 +1138,9 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
// 1) p_src_wave must point to global memory space // 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer. // 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T,
index_t N,
amd_buffer_coherence_bits coherence = amd_buffer_coherence_bits::default_coherence>
__device__ typename vector_type_maker<T, N>::type::type __device__ typename vector_type_maker<T, N>::type::type
amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
index_t src_thread_element_offset, index_t src_thread_element_offset,
...@@ -1078,7 +1158,7 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, ...@@ -1078,7 +1158,7 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>( vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(customized_value); return src_thread_element_valid ? tmp : vector_t(customized_value);
...@@ -1088,7 +1168,9 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, ...@@ -1088,7 +1168,9 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
// 1) p_dst_wave must point to global memory // 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer. // 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T,
index_t N,
amd_buffer_coherence_bits coherence = amd_buffer_coherence_bits::default_coherence>
__device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data, __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
T* p_dst_wave, T* p_dst_wave,
const index_t dst_thread_element_offset, const index_t dst_thread_element_offset,
...@@ -1107,12 +1189,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1107,12 +1189,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_store_impl<scalar_t, vector_size>( amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else #else
if(dst_thread_element_valid) if(dst_thread_element_valid)
{ {
amd_buffer_store_impl<scalar_t, vector_size>( amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
} }
#endif #endif
......
...@@ -19,7 +19,8 @@ namespace ck { ...@@ -19,7 +19,8 @@ namespace ck {
template <AddressSpaceEnum BufferAddressSpace, template <AddressSpaceEnum BufferAddressSpace,
typename T, typename T,
typename ElementSpaceSize, typename ElementSpaceSize,
bool InvalidElementUseNumericalZeroValue> bool InvalidElementUseNumericalZeroValue,
amd_buffer_coherence_bits coherence = amd_buffer_coherence_bits::default_coherence>
struct DynamicBuffer struct DynamicBuffer
{ {
using type = T; using type = T;
...@@ -77,13 +78,16 @@ struct DynamicBuffer ...@@ -77,13 +78,16 @@ struct DynamicBuffer
if constexpr(InvalidElementUseNumericalZeroValue) if constexpr(InvalidElementUseNumericalZeroValue)
{ {
return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>, t_per_x>( return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>,
t_per_x,
coherence>(
p_data_, i, is_valid_element, element_space_size_); p_data_, i, is_valid_element, element_space_size_);
} }
else else
{ {
return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>, return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
t_per_x>( t_per_x,
coherence>(
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
} }
} }
...@@ -173,7 +177,7 @@ struct DynamicBuffer ...@@ -173,7 +177,7 @@ struct DynamicBuffer
{ {
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cvref_t<T>, t_per_x>( amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
x, p_data_, i, is_valid_element, element_space_size_); x, p_data_, i, is_valid_element, element_space_size_);
} }
else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds && else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds &&
...@@ -376,14 +380,19 @@ struct DynamicBuffer ...@@ -376,14 +380,19 @@ struct DynamicBuffer
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
}; };
template <AddressSpaceEnum BufferAddressSpace, typename T, typename ElementSpaceSize> template <AddressSpaceEnum BufferAddressSpace,
amd_buffer_coherence_bits coherence = amd_buffer_coherence_bits::default_coherence,
typename T,
typename ElementSpaceSize>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
{ {
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size}; return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true, coherence>{
p, element_space_size};
} }
template < template <
AddressSpaceEnum BufferAddressSpace, AddressSpaceEnum BufferAddressSpace,
amd_buffer_coherence_bits coherence = amd_buffer_coherence_bits::default_coherence,
typename T, typename T,
typename ElementSpaceSize, typename ElementSpaceSize,
typename X, typename X,
...@@ -391,7 +400,7 @@ template < ...@@ -391,7 +400,7 @@ template <
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value) make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value)
{ {
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, false>{ return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, false, coherence>{
p, element_space_size, invalid_element_value}; p, element_space_size, invalid_element_value};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment