Commit 899df971 authored by Chao Liu's avatar Chao Liu
Browse files

refactor buffer resource

parent a915f574
...@@ -81,13 +81,13 @@ struct GridwiseConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk ...@@ -81,13 +81,13 @@ struct GridwiseConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk
unfold_tensor_descriptor(wei_k_y_x_c_global_desc, I1, I3), Sequence<1, 0>{}); unfold_tensor_descriptor(wei_k_y_x_c_global_desc, I1, I3), Sequence<1, 0>{});
// input tensor // input tensor
constexpr auto in_n_hip_wip_c_global_desc = transform_tensor_descriptor( constexpr auto in_n_hip_wip_c_global_desc =
in_n_hi_wi_c_global_desc, transform_tensor_descriptor(in_n_hi_wi_c_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{}, Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{},
PassThrough<C>{}), PassThrough<C>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr index_t Hip = in_n_hip_wip_c_global_desc.GetLengths()[I1]; constexpr index_t Hip = in_n_hip_wip_c_global_desc.GetLengths()[I1];
constexpr index_t Wip = in_n_hip_wip_c_global_desc.GetLengths()[I2]; constexpr index_t Wip = in_n_hip_wip_c_global_desc.GetLengths()[I2];
...@@ -108,11 +108,11 @@ struct GridwiseConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk ...@@ -108,11 +108,11 @@ struct GridwiseConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor // output tensor
constexpr auto out_gemmm_gemmn_global_desc = constexpr auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_ho_wo_k_global_desc, I0, I2), unfold_tensor_descriptor(out_n_ho_wo_k_global_desc, I0, I2),
make_tuple(PassThrough<K>{}, Merge<Sequence<N * Ho * Wo>>{}), make_tuple(PassThrough<K>{}, Merge<Sequence<N * Ho * Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM // GEMM
constexpr auto gridwise_gemm = constexpr auto gridwise_gemm =
......
...@@ -201,8 +201,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -201,8 +201,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
#else #else
if(is_dst_valid) if(is_dst_valid)
{ {
*reinterpret_cast<dst_vector_t*>(&(p_dst[dst_slice_origin_coord_.GetOffset()])) = *reinterpret_cast<dst_vector_t*>(
dst_vector.Vector(); &(p_dst[dst_slice_origin_coord_.GetOffset()])) = dst_vector.Vector();
} }
#endif #endif
} }
...@@ -210,8 +210,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -210,8 +210,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
{ {
if(is_dst_valid) if(is_dst_valid)
{ {
*reinterpret_cast<dst_vector_t*>(&(p_dst[dst_slice_origin_coord_.GetOffset()])) = *reinterpret_cast<dst_vector_t*>(
dst_vector.Vector(); &(p_dst[dst_slice_origin_coord_.GetOffset()])) = dst_vector.Vector();
} }
} }
......
...@@ -124,14 +124,14 @@ __device__ float amd_buffer_load<float, 1>(const float* p_src_wave, ...@@ -124,14 +124,14 @@ __device__ float amd_buffer_load<float, 1>(const float* p_src_wave,
bool src_thread_data_valid, bool src_thread_data_valid,
index_t src_data_range) index_t src_data_range)
{ {
BufferResourceConstant<float> src_wave_buffer_resource; BufferResource<float> src_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<float*>(p_src_wave); src_wave_buffer_resource.address[0] = const_cast<float*>(p_src_wave);
// wavewise range (32 bit) // wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(float); src_wave_buffer_resource.range[2] = src_data_range * sizeof(float);
// wavewise setting (32 bit) // wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000; src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
...@@ -154,14 +154,14 @@ __device__ float2_t amd_buffer_load<float, 2>(const float* p_src_wave, ...@@ -154,14 +154,14 @@ __device__ float2_t amd_buffer_load<float, 2>(const float* p_src_wave,
bool src_thread_data_valid, bool src_thread_data_valid,
index_t src_data_range) index_t src_data_range)
{ {
BufferResourceConstant<float> src_wave_buffer_resource; BufferResource<float> src_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<float*>(p_src_wave); src_wave_buffer_resource.address[0] = const_cast<float*>(p_src_wave);
// wavewise range (32 bit) // wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(float); src_wave_buffer_resource.range[2] = src_data_range * sizeof(float);
// wavewise setting (32 bit) // wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000; src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
...@@ -184,14 +184,14 @@ __device__ float4_t amd_buffer_load<float, 4>(const float* p_src_wave, ...@@ -184,14 +184,14 @@ __device__ float4_t amd_buffer_load<float, 4>(const float* p_src_wave,
bool src_thread_data_valid, bool src_thread_data_valid,
index_t src_data_range) index_t src_data_range)
{ {
BufferResourceConstant<float> src_wave_buffer_resource; BufferResource<float> src_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<float*>(p_src_wave); src_wave_buffer_resource.address[0] = const_cast<float*>(p_src_wave);
// wavewise range (32 bit) // wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(float); src_wave_buffer_resource.range[2] = src_data_range * sizeof(float);
// wavewise setting (32 bit) // wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000; src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
...@@ -214,14 +214,14 @@ __device__ half_t amd_buffer_load<half_t, 1>(const half_t* p_src_wave, ...@@ -214,14 +214,14 @@ __device__ half_t amd_buffer_load<half_t, 1>(const half_t* p_src_wave,
bool src_thread_data_valid, bool src_thread_data_valid,
index_t src_data_range) index_t src_data_range)
{ {
BufferResourceConstant<half_t> src_wave_buffer_resource; BufferResource<half_t> src_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<half_t*>(p_src_wave); src_wave_buffer_resource.address[0] = const_cast<half_t*>(p_src_wave);
// wavewise range (32 bit) // wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t); src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t);
// wavewise setting (32 bit) // wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000; src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
...@@ -249,14 +249,14 @@ __device__ half2_t amd_buffer_load<half_t, 2>(const half_t* p_src_wave, ...@@ -249,14 +249,14 @@ __device__ half2_t amd_buffer_load<half_t, 2>(const half_t* p_src_wave,
bool src_thread_data_valid, bool src_thread_data_valid,
index_t src_data_range) index_t src_data_range)
{ {
BufferResourceConstant<half_t> src_wave_buffer_resource; BufferResource<half_t> src_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<half_t*>(p_src_wave); src_wave_buffer_resource.address[0] = const_cast<half_t*>(p_src_wave);
// wavewise range (32 bit) // wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t); src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t);
// wavewise setting (32 bit) // wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000; src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
...@@ -283,14 +283,14 @@ __device__ half4_t amd_buffer_load<half_t, 4>(const half_t* p_src_wave, ...@@ -283,14 +283,14 @@ __device__ half4_t amd_buffer_load<half_t, 4>(const half_t* p_src_wave,
bool src_thread_data_valid, bool src_thread_data_valid,
index_t src_data_range) index_t src_data_range)
{ {
BufferResourceConstant<half_t> src_wave_buffer_resource; BufferResource<half_t> src_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<half_t*>(p_src_wave); src_wave_buffer_resource.address[0] = const_cast<half_t*>(p_src_wave);
// wavewise range (32 bit) // wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t); src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t);
// wavewise setting (32 bit) // wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000; src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
...@@ -317,14 +317,14 @@ __device__ half8_t amd_buffer_load<half_t, 8>(const half_t* p_src_wave, ...@@ -317,14 +317,14 @@ __device__ half8_t amd_buffer_load<half_t, 8>(const half_t* p_src_wave,
bool src_thread_data_valid, bool src_thread_data_valid,
index_t src_data_range) index_t src_data_range)
{ {
BufferResourceConstant<half_t> src_wave_buffer_resource; BufferResource<half_t> src_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<half_t*>(p_src_wave); src_wave_buffer_resource.address[0] = const_cast<half_t*>(p_src_wave);
// wavewise range (32 bit) // wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t); src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t);
// wavewise setting (32 bit) // wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000; src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
...@@ -351,14 +351,14 @@ __device__ ushort amd_buffer_load<ushort, 1>(const ushort* p_src_wave, ...@@ -351,14 +351,14 @@ __device__ ushort amd_buffer_load<ushort, 1>(const ushort* p_src_wave,
bool src_thread_data_valid, bool src_thread_data_valid,
index_t src_data_range) index_t src_data_range)
{ {
BufferResourceConstant<ushort> src_wave_buffer_resource; BufferResource<ushort> src_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<ushort*>(p_src_wave); src_wave_buffer_resource.address[0] = const_cast<ushort*>(p_src_wave);
// wavewise range (32 bit) // wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort); src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort);
// wavewise setting (32 bit) // wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000; src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
...@@ -386,14 +386,14 @@ __device__ ushort2_t amd_buffer_load<ushort, 2>(const ushort* p_src_wave, ...@@ -386,14 +386,14 @@ __device__ ushort2_t amd_buffer_load<ushort, 2>(const ushort* p_src_wave,
bool src_thread_data_valid, bool src_thread_data_valid,
index_t src_data_range) index_t src_data_range)
{ {
BufferResourceConstant<ushort> src_wave_buffer_resource; BufferResource<ushort> src_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<ushort*>(p_src_wave); src_wave_buffer_resource.address[0] = const_cast<ushort*>(p_src_wave);
// wavewise range (32 bit) // wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort); src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort);
// wavewise setting (32 bit) // wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000; src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
...@@ -420,14 +420,14 @@ __device__ ushort4_t amd_buffer_load<ushort, 4>(const ushort* p_src_wave, ...@@ -420,14 +420,14 @@ __device__ ushort4_t amd_buffer_load<ushort, 4>(const ushort* p_src_wave,
bool src_thread_data_valid, bool src_thread_data_valid,
index_t src_data_range) index_t src_data_range)
{ {
BufferResourceConstant<ushort> src_wave_buffer_resource; BufferResource<ushort> src_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<ushort*>(p_src_wave); src_wave_buffer_resource.address[0] = const_cast<ushort*>(p_src_wave);
// wavewise range (32 bit) // wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort); src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort);
// wavewise setting (32 bit) // wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000; src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
...@@ -454,14 +454,14 @@ __device__ ushort8_t amd_buffer_load<ushort, 8>(const ushort* p_src_wave, ...@@ -454,14 +454,14 @@ __device__ ushort8_t amd_buffer_load<ushort, 8>(const ushort* p_src_wave,
bool src_thread_data_valid, bool src_thread_data_valid,
index_t src_data_range) index_t src_data_range)
{ {
BufferResourceConstant<ushort> src_wave_buffer_resource; BufferResource<ushort> src_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<ushort*>(p_src_wave); src_wave_buffer_resource.address[0] = const_cast<ushort*>(p_src_wave);
// wavewise range (32 bit) // wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort); src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort);
// wavewise setting (32 bit) // wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000; src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
...@@ -489,14 +489,14 @@ __device__ void amd_buffer_store<float, 1>(const float* p_src_thread, ...@@ -489,14 +489,14 @@ __device__ void amd_buffer_store<float, 1>(const float* p_src_thread,
bool dst_thread_data_valid, bool dst_thread_data_valid,
index_t dst_data_range) index_t dst_data_range)
{ {
BufferResourceConstant<float> dst_wave_buffer_resource; BufferResource<float> dst_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave; dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit) // wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float); dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float);
// wavewise setting (32 bit) // wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000; dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
...@@ -525,14 +525,14 @@ __device__ void amd_buffer_store<float, 2>(const float* p_src_thread, ...@@ -525,14 +525,14 @@ __device__ void amd_buffer_store<float, 2>(const float* p_src_thread,
bool dst_thread_data_valid, bool dst_thread_data_valid,
index_t dst_data_range) index_t dst_data_range)
{ {
BufferResourceConstant<float> dst_wave_buffer_resource; BufferResource<float> dst_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave; dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit) // wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float); dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float);
// wavewise setting (32 bit) // wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000; dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
...@@ -565,14 +565,14 @@ __device__ void amd_buffer_store<float, 4>(const float* p_src_thread, ...@@ -565,14 +565,14 @@ __device__ void amd_buffer_store<float, 4>(const float* p_src_thread,
bool dst_thread_data_valid, bool dst_thread_data_valid,
index_t dst_data_range) index_t dst_data_range)
{ {
BufferResourceConstant<float> dst_wave_buffer_resource; BufferResource<float> dst_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave; dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit) // wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float); dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float);
// wavewise setting (32 bit) // wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000; dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
...@@ -605,14 +605,14 @@ __device__ void amd_buffer_store<half_t, 1>(const half_t* p_src_thread, ...@@ -605,14 +605,14 @@ __device__ void amd_buffer_store<half_t, 1>(const half_t* p_src_thread,
bool dst_thread_data_valid, bool dst_thread_data_valid,
index_t dst_data_range) index_t dst_data_range)
{ {
BufferResourceConstant<half_t> dst_wave_buffer_resource; BufferResource<half_t> dst_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave; dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit) // wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t); dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t);
// wavewise setting (32 bit) // wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000; dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t);
...@@ -644,14 +644,14 @@ __device__ void amd_buffer_store<half_t, 2>(const half_t* p_src_thread, ...@@ -644,14 +644,14 @@ __device__ void amd_buffer_store<half_t, 2>(const half_t* p_src_thread,
bool dst_thread_data_valid, bool dst_thread_data_valid,
index_t dst_data_range) index_t dst_data_range)
{ {
BufferResourceConstant<half_t> dst_wave_buffer_resource; BufferResource<half_t> dst_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave; dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit) // wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t); dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t);
// wavewise setting (32 bit) // wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000; dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t);
...@@ -682,14 +682,14 @@ __device__ void amd_buffer_store<half_t, 4>(const half_t* p_src_thread, ...@@ -682,14 +682,14 @@ __device__ void amd_buffer_store<half_t, 4>(const half_t* p_src_thread,
bool dst_thread_data_valid, bool dst_thread_data_valid,
index_t dst_data_range) index_t dst_data_range)
{ {
BufferResourceConstant<half_t> dst_wave_buffer_resource; BufferResource<half_t> dst_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave; dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit) // wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t); dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t);
// wavewise setting (32 bit) // wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000; dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t);
...@@ -720,14 +720,14 @@ __device__ void amd_buffer_store<half_t, 8>(const half_t* p_src_thread, ...@@ -720,14 +720,14 @@ __device__ void amd_buffer_store<half_t, 8>(const half_t* p_src_thread,
bool dst_thread_data_valid, bool dst_thread_data_valid,
index_t dst_data_range) index_t dst_data_range)
{ {
BufferResourceConstant<half_t> dst_wave_buffer_resource; BufferResource<half_t> dst_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave; dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit) // wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t); dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t);
// wavewise setting (32 bit) // wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000; dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t);
...@@ -758,14 +758,14 @@ __device__ void amd_buffer_store<ushort, 1>(const ushort* p_src_thread, ...@@ -758,14 +758,14 @@ __device__ void amd_buffer_store<ushort, 1>(const ushort* p_src_thread,
bool dst_thread_data_valid, bool dst_thread_data_valid,
index_t dst_data_range) index_t dst_data_range)
{ {
BufferResourceConstant<ushort> dst_wave_buffer_resource; BufferResource<ushort> dst_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave; dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit) // wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort); dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort);
// wavewise setting (32 bit) // wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000; dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort);
...@@ -793,14 +793,14 @@ __device__ void amd_buffer_store<ushort, 2>(const ushort* p_src_thread, ...@@ -793,14 +793,14 @@ __device__ void amd_buffer_store<ushort, 2>(const ushort* p_src_thread,
bool dst_thread_data_valid, bool dst_thread_data_valid,
index_t dst_data_range) index_t dst_data_range)
{ {
BufferResourceConstant<ushort> dst_wave_buffer_resource; BufferResource<ushort> dst_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave; dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit) // wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort); dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort);
// wavewise setting (32 bit) // wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000; dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort);
...@@ -831,14 +831,14 @@ __device__ void amd_buffer_store<ushort, 4>(const ushort* p_src_thread, ...@@ -831,14 +831,14 @@ __device__ void amd_buffer_store<ushort, 4>(const ushort* p_src_thread,
bool dst_thread_data_valid, bool dst_thread_data_valid,
index_t dst_data_range) index_t dst_data_range)
{ {
BufferResourceConstant<ushort> dst_wave_buffer_resource; BufferResource<ushort> dst_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave; dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit) // wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort); dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort);
// wavewise setting (32 bit) // wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000; dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort);
...@@ -869,14 +869,14 @@ __device__ void amd_buffer_store<ushort, 8>(const ushort* p_src_thread, ...@@ -869,14 +869,14 @@ __device__ void amd_buffer_store<ushort, 8>(const ushort* p_src_thread,
bool dst_thread_data_valid, bool dst_thread_data_valid,
index_t dst_data_range) index_t dst_data_range)
{ {
BufferResourceConstant<ushort> dst_wave_buffer_resource; BufferResource<ushort> dst_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave; dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit) // wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort); dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort);
// wavewise setting (32 bit) // wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000; dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort);
...@@ -908,14 +908,14 @@ __device__ void amd_buffer_atomic_add<float, 1>(const float* p_src_thread, ...@@ -908,14 +908,14 @@ __device__ void amd_buffer_atomic_add<float, 1>(const float* p_src_thread,
bool dst_thread_data_valid, bool dst_thread_data_valid,
index_t dst_data_range) index_t dst_data_range)
{ {
BufferResourceConstant<float> dst_wave_buffer_resource; BufferResource<float> dst_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave; dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit) // wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float); dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float);
// wavewise setting (32 bit) // wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000; dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
...@@ -943,14 +943,14 @@ __device__ void amd_buffer_atomic_add<float, 2>(const float* p_src_thread, ...@@ -943,14 +943,14 @@ __device__ void amd_buffer_atomic_add<float, 2>(const float* p_src_thread,
bool dst_thread_data_valid, bool dst_thread_data_valid,
index_t dst_data_range) index_t dst_data_range)
{ {
BufferResourceConstant<float> dst_wave_buffer_resource; BufferResource<float> dst_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave; dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit) // wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range; dst_wave_buffer_resource.range[2] = dst_data_range;
// wavewise setting (32 bit) // wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000; dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
...@@ -988,14 +988,14 @@ __device__ void amd_buffer_atomic_add<float, 4>(const float* p_src_thread, ...@@ -988,14 +988,14 @@ __device__ void amd_buffer_atomic_add<float, 4>(const float* p_src_thread,
bool dst_thread_data_valid, bool dst_thread_data_valid,
index_t dst_data_range) index_t dst_data_range)
{ {
BufferResourceConstant<float> dst_wave_buffer_resource; BufferResource<float> dst_wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave; dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit) // wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float); dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float);
// wavewise setting (32 bit) // wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000; dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
......
...@@ -5,17 +5,32 @@ ...@@ -5,17 +5,32 @@
namespace ck { namespace ck {
// For 128 bit SGPRs to supply resource constant in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
template <typename T> template <typename T>
union BufferResourceConstant union BufferResource
{ {
// 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
int32x4_t data; int32x4_t data;
T* address[2]; T* address[2];
int32_t range[4]; int32_t range[4];
int32_t config[4]; int32_t config[4];
}; };
template <typename T>
__device__ auto make_wave_buffer_resource(T* p_wave, index_t data_space_size)
{
BufferResource<T> wave_buffer_resource;
// wavewise base address (64 bit)
wave_buffer_resource.address[0] = const_cast<remove_cv_t<T>*>(p_wave);
// wavewise range (32 bit)
wave_buffer_resource.range[2] = data_space_size * sizeof(T);
// wavewise setting (32 bit)
wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
return wave_buffer_resource.data;
}
__device__ float __device__ float
__llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, __llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
index_t voffset, index_t voffset,
...@@ -83,18 +98,8 @@ __device__ float amd_buffer_load_v2<float, 1>(const float* p_src_wave, ...@@ -83,18 +98,8 @@ __device__ float amd_buffer_load_v2<float, 1>(const float* p_src_wave,
bool src_thread_data_valid, bool src_thread_data_valid,
index_t src_data_range) index_t src_data_range)
{ {
BufferResourceConstant<float> src_wave_buffer_resource; const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_data_range);
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<float*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(float);
// wavewise setting (32 bit)
#if 0
src_wave_buffer_resource.config[3] = 0x00027000;
#else
src_wave_buffer_resource.config[3] = 0x31014000;
#endif
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
...@@ -102,10 +107,10 @@ __device__ float amd_buffer_load_v2<float, 1>(const float* p_src_wave, ...@@ -102,10 +107,10 @@ __device__ float amd_buffer_load_v2<float, 1>(const float* p_src_wave,
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
return __llvm_amdgcn_raw_buffer_load_fp32( return __llvm_amdgcn_raw_buffer_load_fp32(
src_wave_buffer_resource.data, src_addr_shift + src_thread_addr_offset, 0, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0, 0);
#else #else
float tmp = __llvm_amdgcn_raw_buffer_load_fp32( float tmp =
src_wave_buffer_resource.data, src_thread_addr_offset, 0, 0); __llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource, src_thread_addr_offset, 0, 0);
return src_thread_data_valid ? tmp : float(0); return src_thread_data_valid ? tmp : float(0);
#endif #endif
...@@ -117,18 +122,8 @@ __device__ float2_t amd_buffer_load_v2<float, 2>(const float* p_src_wave, ...@@ -117,18 +122,8 @@ __device__ float2_t amd_buffer_load_v2<float, 2>(const float* p_src_wave,
bool src_thread_data_valid, bool src_thread_data_valid,
index_t src_data_range) index_t src_data_range)
{ {
BufferResourceConstant<float> src_wave_buffer_resource; const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_data_range);
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<float*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(float);
// wavewise setting (32 bit)
#if 0
src_wave_buffer_resource.config[3] = 0x00027000;
#else
src_wave_buffer_resource.config[3] = 0x31014000;
#endif
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
...@@ -136,10 +131,10 @@ __device__ float2_t amd_buffer_load_v2<float, 2>(const float* p_src_wave, ...@@ -136,10 +131,10 @@ __device__ float2_t amd_buffer_load_v2<float, 2>(const float* p_src_wave,
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
return __llvm_amdgcn_raw_buffer_load_fp32x2( return __llvm_amdgcn_raw_buffer_load_fp32x2(
src_wave_buffer_resource.data, src_addr_shift + src_thread_addr_offset, 0, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0, 0);
#else #else
float2_t tmp = __llvm_amdgcn_raw_buffer_load_fp32x2( float2_t tmp = __llvm_amdgcn_raw_buffer_load_fp32x2(
src_wave_buffer_resource.data, src_thread_addr_offset, 0, 0); src_wave_buffer_resource, src_thread_addr_offset, 0, 0);
return src_thread_data_valid ? tmp : float2_t(0); return src_thread_data_valid ? tmp : float2_t(0);
#endif #endif
...@@ -151,18 +146,8 @@ __device__ float4_t amd_buffer_load_v2<float, 4>(const float* p_src_wave, ...@@ -151,18 +146,8 @@ __device__ float4_t amd_buffer_load_v2<float, 4>(const float* p_src_wave,
bool src_thread_data_valid, bool src_thread_data_valid,
index_t src_data_range) index_t src_data_range)
{ {
BufferResourceConstant<float> src_wave_buffer_resource; const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_data_range);
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<float*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(float);
// wavewise setting (32 bit)
#if 0
src_wave_buffer_resource.config[3] = 0x00027000;
#else
src_wave_buffer_resource.config[3] = 0x31014000;
#endif
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
...@@ -170,10 +155,10 @@ __device__ float4_t amd_buffer_load_v2<float, 4>(const float* p_src_wave, ...@@ -170,10 +155,10 @@ __device__ float4_t amd_buffer_load_v2<float, 4>(const float* p_src_wave,
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
return __llvm_amdgcn_raw_buffer_load_fp32x4( return __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource.data, src_addr_shift + src_thread_addr_offset, 0, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0, 0);
#else #else
float4_t tmp = __llvm_amdgcn_raw_buffer_load_fp32x4( float4_t tmp = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource.data, src_thread_addr_offset, 0, 0); src_wave_buffer_resource, src_thread_addr_offset, 0, 0);
return src_thread_data_valid ? tmp : float4_t(0); return src_thread_data_valid ? tmp : float4_t(0);
#endif #endif
...@@ -185,18 +170,8 @@ __device__ float8_t amd_buffer_load_v2<float, 8>(const float* p_src_wave, ...@@ -185,18 +170,8 @@ __device__ float8_t amd_buffer_load_v2<float, 8>(const float* p_src_wave,
bool src_thread_data_valid, bool src_thread_data_valid,
index_t src_data_range) index_t src_data_range)
{ {
BufferResourceConstant<float> src_wave_buffer_resource; const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_data_range);
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<float*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(float);
// wavewise setting (32 bit)
#if 0
src_wave_buffer_resource.config[3] = 0x00027000;
#else
src_wave_buffer_resource.config[3] = 0x31014000;
#endif
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
...@@ -206,10 +181,10 @@ __device__ float8_t amd_buffer_load_v2<float, 8>(const float* p_src_wave, ...@@ -206,10 +181,10 @@ __device__ float8_t amd_buffer_load_v2<float, 8>(const float* p_src_wave,
vector_type<float, 8> tmp; vector_type<float, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource.data, src_addr_shift + src_thread_addr_offset, 0, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource.data, src_wave_buffer_resource,
src_addr_shift + src_thread_addr_offset + 4 * sizeof(float), src_addr_shift + src_thread_addr_offset + 4 * sizeof(float),
0, 0,
0); 0);
...@@ -219,10 +194,10 @@ __device__ float8_t amd_buffer_load_v2<float, 8>(const float* p_src_wave, ...@@ -219,10 +194,10 @@ __device__ float8_t amd_buffer_load_v2<float, 8>(const float* p_src_wave,
vector_type<float, 8> tmp; vector_type<float, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource.data, src_thread_addr_offset, 0, 0); src_wave_buffer_resource, src_thread_addr_offset, 0, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource.data, src_thread_addr_offset + 4 * sizeof(float), 0, 0); src_wave_buffer_resource, src_thread_addr_offset + 4 * sizeof(float), 0, 0);
return src_thread_data_valid ? tmp.Vector() : float8_t(0); return src_thread_data_valid ? tmp.Vector() : float8_t(0);
#endif #endif
...@@ -235,34 +210,21 @@ __device__ void amd_buffer_store_v2<float, 1>(const float src_thread_data, ...@@ -235,34 +210,21 @@ __device__ void amd_buffer_store_v2<float, 1>(const float src_thread_data,
const bool dst_thread_data_valid, const bool dst_thread_data_valid,
const index_t dst_data_range) const index_t dst_data_range)
{ {
BufferResourceConstant<float> dst_wave_buffer_resource; const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_data_range);
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float);
// wavewise setting (32 bit)
#if 0
dst_wave_buffer_resource.config[3] = 0x00027000;
#else
dst_wave_buffer_resource.config[3] = 0x31014000;
#endif
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_raw_buffer_store_fp32(src_thread_data, __llvm_amdgcn_raw_buffer_store_fp32(
dst_wave_buffer_resource.data, src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0, 0);
dst_addr_shift + dst_thread_addr_offset,
0,
0);
#else #else
if(dst_thread_data_valid) if(dst_thread_data_valid)
{ {
__llvm_amdgcn_buffer_store_fp32( __llvm_amdgcn_buffer_store_fp32(
src_thread_data, dst_wave_buffer_resource.data, dst_thread_addr_offset, 0, 0); src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0, 0);
} }
#endif #endif
} }
...@@ -274,34 +236,21 @@ __device__ void amd_buffer_store_v2<float, 2>(const float2_t src_thread_data, ...@@ -274,34 +236,21 @@ __device__ void amd_buffer_store_v2<float, 2>(const float2_t src_thread_data,
const bool dst_thread_data_valid, const bool dst_thread_data_valid,
const index_t dst_data_range) const index_t dst_data_range)
{ {
BufferResourceConstant<float> dst_wave_buffer_resource; const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_data_range);
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float);
// wavewise setting (32 bit)
#if 0
dst_wave_buffer_resource.config[3] = 0x00027000;
#else
dst_wave_buffer_resource.config[3] = 0x31014000;
#endif
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data, __llvm_amdgcn_raw_buffer_store_fp32x2(
dst_wave_buffer_resource.data, src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0, 0);
dst_addr_shift + dst_thread_addr_offset,
0,
0);
#else #else
if(dst_thread_data_valid) if(dst_thread_data_valid)
{ {
__llvm_amdgcn_raw_buffer_store_fp32x2( __llvm_amdgcn_raw_buffer_store_fp32x2(
src_thread_data, dst_wave_buffer_resource.data, dst_thread_addr_offset, 0, 0); src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0, 0);
} }
#endif #endif
} }
...@@ -313,34 +262,21 @@ __device__ void amd_buffer_store_v2<float, 4>(const float4_t src_thread_data, ...@@ -313,34 +262,21 @@ __device__ void amd_buffer_store_v2<float, 4>(const float4_t src_thread_data,
const bool dst_thread_data_valid, const bool dst_thread_data_valid,
const index_t dst_data_range) const index_t dst_data_range)
{ {
BufferResourceConstant<float> dst_wave_buffer_resource; const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_data_range);
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float);
// wavewise setting (32 bit)
#if 0
dst_wave_buffer_resource.config[3] = 0x00027000;
#else
dst_wave_buffer_resource.config[3] = 0x31014000;
#endif
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data, __llvm_amdgcn_raw_buffer_store_fp32x4(
dst_wave_buffer_resource.data, src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0, 0);
dst_addr_shift + dst_thread_addr_offset,
0,
0);
#else #else
if(dst_thread_data_valid) if(dst_thread_data_valid)
{ {
__llvm_amdgcn_raw_buffer_store_fp32x4( __llvm_amdgcn_raw_buffer_store_fp32x4(
src_thread_data, dst_wave_buffer_resource.data, dst_thread_addr_offset, 0, 0); src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0, 0);
} }
#endif #endif
} }
......
...@@ -7,6 +7,20 @@ ...@@ -7,6 +7,20 @@
#endif #endif
#include "bfloat16_dev.hpp" #include "bfloat16_dev.hpp"
#if 1
#define CK_AMD_GPU_GFX906 1
#elif 0
#define CK_AMD_GPU_GFX908 1
#else
#define CK_AMD_GPU_GFX906 1
#endif
#if defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(CK_AMD_GPU_GFX_1030)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#endif
#ifndef CK_HIP_VERSION_FLAT #ifndef CK_HIP_VERSION_FLAT
#define CK_HIP_VERSION_FLAT 0 #define CK_HIP_VERSION_FLAT 0
#endif #endif
......
...@@ -56,7 +56,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -56,7 +56,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1 #if 0
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
...@@ -167,14 +167,14 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -167,14 +167,14 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 16; constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
......
...@@ -51,8 +51,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc, ...@@ -51,8 +51,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc,
// compile-time variables // compile-time variables
constexpr auto in_n_hi_wi_c_desc = constexpr auto in_n_hi_wi_c_desc =
make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{}); make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{});
constexpr auto wei_k_y_x_c_desc = constexpr auto wei_k_y_x_c_desc = make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{});
make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{});
constexpr auto out_n_ho_wo_k_desc = constexpr auto out_n_ho_wo_k_desc =
make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{}); make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{});
......
...@@ -39,7 +39,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -39,7 +39,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1 #if 0
// run-time variables // run-time variables
const auto in_n_c_hi_wi_desc = const auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
......
...@@ -109,7 +109,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -109,7 +109,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
wei_kyxc_device_buf.ToDevice(wei_kyxc.mData.data()); wei_kyxc_device_buf.ToDevice(wei_kyxc.mData.data());
out_nhwk_device_buf.ToDevice(out_nhwk.mData.data()); out_nhwk_device_buf.ToDevice(out_nhwk.mData.data());
#if 0 #if 1
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
......
...@@ -63,7 +63,7 @@ int main(int argc, char* argv[]) ...@@ -63,7 +63,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 4;
constexpr index_t HI = 1080; constexpr index_t HI = 1080;
...@@ -77,7 +77,7 @@ int main(int argc, char* argv[]) ...@@ -77,7 +77,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 1 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 4;
constexpr index_t HI = 1080; constexpr index_t HI = 1080;
...@@ -689,7 +689,7 @@ int main(int argc, char* argv[]) ...@@ -689,7 +689,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1 #elif 0
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment