Commit 7267c0b3 authored by rocking's avatar rocking
Browse files

Comment v_dot4_i32_i8

parent 4a93c836
...@@ -205,155 +205,155 @@ __device__ void amd_assembly_outer_product_1x4(half16_t a, ...@@ -205,155 +205,155 @@ __device__ void amd_assembly_outer_product_1x4(half16_t a,
// c0 += inner_product(a, b0) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1) // c1 += inner_product(a, b1)
__device__ void // __device__ void
amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1) // amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1)
{ // {
#if 1 // #if 1
asm volatile("\n \ // asm volatile("\n \
v_dot4_i32_i8 %0, %2, %3, %0\n \ // v_dot4_i32_i8 %0, %2, %3, %0\n \
v_dot4_i32_i8 %1, %2, %4, %1\n \ // v_dot4_i32_i8 %1, %2, %4, %1\n \
" // "
: "=v"(c0), "=v"(c1) // : "=v"(c0), "=v"(c1)
: "v"(bit_cast<int32_t>(a)), // : "v"(bit_cast<int32_t>(a)),
"v"(bit_cast<int32_t>(b0)), // "v"(bit_cast<int32_t>(b0)),
"v"(bit_cast<int32_t>(b1)), // "v"(bit_cast<int32_t>(b1)),
"0"(c0), // "0"(c0),
"1"(c1)); // "1"(c1));
#else // #else
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false); // c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false); // c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
#endif // #endif
} // }
// c0 += inner_product(a, b0) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1) // c1 += inner_product(a, b1)
// c2 += inner_product(a, b2) // c2 += inner_product(a, b2)
// c3 += inner_product(a, b3) // c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(int8x4_t a, // __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
int8x4_t b0, // int8x4_t b0,
int8x4_t b1, // int8x4_t b1,
int8x4_t b2, // int8x4_t b2,
int8x4_t b3, // int8x4_t b3,
int32_t& c0, // int32_t& c0,
int32_t& c1, // int32_t& c1,
int32_t& c2, // int32_t& c2,
int32_t& c3) // int32_t& c3)
{ // {
#if 1 // #if 1
asm volatile("\n \ // asm volatile("\n \
v_dot4_i32_i8 %0, %4, %5, %0\n \ // v_dot4_i32_i8 %0, %4, %5, %0\n \
v_dot4_i32_i8 %1, %4, %6, %1\n \ // v_dot4_i32_i8 %1, %4, %6, %1\n \
v_dot4_i32_i8 %2, %4, %7, %2\n \ // v_dot4_i32_i8 %2, %4, %7, %2\n \
v_dot4_i32_i8 %3, %4, %8, %3\n \ // v_dot4_i32_i8 %3, %4, %8, %3\n \
" // "
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(bit_cast<int32_t>(a)), // : "v"(bit_cast<int32_t>(a)),
"v"(bit_cast<int32_t>(b0)), // "v"(bit_cast<int32_t>(b0)),
"v"(bit_cast<int32_t>(b1)), // "v"(bit_cast<int32_t>(b1)),
"v"(bit_cast<int32_t>(b2)), // "v"(bit_cast<int32_t>(b2)),
"v"(bit_cast<int32_t>(b3)), // "v"(bit_cast<int32_t>(b3)),
"0"(c0), // "0"(c0),
"1"(c1), // "1"(c1),
"2"(c2), // "2"(c2),
"3"(c3)); // "3"(c3));
#else // #else
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false); // c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false); // c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false); // c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false); // c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
#endif // #endif
} // }
__device__ void amd_assembly_outer_product_1x4(int8x8_t a, // __device__ void amd_assembly_outer_product_1x4(int8x8_t a,
int8x8_t b0, // int8x8_t b0,
int8x8_t b1, // int8x8_t b1,
int8x8_t b2, // int8x8_t b2,
int8x8_t b3, // int8x8_t b3,
int32_t& c0, // int32_t& c0,
int32_t& c1, // int32_t& c1,
int32_t& c2, // int32_t& c2,
int32_t& c3) // int32_t& c3)
{ // {
constexpr auto I0 = Number<0>{}; // constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; // constexpr auto I1 = Number<1>{};
amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0], // amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I0], // vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I0], // vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I0], // vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I0], // vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I0],
c0, // c0,
c1, // c1,
c2, // c2,
c3); // c3);
amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1], // amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I1], // vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I1], // vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I1], // vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I1], // vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I1],
c0, // c0,
c1, // c1,
c2, // c2,
c3); // c3);
} // }
__device__ void amd_assembly_outer_product_1x4(int8x16_t a, // __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
int8x16_t b0, // int8x16_t b0,
int8x16_t b1, // int8x16_t b1,
int8x16_t b2, // int8x16_t b2,
int8x16_t b3, // int8x16_t b3,
int32_t& c0, // int32_t& c0,
int32_t& c1, // int32_t& c1,
int32_t& c2, // int32_t& c2,
int32_t& c3) // int32_t& c3)
{ // {
constexpr auto I0 = Number<0>{}; // constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; // constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; // constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; // constexpr auto I3 = Number<3>{};
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0], // amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I0], // vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I0], // vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I0], // vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I0], // vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I0],
c0, // c0,
c1, // c1,
c2, // c2,
c3); // c3);
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1], // amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I1], // vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I1], // vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I1], // vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I1], // vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I1],
c0, // c0,
c1, // c1,
c2, // c2,
c3); // c3);
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2], // amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I2], // vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I2], // vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I2], // vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I2], // vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I2],
c0, // c0,
c1, // c1,
c2, // c2,
c3); // c3);
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3], // amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I3], // vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I3], // vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I3], // vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I3], // vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I3],
c0, // c0,
c1, // c1,
c2, // c2,
c3); // c3);
} // }
// Ranged input operand // Ranged input operand
__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c) __device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c)
......
...@@ -161,17 +161,17 @@ template <> ...@@ -161,17 +161,17 @@ template <>
__device__ void __device__ void
inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c) inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c)
{ {
#if defined(CK_USE_AMD_V_DOT4_I32_I8) // #if defined(CK_USE_AMD_V_DOT4_I32_I8)
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM // #if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
asm volatile("\n \ // asm volatile("\n \
v_dot4_i32_i8 %0, %1, %2, %0\n \ // v_dot4_i32_i8 %0, %1, %2, %0\n \
" // "
: "=v"(c) // : "=v"(c)
: "v"(bit_cast<int32_t>(a)), "v"(bit_cast<int32_t>(b)), "0"(c)); // : "v"(bit_cast<int32_t>(a)), "v"(bit_cast<int32_t>(b)), "0"(c));
#else // #else
c = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b), c, false); // c = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b), c, false);
#endif // #endif
#else // #else
const vector_type<int8_t, 4> a_vector{a}; const vector_type<int8_t, 4> a_vector{a};
const vector_type<int8_t, 4> b_vector{b}; const vector_type<int8_t, 4> b_vector{b};
...@@ -179,9 +179,10 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, ...@@ -179,9 +179,10 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
c += type_convert<int32_t>(a_vector.AsType<int8_t>()[i]) * c += type_convert<int32_t>(a_vector.AsType<int8_t>()[i]) *
type_convert<int32_t>(b_vector.AsType<int8_t>()[i]); type_convert<int32_t>(b_vector.AsType<int8_t>()[i]);
}); });
#endif // #endif
} }
template <> template <>
__device__ void __device__ void
inner_product<int8x8_t, int8x8_t, int32_t>(const int8x8_t& a, const int8x8_t& b, int32_t& c) inner_product<int8x8_t, int8x8_t, int32_t>(const int8x8_t& a, const int8x8_t& b, int32_t& c)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment