Commit 198593d5 authored by Jing Zhang's avatar Jing Zhang
Browse files

inline_asm refactor

parent 34321734
...@@ -272,21 +272,28 @@ __device__ void amd_assembly_outer_product_1x4(int8x8_t a, ...@@ -272,21 +272,28 @@ __device__ void amd_assembly_outer_product_1x4(int8x8_t a,
int32_t& c2, int32_t& c2,
int32_t& c3) int32_t& c3)
{ {
amd_assembly_outer_product_1x4(a.Vectors(Number<4>{})[Number<0>{}],
b0.Vectors(Number<4>{})[Number<0>{}], const int8x4_t* p_a_int8x4_t = reinterpret_cast<const int8x4_t*>(&a);
b1.Vectors(Number<4>{})[Number<0>{}], const int8x4_t* p_b0_int8x4_t = reinterpret_cast<const int8x4_t*>(&b0);
b2.Vectors(Number<4>{})[Number<0>{}], const int8x4_t* p_b1_int8x4_t = reinterpret_cast<const int8x4_t*>(&b1);
b3.Vectors(Number<4>{})[Number<0>{}], const int8x4_t* p_b2_int8x4_t = reinterpret_cast<const int8x4_t*>(&b2);
const int8x4_t* p_b3_int8x4_t = reinterpret_cast<const int8x4_t*>(&b3);
amd_assembly_outer_product_1x4(p_a_int8x4_t[0],
p_b0_int8x4_t[0],
p_b1_int8x4_t[0],
p_b2_int8x4_t[0],
p_b3_int8x4_t[0],
c0, c0,
c1, c1,
c2, c2,
c3); c3);
amd_assembly_outer_product_1x4(a.Vectors(Number<4>{})[Number<1>{}], amd_assembly_outer_product_1x4(p_a_int8x4_t[1],
b0.Vectors(Number<4>{})[Number<1>{}], p_b0_int8x4_t[1],
b1.Vectors(Number<4>{})[Number<1>{}], p_b1_int8x4_t[1],
b2.Vectors(Number<4>{})[Number<1>{}], p_b2_int8x4_t[1],
b3.Vectors(Number<4>{})[Number<1>{}], p_b3_int8x4_t[1],
c0, c0,
c1, c1,
c2, c2,
...@@ -302,22 +309,30 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, ...@@ -302,22 +309,30 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
int32_t& c1, int32_t& c1,
int32_t& c2, int32_t& c2,
int32_t& c3) int32_t& c3)
{ {
amd_assembly_outer_product_1x4(a.Vectors(Number<8>{})[Number<0>{}],
b0.Vectors(Number<8>{})[Number<0>{}], const int8x8_t* p_a_int8x8_t = reinterpret_cast<const int8x8_t*>(&a);
b1.Vectors(Number<8>{})[Number<0>{}], const int8x8_t* p_b0_int8x8_t = reinterpret_cast<const int8x8_t*>(&b0);
b2.Vectors(Number<8>{})[Number<0>{}], const int8x8_t* p_b1_int8x8_t = reinterpret_cast<const int8x8_t*>(&b1);
b3.Vectors(Number<8>{})[Number<0>{}], const int8x8_t* p_b2_int8x8_t = reinterpret_cast<const int8x8_t*>(&b2);
const int8x8_t* p_b3_int8x8_t = reinterpret_cast<const int8x8_t*>(&b3);
amd_assembly_outer_product_1x4(p_a_int8x8_t[0],
p_b0_int8x8_t[0],
p_b1_int8x8_t[0],
p_b2_int8x8_t[0],
p_b3_int8x8_t[0],
c0, c0,
c1, c1,
c2, c2,
c3); c3);
amd_assembly_outer_product_1x4(a.Vectors(Number<8>{})[Number<1>{}], amd_assembly_outer_product_1x4(p_a_int8x8_t[1],
b0.Vectors(Number<8>{})[Number<1>{}], p_b0_int8x8_t[1],
b1.Vectors(Number<8>{})[Number<1>{}], p_b1_int8x8_t[1],
b2.Vectors(Number<8>{})[Number<1>{}], p_b2_int8x8_t[1],
b3.Vectors(Number<8>{})[Number<1>{}], p_b3_int8x8_t[1],
c0, c0,
c1, c1,
c2, c2,
......
...@@ -224,17 +224,16 @@ struct vector_type<T, 16> ...@@ -224,17 +224,16 @@ struct vector_type<T, 16>
__host__ __device__ constexpr auto& Vectors(Number<16>) { return data_.d16x1_; } __host__ __device__ constexpr auto& Vectors(Number<16>) { return data_.d16x1_; }
}; };
// fp32 // fp32
using float2_t = typename vector_type<float, 2>::type; using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type; using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type; using float8_t = typename vector_type<float, 8>::type;
// fp16 // fp16
using half_t = _Float16; using half_t = _Float16;
using half2_t = typename vector_type<half_t, 2>::type; using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type; using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type; using half8_t = typename vector_type<half_t, 8>::type;
using half16_t = typename vector_type<half_t, 16>::type; using half16_t = typename vector_type<half_t, 16>::type;
// bfp16 // bfp16
...@@ -439,8 +438,8 @@ struct vector_type<int8_t, 16> ...@@ -439,8 +438,8 @@ struct vector_type<int8_t, 16>
// hack for int8x4_t, because compiler does not have native support for int8x4_t // hack for int8x4_t, because compiler does not have native support for int8x4_t
// int8x4_t is defined as int32_t // int8x4_t is defined as int32_t
using int8x4_t = typename vector_type<int8_t, 4>::type; using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = vector_type<int8_t, 8>; using int8x8_t = typename vector_type<int8_t, 8>::type;
using int8x16_t = vector_type<int8_t, 16>; using int8x16_t = typename vector_type<int8_t, 16>::type;
// data type conversion // data type conversion
template <typename T> template <typename T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment