Commit 86c5f995 authored by Jing Zhang's avatar Jing Zhang
Browse files

add fp16 by 16

parent 1d011fef
...@@ -190,6 +190,29 @@ __device__ void amd_assembly_outer_product_1x4(half8_t a, ...@@ -190,6 +190,29 @@ __device__ void amd_assembly_outer_product_1x4(half8_t a,
p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3); p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3);
} }
__device__ void amd_assembly_outer_product_1x4(half16_t a,
half16_t b0,
half16_t b1,
half16_t b2,
half16_t b3,
float& c0,
float& c1,
float& c2,
float& c3)
{
const half8_t* p_a_half8 = reinterpret_cast<const half8_t*>(&a);
const half8_t* p_b0_half8 = reinterpret_cast<const half8_t*>(&b0);
const half8_t* p_b1_half8 = reinterpret_cast<const half8_t*>(&b1);
const half8_t* p_b2_half8 = reinterpret_cast<const half8_t*>(&b2);
const half8_t* p_b3_half8 = reinterpret_cast<const half8_t*>(&b3);
amd_assembly_outer_product_1x4(
p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);
amd_assembly_outer_product_1x4(
p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3);
}
// c0 += inner_product(a, b0) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1) // c1 += inner_product(a, b1)
__device__ void __device__ void
......
...@@ -11,9 +11,9 @@ ...@@ -11,9 +11,9 @@
#define CK_DEVICE_BACKEND_AMD 1 #define CK_DEVICE_BACKEND_AMD 1
// GPU ID // GPU ID
#define CK_AMD_GPU_GFX906 1 #define CK_AMD_GPU_GFX906 0
#define CK_AMD_GPU_GFX908 0 #define CK_AMD_GPU_GFX908 0
#define CK_AMD_GPU_GFX1030 0 #define CK_AMD_GPU_GFX1030 1
// HIP version // HIP version
#ifndef CK_HIP_VERSION_FLAT #ifndef CK_HIP_VERSION_FLAT
...@@ -53,7 +53,7 @@ ...@@ -53,7 +53,7 @@
// AMD buffer addressing // AMD buffer addressing
#ifndef CK_USE_AMD_BUFFER_ADDRESSING #ifndef CK_USE_AMD_BUFFER_ADDRESSING
#define CK_USE_AMD_BUFFER_ADDRESSING 1 #define CK_USE_AMD_BUFFER_ADDRESSING 0
#endif #endif
// only gfx908 support native floating point atomic add // only gfx908 support native floating point atomic add
......
...@@ -168,6 +168,63 @@ struct vector_type<T, 8> ...@@ -168,6 +168,63 @@ struct vector_type<T, 8>
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; } __host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; }
}; };
template <typename T>
struct vector_type<T, 16>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
using type = d16_t;
union
{
d16_t d16_;
StaticallyIndexedArray<d1_t, 16> d1x16_;
StaticallyIndexedArray<d2_t, 8> d2x8_;
StaticallyIndexedArray<d4_t, 4> d4x4_;
StaticallyIndexedArray<d8_t, 2> d8x2_;
StaticallyIndexedArray<d16_t, 1> d16x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 16; }
__host__ __device__ constexpr const auto& Vector() const { return data_.d16_; }
__host__ __device__ constexpr auto& Vector() { return data_.d16_; }
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x16_; }
__host__ __device__ constexpr auto& Scalars() { return data_.d1x16_; }
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x16_; }
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x8_; }
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x4_; }
__host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x2_; }
__host__ __device__ constexpr const auto& Vectors(Number<16>) const { return data_.d16x1_; }
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x16_; }
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x8_; }
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x4_; }
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x2_; }
__host__ __device__ constexpr auto& Vectors(Number<16>) { return data_.d16x1_; }
};
// fp32 // fp32
using float2_t = typename vector_type<float, 2>::type; using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type; using float4_t = typename vector_type<float, 4>::type;
...@@ -178,6 +235,7 @@ using half_t = _Float16; ...@@ -178,6 +235,7 @@ using half_t = _Float16;
using half2_t = typename vector_type<half_t, 2>::type; using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type; using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type; using half8_t = typename vector_type<half_t, 8>::type;
using half16_t = typename vector_type<half_t, 16>::type;
// bfp16 // bfp16
using ushort2_t = typename vector_type<ushort, 2>::type; using ushort2_t = typename vector_type<ushort, 2>::type;
......
...@@ -116,18 +116,18 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -116,18 +116,18 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
// cdata = 64, BlockSize = 64, 16x8x32x4 // cdata = 64, BlockSize = 64, 16x8x32x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 16; constexpr index_t KPerBlock = K;
constexpr index_t HoPerBlock = 8; constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr index_t EPerBlock = 2; constexpr index_t EPerBlock = C0;
constexpr index_t KPerThread = KPerBlock; constexpr index_t KPerThread = KPerBlock;
constexpr index_t HoPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = EPerBlock; constexpr index_t EPerThread = EPerBlock;
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>; using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<EPerBlock, KPerBlock>; using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1; constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1; constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
...@@ -164,7 +164,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -164,7 +164,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
#endif #endif
constexpr auto conv_driver = constexpr auto conv_driver =
// DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad< //DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad< DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad<
BlockSize, BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type, typename vector_type<TInWei, InWeiVectorSize>::type,
......
...@@ -50,7 +50,7 @@ int main(int argc, char* argv[]) ...@@ -50,7 +50,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 270; constexpr index_t HI = 270;
...@@ -78,7 +78,7 @@ int main(int argc, char* argv[]) ...@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 1 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 1080; constexpr index_t HI = 1080;
...@@ -665,7 +665,7 @@ int main(int argc, char* argv[]) ...@@ -665,7 +665,7 @@ int main(int argc, char* argv[])
using out_data_t = float; using out_data_t = float;
#elif 1 #elif 1
using in_data_t = half_t; using in_data_t = half_t;
constexpr index_t in_vector_size = 8; constexpr index_t in_vector_size = 16;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = half_t; using out_data_t = half_t;
#elif 0 #elif 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment