Commit b1f7f365 authored by Jing Zhang's avatar Jing Zhang
Browse files

add int8x8_t

parent 8753e615
...@@ -142,7 +142,8 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -142,7 +142,8 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
index_t src_wave_addr_offset) index_t src_wave_addr_offset)
{ {
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)), (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32x2_t>::value && (N == 1)),
"wrong! not implemented"); "wrong! not implemented");
if constexpr(is_same<T, float>::value) if constexpr(is_same<T, float>::value)
...@@ -205,6 +206,14 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -205,6 +206,14 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
return tmp.Vector(); return tmp.Vector();
} }
} }
else if constexpr(is_same<T, int32x2_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_i32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
}
} }
template <typename T, index_t N> template <typename T, index_t N>
......
...@@ -215,5 +215,36 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a, ...@@ -215,5 +215,36 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
#endif #endif
} }
__device__ void amd_assembly_outer_product_1x4(int8x8_t a,
int8x8_t b0,
int8x8_t b1,
int8x8_t b2,
int8x8_t b3,
int32_t& c0,
int32_t& c1,
int32_t& c2,
int32_t& c3)
{
amd_assembly_outer_product_1x4(a.Vectors(Number<4>{})[Number<0>{}],
b0.Vectors(Number<4>{})[Number<0>{}],
b1.Vectors(Number<4>{})[Number<0>{}],
b2.Vectors(Number<4>{})[Number<0>{}],
b3.Vectors(Number<4>{})[Number<0>{}],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(a.Vectors(Number<4>{})[Number<1>{}],
b0.Vectors(Number<4>{})[Number<1>{}],
b1.Vectors(Number<4>{})[Number<1>{}],
b2.Vectors(Number<4>{})[Number<1>{}],
b3.Vectors(Number<4>{})[Number<1>{}],
c0,
c1,
c2,
c3);
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -168,6 +168,27 @@ struct vector_type<T, 8> ...@@ -168,6 +168,27 @@ struct vector_type<T, 8>
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; } __host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; }
}; };
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
// fp16
using half_t = _Float16;
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
// bfp16
using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type;
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
template <> template <>
struct vector_type<int8_t, 2> struct vector_type<int8_t, 2>
{ {
...@@ -250,31 +271,61 @@ struct vector_type<int8_t, 4> ...@@ -250,31 +271,61 @@ struct vector_type<int8_t, 4>
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x1_; } __host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x1_; }
}; };
// fp32 template <>
using float2_t = typename vector_type<float, 2>::type; struct vector_type<int8_t, 8>
using float4_t = typename vector_type<float, 4>::type; {
using float8_t = typename vector_type<float, 8>::type; using d1_t = int8_t;
typedef int16_t d2_t;
typedef int32_t d4_t;
typedef int32x2_t d8_t;
// fp16 using type = d8_t;
using half_t = _Float16;
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
// bfp16 union
using ushort2_t = typename vector_type<ushort, 2>::type; {
using ushort4_t = typename vector_type<ushort, 4>::type; d8_t d8_;
using ushort8_t = typename vector_type<ushort, 8>::type; StaticallyIndexedArray<d1_t, 8> d1x8_;
StaticallyIndexedArray<d2_t, 4> d2x4_;
StaticallyIndexedArray<d4_t, 2> d4x2_;
StaticallyIndexedArray<d8_t, 1> d8x1_;
} data_;
// i32 __host__ __device__ constexpr vector_type() : data_{type{0}} {}
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type; __host__ __device__ constexpr vector_type(type v) : data_{v} {}
using int32x8_t = typename vector_type<int32_t, 8>::type;
__host__ __device__ static constexpr index_t Size() { return 8; }
__host__ __device__ constexpr const auto& Vector() const { return data_.d8_; }
__host__ __device__ constexpr auto& Vector() { return data_.d8_; }
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x8_; }
__host__ __device__ constexpr auto& Scalars() { return data_.d1x8_; }
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x8_; }
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x4_; }
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x2_; }
__host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x1_; }
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x8_; }
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x4_; }
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x2_; }
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; }
};
// i8 // i8
// hack for int8x4_t, because compiler does not have native support for int8x4_t // hack for int8x4_t, because compiler does not have native support for int8x4_t
// int8x4_t is defined as int32_t // int8x4_t is defined as int32_t
using int8x4_t = typename vector_type<int8_t, 4>::type; using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = vector_type<int8_t, 8>;
// data type conversion // data type conversion
template <typename T> template <typename T>
...@@ -339,6 +390,20 @@ struct inner_product_with_conversion ...@@ -339,6 +390,20 @@ struct inner_product_with_conversion
return acc; return acc;
} }
__device__ T operator()(int8x8_t a, int8x8_t b) const
{
const vector_type<int8_t, 8> a_vector{a};
const vector_type<int8_t, 8> b_vector{b};
T acc = 0;
static_for<0, 8, 1>{}([&](auto i) {
acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
});
return acc;
}
}; };
} // namespace ck } // namespace ck
......
...@@ -113,7 +113,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -113,7 +113,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr index_t KPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 8; constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr index_t EPerBlock = 4; constexpr index_t EPerBlock = 2;
constexpr index_t KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2; constexpr index_t HoPerThread = 2;
......
...@@ -642,7 +642,7 @@ int main(int argc, char* argv[]) ...@@ -642,7 +642,7 @@ int main(int argc, char* argv[])
using out_data_t = int8_t; using out_data_t = int8_t;
#elif 1 #elif 1
using in_data_t = int8_t; using in_data_t = int8_t;
constexpr index_t in_vector_size = 4; constexpr index_t in_vector_size = 8;
using acc_data_t = int32_t; using acc_data_t = int32_t;
using out_data_t = int8_t; using out_data_t = int8_t;
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment