Commit 86c5f995 authored by Jing Zhang's avatar Jing Zhang
Browse files

add fp16 by 16

parent 1d011fef
......@@ -190,6 +190,29 @@ __device__ void amd_assembly_outer_product_1x4(half8_t a,
p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3);
}
__device__ void amd_assembly_outer_product_1x4(half16_t a,
half16_t b0,
half16_t b1,
half16_t b2,
half16_t b3,
float& c0,
float& c1,
float& c2,
float& c3)
{
const half8_t* p_a_half8 = reinterpret_cast<const half8_t*>(&a);
const half8_t* p_b0_half8 = reinterpret_cast<const half8_t*>(&b0);
const half8_t* p_b1_half8 = reinterpret_cast<const half8_t*>(&b1);
const half8_t* p_b2_half8 = reinterpret_cast<const half8_t*>(&b2);
const half8_t* p_b3_half8 = reinterpret_cast<const half8_t*>(&b3);
amd_assembly_outer_product_1x4(
p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);
amd_assembly_outer_product_1x4(
p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3);
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void
......
......@@ -11,9 +11,9 @@
#define CK_DEVICE_BACKEND_AMD 1
// GPU ID
#define CK_AMD_GPU_GFX906 1
#define CK_AMD_GPU_GFX906 0
#define CK_AMD_GPU_GFX908 0
#define CK_AMD_GPU_GFX1030 0
#define CK_AMD_GPU_GFX1030 1
// HIP version
#ifndef CK_HIP_VERSION_FLAT
......@@ -53,7 +53,7 @@
// AMD buffer addressing
#ifndef CK_USE_AMD_BUFFER_ADDRESSING
#define CK_USE_AMD_BUFFER_ADDRESSING 1
#define CK_USE_AMD_BUFFER_ADDRESSING 0
#endif
// only gfx908 support native floating point atomic add
......
......@@ -168,6 +168,63 @@ struct vector_type<T, 8>
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; }
};
template <typename T>
struct vector_type<T, 16>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
using type = d16_t;
union
{
d16_t d16_;
StaticallyIndexedArray<d1_t, 16> d1x16_;
StaticallyIndexedArray<d2_t, 8> d2x8_;
StaticallyIndexedArray<d4_t, 4> d4x4_;
StaticallyIndexedArray<d8_t, 2> d8x2_;
StaticallyIndexedArray<d16_t, 1> d16x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 16; }
__host__ __device__ constexpr const auto& Vector() const { return data_.d16_; }
__host__ __device__ constexpr auto& Vector() { return data_.d16_; }
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x16_; }
__host__ __device__ constexpr auto& Scalars() { return data_.d1x16_; }
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x16_; }
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x8_; }
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x4_; }
__host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x2_; }
__host__ __device__ constexpr const auto& Vectors(Number<16>) const { return data_.d16x1_; }
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x16_; }
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x8_; }
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x4_; }
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x2_; }
__host__ __device__ constexpr auto& Vectors(Number<16>) { return data_.d16x1_; }
};
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
......@@ -178,6 +235,7 @@ using half_t = _Float16;
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
using half16_t = typename vector_type<half_t, 16>::type;
// bfp16
using ushort2_t = typename vector_type<ushort, 2>::type;
......
......@@ -116,18 +116,18 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
// cdata = 64, BlockSize = 64, 16x8x32x4
constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 16;
constexpr index_t KPerBlock = K;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 32;
constexpr index_t EPerBlock = 2;
constexpr index_t EPerBlock = C0;
constexpr index_t KPerThread = KPerBlock;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = EPerBlock;
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<EPerBlock, KPerBlock>;
using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
......@@ -164,7 +164,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
#endif
constexpr auto conv_driver =
// DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
//DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad<
BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type,
......
......@@ -50,7 +50,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
#elif 1
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 270;
......@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 1080;
......@@ -665,7 +665,7 @@ int main(int argc, char* argv[])
using out_data_t = float;
#elif 1
using in_data_t = half_t;
constexpr index_t in_vector_size = 8;
constexpr index_t in_vector_size = 16;
using acc_data_t = float;
using out_data_t = half_t;
#elif 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment