Commit 3321471c authored by Jing Zhang's avatar Jing Zhang
Browse files

add fp16

parent 0c883faa
...@@ -59,7 +59,26 @@ __llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, ...@@ -59,7 +59,26 @@ __llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
// half
__device__ half_t
__llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
__device__ half2_t
__llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16");
__device__ half4_t
__llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16");
// float
__device__ float __device__ float
__llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, __llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
index_t voffset, index_t voffset,
...@@ -114,6 +133,28 @@ __llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, ...@@ -114,6 +133,28 @@ __llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
// half
__device__ void
__llvm_amdgcn_raw_buffer_store_fp16(half_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
__device__ void
__llvm_amdgcn_raw_buffer_store_fp16x2(half2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16");
__device__ void
__llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
// float
__device__ void __device__ void
__llvm_amdgcn_raw_buffer_store_fp32(float vdata, __llvm_amdgcn_raw_buffer_store_fp32(float vdata,
int32x4_t rsrc, int32x4_t rsrc,
...@@ -142,6 +183,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -142,6 +183,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
index_t src_wave_addr_offset) index_t src_wave_addr_offset)
{ {
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, half2_t>::value && (N == 1)) ||
(is_same<T, half4_t>::value && (N == 1)) ||
(is_same<T, half8_t>::value && (N == 1)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32x2_t>::value && (N == 1)) || (is_same<T, int32x2_t>::value && (N == 1)) ||
(is_same<T, int32x4_t>::value && (N == 1)), (is_same<T, int32x4_t>::value && (N == 1)),
...@@ -177,6 +222,55 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -177,6 +222,55 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
return tmp.Vector(); return tmp.Vector();
} }
} }
else if constexpr(is_same<T, half_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_fp16(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 2)
{
return __llvm_amdgcn_raw_buffer_load_fp16x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 4)
{
return __llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
}
else if constexpr(is_same<T, half2_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_fp16x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
}
else if constexpr(is_same<T, half4_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
}
else if constexpr(is_same<T, half8_t>::value)
{
if constexpr(N == 1)
{
vector_type<half_t, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return tmp.Vector();
}
}
else if constexpr(is_same<T, int32_t>::value) else if constexpr(is_same<T, int32_t>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
...@@ -234,7 +328,8 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -234,7 +328,8 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
static_assert( static_assert(
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
"wrong! not implemented"); "wrong! not implemented");
if constexpr(is_same<T, float>::value) if constexpr(is_same<T, float>::value)
...@@ -334,6 +429,50 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -334,6 +429,50 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
0); 0);
} }
} }
else if constexpr(is_same<T, half_t>::value)
{
if constexpr(N == 1)
{
__llvm_amdgcn_raw_buffer_store_fp16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
__llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
__llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 8)
{
vector_type<half_t, 8> tmp;
tmp.Vector() = src_thread_data;
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.Vectors(Number<4>{})[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.Vectors(Number<4>{})[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
} }
// buffer_load requires: // buffer_load requires:
......
...@@ -166,6 +166,30 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a, ...@@ -166,6 +166,30 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
"3"(c3)); "3"(c3));
} }
__device__ void amd_assembly_outer_product_1x4(half8_t a,
half8_t b0,
half8_t b1,
half8_t b2,
half8_t b3,
float& c0,
float& c1,
float& c2,
float& c3)
{
const half4_t* p_a_half4 = reinterpret_cast<const half4_t*>(&a);
const half4_t* p_b0_half4 = reinterpret_cast<const half4_t*>(&b0);
const half4_t* p_b1_half4 = reinterpret_cast<const half4_t*>(&b1);
const half4_t* p_b2_half4 = reinterpret_cast<const half4_t*>(&b2);
const half4_t* p_b3_half4 = reinterpret_cast<const half4_t*>(&b3);
amd_assembly_outer_product_1x4(
p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);
amd_assembly_outer_product_1x4(
p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3);
}
// c0 += inner_product(a, b0) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1) // c1 += inner_product(a, b1)
__device__ void __device__ void
......
...@@ -53,7 +53,7 @@ ...@@ -53,7 +53,7 @@
// AMD buffer addressing // AMD buffer addressing
#ifndef CK_USE_AMD_BUFFER_ADDRESSING #ifndef CK_USE_AMD_BUFFER_ADDRESSING
#define CK_USE_AMD_BUFFER_ADDRESSING 0 #define CK_USE_AMD_BUFFER_ADDRESSING 1
#endif #endif
// only gfx908 support native floating point atomic add // only gfx908 support native floating point atomic add
......
...@@ -118,16 +118,16 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -118,16 +118,16 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr index_t KPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 8; constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 64; constexpr index_t WoPerBlock = 32;
constexpr index_t EPerBlock = 1; constexpr index_t EPerBlock = 2;
constexpr index_t KPerThread = KPerBlock; constexpr index_t KPerThread = KPerBlock;
constexpr index_t HoPerThread = 4; constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = EPerBlock; constexpr index_t EPerThread = EPerBlock;
using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>; using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>; using ABlockTransferThreadClusterLengths_E_K = Sequence<EPerBlock, KPerBlock>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1; constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1; constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
...@@ -164,7 +164,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -164,7 +164,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
#endif #endif
constexpr auto conv_driver = constexpr auto conv_driver =
//DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad< // DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad< DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad<
BlockSize, BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type, typename vector_type<TInWei, InWeiVectorSize>::type,
......
...@@ -76,14 +76,14 @@ int main(int argc, char* argv[]) ...@@ -76,14 +76,14 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 1
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 1080; constexpr index_t HI = 1080;
constexpr index_t WI = 1920; constexpr index_t WI = 1920;
constexpr index_t K = 4; constexpr index_t K = 16;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -106,7 +106,7 @@ int main(int argc, char* argv[]) ...@@ -106,7 +106,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 1 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 540; constexpr index_t HI = 540;
...@@ -663,12 +663,17 @@ int main(int argc, char* argv[]) ...@@ -663,12 +663,17 @@ int main(int argc, char* argv[])
constexpr index_t in_vector_size = 1; constexpr index_t in_vector_size = 1;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = float; using out_data_t = float;
#elif 1
using in_data_t = half_t;
constexpr index_t in_vector_size = 8;
using acc_data_t = float;
using out_data_t = half_t;
#elif 0 #elif 0
using in_data_t = float; using in_data_t = float;
constexpr index_t in_vector_size = 1; constexpr index_t in_vector_size = 1;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = int8_t; using out_data_t = int8_t;
#elif 1 #elif 0
using in_data_t = int8_t; using in_data_t = int8_t;
constexpr index_t in_vector_size = 16; constexpr index_t in_vector_size = 16;
using acc_data_t = int32_t; using acc_data_t = int32_t;
...@@ -816,6 +821,7 @@ int main(int argc, char* argv[]) ...@@ -816,6 +821,7 @@ int main(int argc, char* argv[])
check_error(out_nkhw_host, out_nkhw_device); check_error(out_nkhw_host, out_nkhw_device);
#if 0
if(do_log) if(do_log)
{ {
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
...@@ -823,5 +829,6 @@ int main(int argc, char* argv[]) ...@@ -823,5 +829,6 @@ int main(int argc, char* argv[])
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl; LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl; LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
} }
#endif
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment