Commit 5602817f authored by Chao Liu's avatar Chao Liu
Browse files

add back buffer addressing changes

parent 95593106
...@@ -6,6 +6,17 @@ ...@@ -6,6 +6,17 @@
namespace ck { namespace ck {
template <typename T>
union BufferResource
{
// 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
int32x4_t data;
T* address[2];
int32_t range[4];
int32_t config[4];
};
__device__ float __llvm_amdgcn_buffer_load_f32(int32x4_t srsrc, __device__ float __llvm_amdgcn_buffer_load_f32(int32x4_t srsrc,
index_t vindex, index_t vindex,
index_t offset, index_t offset,
......
...@@ -6,27 +6,27 @@ ...@@ -6,27 +6,27 @@
namespace ck { namespace ck {
template <typename T> template <typename T>
union BufferResource union BufferResource_v2
{ {
// 128 bit SGPRs to supply buffer resource in buffer instructions // 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
int32x4_t data; int32x4_t data;
T* address[2]; StaticallyIndexedArray<T*, 2> address;
int32_t range[4]; StaticallyIndexedArray<int32_t, 4> range;
int32_t config[4]; StaticallyIndexedArray<int32_t, 4> config;
}; };
template <typename T> template <typename T>
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_size) __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_size)
{ {
BufferResource<T> wave_buffer_resource; BufferResource_v2<T> wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
wave_buffer_resource.address[0] = const_cast<remove_cv_t<T>*>(p_wave); wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
// wavewise range (32 bit) // wavewise range (32 bit)
wave_buffer_resource.range[2] = data_space_size * sizeof(T); wave_buffer_resource.range(Number<2>{}) = data_space_size * sizeof(T);
// wavewise setting (32 bit) // wavewise setting (32 bit)
wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
return wave_buffer_resource.data; return wave_buffer_resource.data;
} }
...@@ -37,6 +37,19 @@ __llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, ...@@ -37,6 +37,19 @@ __llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
__device__ int8x2_t
__llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8");
__device__ int8x4_t
__llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
__device__ int16_t __device__ int16_t
__llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, __llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t voffset, index_t voffset,
...@@ -105,6 +118,20 @@ __llvm_amdgcn_raw_buffer_store_i8(int8_t vdata, ...@@ -105,6 +118,20 @@ __llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
__device__ void
__llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8");
__device__ void
__llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
__device__ void __device__ void
__llvm_amdgcn_raw_buffer_store_i16(int16_t vdata, __llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
...@@ -183,6 +210,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -183,6 +210,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
index_t src_wave_addr_offset) index_t src_wave_addr_offset)
{ {
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, half2_t>::value && (N == 1)) || (is_same<T, half2_t>::value && (N == 1)) ||
(is_same<T, half4_t>::value && (N == 1)) || (is_same<T, half4_t>::value && (N == 1)) ||
...@@ -306,6 +334,38 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -306,6 +334,38 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t), src_wave_addr_offset + 4 * sizeof(int32_t),
0); 0);
return tmp.Vector();
}
}
else if constexpr(is_same<T, int8_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_i8(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 2)
{
return __llvm_amdgcn_raw_buffer_load_i8x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 4)
{
return __llvm_amdgcn_raw_buffer_load_i8x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 8)
{
vector_type<int8_t, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) =
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int8_t),
0);
return tmp.Vector(); return tmp.Vector();
} }
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
#define CK_DEVICE_BACKEND_AMD 1 #define CK_DEVICE_BACKEND_AMD 1
// GPU ID // GPU ID
#if 0 #if 1
#define CK_AMD_GPU_GFX906 1 #define CK_AMD_GPU_GFX906 1
#elif 0 #elif 0
#define CK_AMD_GPU_GFX908 1 #define CK_AMD_GPU_GFX908 1
#elif 1 #elif 0
#define CK_AMD_GPU_GFX1030 1 #define CK_AMD_GPU_GFX1030 1
#endif #endif
...@@ -88,7 +88,7 @@ ...@@ -88,7 +88,7 @@
// experimental implementation // experimental implementation
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#endif #endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
......
...@@ -437,6 +437,7 @@ struct vector_type<int8_t, 16> ...@@ -437,6 +437,7 @@ struct vector_type<int8_t, 16>
// i8 // i8
// hack for int8x4_t, because compiler does not have native support for int8x4_t // hack for int8x4_t, because compiler does not have native support for int8x4_t
// int8x4_t is defined as int32_t // int8x4_t is defined as int32_t
using int8x2_t = typename vector_type<int8_t, 2>::type;
using int8x4_t = typename vector_type<int8_t, 4>::type; using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = typename vector_type<int8_t, 8>::type; using int8x8_t = typename vector_type<int8_t, 8>::type;
using int8x16_t = typename vector_type<int8_t, 16>::type; using int8x16_t = typename vector_type<int8_t, 16>::type;
......
...@@ -40,7 +40,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( ...@@ -40,7 +40,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
#if 0 #if 1
// run-time variables // run-time variables
const auto in_n_c_hi_wi_desc = const auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
...@@ -167,7 +167,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( ...@@ -167,7 +167,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 1 #elif 0
// cdata = 64, BlockSize 64, 16x256x4 // cdata = 64, BlockSize 64, 16x256x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
......
...@@ -53,7 +53,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -53,7 +53,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr auto C0 = C / Number<InWeiVectorSize>{}; constexpr auto C0 = C / Number<InWeiVectorSize>{};
constexpr auto C1 = Number<InWeiVectorSize>{}; constexpr auto C1 = Number<InWeiVectorSize>{};
#if 0 #if 1
// run-time variables // run-time variables
constexpr auto in_n_hi_wi_c0_desc = constexpr auto in_n_hi_wi_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0)); make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0));
...@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
#if 1 #if 0
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
...@@ -211,7 +211,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -211,7 +211,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
#elif 1 #elif 0
// cdata = 64, BlockSize = 64, 16x256x4 // cdata = 64, BlockSize = 64, 16x256x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
...@@ -310,7 +310,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -310,7 +310,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
#elif 0 #elif 1
// cdata = 64, BlockSize = 256, 128x128x8 // cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
......
...@@ -83,10 +83,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -83,10 +83,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
const auto out_n_k0_ho_wo_k1_desc = const auto out_n_k0_ho_wo_k1_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{}); const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif #endif
Tensor<TInWei> in_n_c0_hi_wi_c1(make_HostTensorDescriptor( Tensor<TInWei> in_n_c0_hi_wi_c1(make_HostTensorDescriptor(
......
...@@ -48,8 +48,8 @@ int main(int argc, char* argv[]) ...@@ -48,8 +48,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
...@@ -62,9 +62,9 @@ int main(int argc, char* argv[]) ...@@ -62,9 +62,9 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 1080; constexpr index_t HI = 1080;
...@@ -150,7 +150,7 @@ int main(int argc, char* argv[]) ...@@ -150,7 +150,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 3x3, 71x71 // 3x3, 71x71
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 192; constexpr index_t C = 192;
...@@ -630,7 +630,7 @@ int main(int argc, char* argv[]) ...@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
print_array("ConvStrides", to_multi_index(ConvStrides{})); print_array("ConvStrides", to_multi_index(ConvStrides{}));
print_array("ConvDilations", to_multi_index(ConvDilations{})); print_array("ConvDilations", to_multi_index(ConvDilations{}));
#if 0 #if 1
using in_data_t = float; using in_data_t = float;
constexpr index_t in_vector_size = 1; constexpr index_t in_vector_size = 1;
using acc_data_t = float; using acc_data_t = float;
...@@ -724,7 +724,7 @@ int main(int argc, char* argv[]) ...@@ -724,7 +724,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 0 #elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
in_vector_size, in_vector_size,
acc_data_t, acc_data_t,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment