Commit 198593d5 authored by Jing Zhang's avatar Jing Zhang
Browse files

inline_asm refactor

parent 34321734
......@@ -272,21 +272,28 @@ __device__ void amd_assembly_outer_product_1x4(int8x8_t a,
int32_t& c2,
int32_t& c3)
{
amd_assembly_outer_product_1x4(a.Vectors(Number<4>{})[Number<0>{}],
b0.Vectors(Number<4>{})[Number<0>{}],
b1.Vectors(Number<4>{})[Number<0>{}],
b2.Vectors(Number<4>{})[Number<0>{}],
b3.Vectors(Number<4>{})[Number<0>{}],
const int8x4_t* p_a_int8x4_t = reinterpret_cast<const int8x4_t*>(&a);
const int8x4_t* p_b0_int8x4_t = reinterpret_cast<const int8x4_t*>(&b0);
const int8x4_t* p_b1_int8x4_t = reinterpret_cast<const int8x4_t*>(&b1);
const int8x4_t* p_b2_int8x4_t = reinterpret_cast<const int8x4_t*>(&b2);
const int8x4_t* p_b3_int8x4_t = reinterpret_cast<const int8x4_t*>(&b3);
amd_assembly_outer_product_1x4(p_a_int8x4_t[0],
p_b0_int8x4_t[0],
p_b1_int8x4_t[0],
p_b2_int8x4_t[0],
p_b3_int8x4_t[0],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(a.Vectors(Number<4>{})[Number<1>{}],
b0.Vectors(Number<4>{})[Number<1>{}],
b1.Vectors(Number<4>{})[Number<1>{}],
b2.Vectors(Number<4>{})[Number<1>{}],
b3.Vectors(Number<4>{})[Number<1>{}],
amd_assembly_outer_product_1x4(p_a_int8x4_t[1],
p_b0_int8x4_t[1],
p_b1_int8x4_t[1],
p_b2_int8x4_t[1],
p_b3_int8x4_t[1],
c0,
c1,
c2,
......@@ -302,22 +309,30 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
int32_t& c1,
int32_t& c2,
int32_t& c3)
{
amd_assembly_outer_product_1x4(a.Vectors(Number<8>{})[Number<0>{}],
b0.Vectors(Number<8>{})[Number<0>{}],
b1.Vectors(Number<8>{})[Number<0>{}],
b2.Vectors(Number<8>{})[Number<0>{}],
b3.Vectors(Number<8>{})[Number<0>{}],
const int8x8_t* p_a_int8x8_t = reinterpret_cast<const int8x8_t*>(&a);
const int8x8_t* p_b0_int8x8_t = reinterpret_cast<const int8x8_t*>(&b0);
const int8x8_t* p_b1_int8x8_t = reinterpret_cast<const int8x8_t*>(&b1);
const int8x8_t* p_b2_int8x8_t = reinterpret_cast<const int8x8_t*>(&b2);
const int8x8_t* p_b3_int8x8_t = reinterpret_cast<const int8x8_t*>(&b3);
amd_assembly_outer_product_1x4(p_a_int8x8_t[0],
p_b0_int8x8_t[0],
p_b1_int8x8_t[0],
p_b2_int8x8_t[0],
p_b3_int8x8_t[0],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(a.Vectors(Number<8>{})[Number<1>{}],
b0.Vectors(Number<8>{})[Number<1>{}],
b1.Vectors(Number<8>{})[Number<1>{}],
b2.Vectors(Number<8>{})[Number<1>{}],
b3.Vectors(Number<8>{})[Number<1>{}],
amd_assembly_outer_product_1x4(p_a_int8x8_t[1],
p_b0_int8x8_t[1],
p_b1_int8x8_t[1],
p_b2_int8x8_t[1],
p_b3_int8x8_t[1],
c0,
c1,
c2,
......
......@@ -224,17 +224,16 @@ struct vector_type<T, 16>
__host__ __device__ constexpr auto& Vectors(Number<16>) { return data_.d16x1_; }
};
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
// fp16
using half_t = _Float16;
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
using half_t = _Float16;
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
using half16_t = typename vector_type<half_t, 16>::type;
// bfp16
......@@ -439,8 +438,8 @@ struct vector_type<int8_t, 16>
// hack for int8x4_t, because compiler does not have native support for int8x4_t
// int8x4_t is defined as int32_t
using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = vector_type<int8_t, 8>;
using int8x16_t = vector_type<int8_t, 16>;
using int8x8_t = typename vector_type<int8_t, 8>::type;
using int8x16_t = typename vector_type<int8_t, 16>::type;
// data type conversion
template <typename T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment