Commit 55c280e4 authored by Jing Zhang's avatar Jing Zhang
Browse files

add int8x16_t

parent bb580479
...@@ -143,7 +143,8 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -143,7 +143,8 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
{ {
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32x2_t>::value && (N == 1)), (is_same<T, int32x2_t>::value && (N == 1)) ||
(is_same<T, int32x4_t>::value && (N == 1)),
"wrong! not implemented"); "wrong! not implemented");
if constexpr(is_same<T, float>::value) if constexpr(is_same<T, float>::value)
...@@ -214,6 +215,14 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -214,6 +215,14 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
} }
} }
else if constexpr(is_same<T, int32x4_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
}
} }
template <typename T, index_t N> template <typename T, index_t N>
......
...@@ -246,5 +246,36 @@ __device__ void amd_assembly_outer_product_1x4(int8x8_t a, ...@@ -246,5 +246,36 @@ __device__ void amd_assembly_outer_product_1x4(int8x8_t a,
c3); c3);
} }
__device__ void amd_assembly_outer_product_1x4(int8x16_t a,
int8x16_t b0,
int8x16_t b1,
int8x16_t b2,
int8x16_t b3,
int32_t& c0,
int32_t& c1,
int32_t& c2,
int32_t& c3)
{
amd_assembly_outer_product_1x4(a.Vectors(Number<8>{})[Number<0>{}],
b0.Vectors(Number<8>{})[Number<0>{}],
b1.Vectors(Number<8>{})[Number<0>{}],
b2.Vectors(Number<8>{})[Number<0>{}],
b3.Vectors(Number<8>{})[Number<0>{}],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(a.Vectors(Number<8>{})[Number<1>{}],
b0.Vectors(Number<8>{})[Number<1>{}],
b1.Vectors(Number<8>{})[Number<1>{}],
b2.Vectors(Number<8>{})[Number<1>{}],
b3.Vectors(Number<8>{})[Number<1>{}],
c0,
c1,
c2,
c3);
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -321,11 +321,68 @@ struct vector_type<int8_t, 8> ...@@ -321,11 +321,68 @@ struct vector_type<int8_t, 8>
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; } __host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; }
}; };
template <>
struct vector_type<int8_t, 16>
{
using d1_t = int8_t;
typedef int16_t d2_t;
typedef int32_t d4_t;
typedef int32x2_t d8_t;
typedef int32x4_t d16_t;
using type = d16_t;
union
{
d16_t d16_;
StaticallyIndexedArray<d1_t, 16> d1x16_;
StaticallyIndexedArray<d2_t, 8> d2x8_;
StaticallyIndexedArray<d4_t, 4> d4x4_;
StaticallyIndexedArray<d8_t, 2> d8x2_;
StaticallyIndexedArray<d8_t, 1> d16x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 16; }
__host__ __device__ constexpr const auto& Vector() const { return data_.d16_; }
__host__ __device__ constexpr auto& Vector() { return data_.d16_; }
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x16_; }
__host__ __device__ constexpr auto& Scalars() { return data_.d1x16_; }
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x16_; }
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x8_; }
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x4_; }
__host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x2_; }
__host__ __device__ constexpr const auto& Vectors(Number<16>) const { return data_.d16x1_; }
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x16_; }
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x8_; }
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x4_; }
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x2_; }
__host__ __device__ constexpr auto& Vectors(Number<16>) { return data_.d16x1_; }
};
// i8 // i8
// hack for int8x4_t, because compiler does not have native support for int8x4_t // hack for int8x4_t, because compiler does not have native support for int8x4_t
// int8x4_t is defined as int32_t // int8x4_t is defined as int32_t
using int8x4_t = typename vector_type<int8_t, 4>::type; using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = vector_type<int8_t, 8>; using int8x8_t = vector_type<int8_t, 8>;
using int8x16_t = vector_type<int8_t, 16>;
// data type conversion // data type conversion
template <typename T> template <typename T>
...@@ -404,6 +461,20 @@ struct inner_product_with_conversion ...@@ -404,6 +461,20 @@ struct inner_product_with_conversion
return acc; return acc;
} }
__device__ T operator()(int8x16_t a, int8x16_t b) const
{
const vector_type<int8_t, 16> a_vector{a};
const vector_type<int8_t, 16> b_vector{b};
T acc = 0;
static_for<0, 16, 1>{}([&](auto i) {
acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
});
return acc;
}
}; };
} // namespace ck } // namespace ck
......
...@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr index_t KPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 8; constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr index_t EPerBlock = 2; constexpr index_t EPerBlock = 1;
constexpr index_t KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2; constexpr index_t HoPerThread = 2;
......
...@@ -642,7 +642,7 @@ int main(int argc, char* argv[]) ...@@ -642,7 +642,7 @@ int main(int argc, char* argv[])
using out_data_t = int8_t; using out_data_t = int8_t;
#elif 1 #elif 1
using in_data_t = int8_t; using in_data_t = int8_t;
constexpr index_t in_vector_size = 8; constexpr index_t in_vector_size = 16;
using acc_data_t = int32_t; using acc_data_t = int32_t;
using out_data_t = int8_t; using out_data_t = int8_t;
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment