Commit 3321471c authored by Jing Zhang's avatar Jing Zhang
Browse files

add fp16

parent 0c883faa
......@@ -59,7 +59,26 @@ __llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
// half
__device__ half_t
__llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
__device__ half2_t
__llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16");
__device__ half4_t
__llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16");
// float
__device__ float
__llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
index_t voffset,
......@@ -114,6 +133,28 @@ __llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
// half
__device__ void
__llvm_amdgcn_raw_buffer_store_fp16(half_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
__device__ void
__llvm_amdgcn_raw_buffer_store_fp16x2(half2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16");
__device__ void
__llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
// float
__device__ void
__llvm_amdgcn_raw_buffer_store_fp32(float vdata,
int32x4_t rsrc,
......@@ -142,6 +183,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
index_t src_wave_addr_offset)
{
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, half2_t>::value && (N == 1)) ||
(is_same<T, half4_t>::value && (N == 1)) ||
(is_same<T, half8_t>::value && (N == 1)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32x2_t>::value && (N == 1)) ||
(is_same<T, int32x4_t>::value && (N == 1)),
......@@ -177,6 +222,55 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
return tmp.Vector();
}
}
else if constexpr(is_same<T, half_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_fp16(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 2)
{
return __llvm_amdgcn_raw_buffer_load_fp16x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 4)
{
return __llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
}
else if constexpr(is_same<T, half2_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_fp16x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
}
else if constexpr(is_same<T, half4_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
}
else if constexpr(is_same<T, half8_t>::value)
{
if constexpr(N == 1)
{
vector_type<half_t, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return tmp.Vector();
}
}
else if constexpr(is_same<T, int32_t>::value)
{
if constexpr(N == 1)
......@@ -234,7 +328,8 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
static_assert(
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
"wrong! not implemented");
if constexpr(is_same<T, float>::value)
......@@ -334,6 +429,50 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
0);
}
}
else if constexpr(is_same<T, half_t>::value)
{
if constexpr(N == 1)
{
__llvm_amdgcn_raw_buffer_store_fp16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
__llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
__llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 8)
{
vector_type<half_t, 8> tmp;
tmp.Vector() = src_thread_data;
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.Vectors(Number<4>{})[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.Vectors(Number<4>{})[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
}
// buffer_load requires:
......
......@@ -166,6 +166,30 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
"3"(c3));
}
__device__ void amd_assembly_outer_product_1x4(half8_t a,
half8_t b0,
half8_t b1,
half8_t b2,
half8_t b3,
float& c0,
float& c1,
float& c2,
float& c3)
{
const half4_t* p_a_half4 = reinterpret_cast<const half4_t*>(&a);
const half4_t* p_b0_half4 = reinterpret_cast<const half4_t*>(&b0);
const half4_t* p_b1_half4 = reinterpret_cast<const half4_t*>(&b1);
const half4_t* p_b2_half4 = reinterpret_cast<const half4_t*>(&b2);
const half4_t* p_b3_half4 = reinterpret_cast<const half4_t*>(&b3);
amd_assembly_outer_product_1x4(
p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);
amd_assembly_outer_product_1x4(
p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3);
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void
......
......@@ -53,7 +53,7 @@
// AMD buffer addressing
#ifndef CK_USE_AMD_BUFFER_ADDRESSING
#define CK_USE_AMD_BUFFER_ADDRESSING 0
#define CK_USE_AMD_BUFFER_ADDRESSING 1
#endif
// only gfx908 support native floating point atomic add
......
......@@ -118,16 +118,16 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 64;
constexpr index_t EPerBlock = 1;
constexpr index_t WoPerBlock = 32;
constexpr index_t EPerBlock = 2;
constexpr index_t KPerThread = KPerBlock;
constexpr index_t HoPerThread = 4;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = EPerBlock;
using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>;
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<EPerBlock, KPerBlock>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
......@@ -164,7 +164,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
#endif
constexpr auto conv_driver =
//DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
// DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad<
BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type,
......
......@@ -78,12 +78,12 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
#elif 1
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 1080;
constexpr index_t WI = 1920;
constexpr index_t K = 4;
constexpr index_t K = 16;
constexpr index_t Y = 3;
constexpr index_t X = 3;
......@@ -106,7 +106,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 540;
......@@ -663,12 +663,17 @@ int main(int argc, char* argv[])
constexpr index_t in_vector_size = 1;
using acc_data_t = float;
using out_data_t = float;
#elif 1
using in_data_t = half_t;
constexpr index_t in_vector_size = 8;
using acc_data_t = float;
using out_data_t = half_t;
#elif 0
using in_data_t = float;
constexpr index_t in_vector_size = 1;
using acc_data_t = float;
using out_data_t = int8_t;
#elif 1
#elif 0
using in_data_t = int8_t;
constexpr index_t in_vector_size = 16;
using acc_data_t = int32_t;
......@@ -816,6 +821,7 @@ int main(int argc, char* argv[])
check_error(out_nkhw_host, out_nkhw_device);
#if 0
if(do_log)
{
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
......@@ -823,5 +829,6 @@ int main(int argc, char* argv[])
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
}
#endif
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment