Commit 7267c0b3 authored by rocking's avatar rocking
Browse files

Comment v_dot4_i32_i8

parent 4a93c836
......@@ -205,155 +205,155 @@ __device__ void amd_assembly_outer_product_1x4(half16_t a,
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void
amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1)
{
#if 1
asm volatile("\n \
v_dot4_i32_i8 %0, %2, %3, %0\n \
v_dot4_i32_i8 %1, %2, %4, %1\n \
"
: "=v"(c0), "=v"(c1)
: "v"(bit_cast<int32_t>(a)),
"v"(bit_cast<int32_t>(b0)),
"v"(bit_cast<int32_t>(b1)),
"0"(c0),
"1"(c1));
#else
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
#endif
}
// __device__ void
// amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1)
// {
// #if 1
// asm volatile("\n \
// v_dot4_i32_i8 %0, %2, %3, %0\n \
// v_dot4_i32_i8 %1, %2, %4, %1\n \
// "
// : "=v"(c0), "=v"(c1)
// : "v"(bit_cast<int32_t>(a)),
// "v"(bit_cast<int32_t>(b0)),
// "v"(bit_cast<int32_t>(b1)),
// "0"(c0),
// "1"(c1));
// #else
// c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
// c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
// #endif
// }
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(int8x4_t a,
int8x4_t b0,
int8x4_t b1,
int8x4_t b2,
int8x4_t b3,
int32_t& c0,
int32_t& c1,
int32_t& c2,
int32_t& c3)
{
#if 1
asm volatile("\n \
v_dot4_i32_i8 %0, %4, %5, %0\n \
v_dot4_i32_i8 %1, %4, %6, %1\n \
v_dot4_i32_i8 %2, %4, %7, %2\n \
v_dot4_i32_i8 %3, %4, %8, %3\n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(bit_cast<int32_t>(a)),
"v"(bit_cast<int32_t>(b0)),
"v"(bit_cast<int32_t>(b1)),
"v"(bit_cast<int32_t>(b2)),
"v"(bit_cast<int32_t>(b3)),
"0"(c0),
"1"(c1),
"2"(c2),
"3"(c3));
#else
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
#endif
}
__device__ void amd_assembly_outer_product_1x4(int8x8_t a,
int8x8_t b0,
int8x8_t b1,
int8x8_t b2,
int8x8_t b3,
int32_t& c0,
int32_t& c1,
int32_t& c2,
int32_t& c3)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I0],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I1],
c0,
c1,
c2,
c3);
}
__device__ void amd_assembly_outer_product_1x4(int8x16_t a,
int8x16_t b0,
int8x16_t b1,
int8x16_t b2,
int8x16_t b3,
int32_t& c0,
int32_t& c1,
int32_t& c2,
int32_t& c3)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I0],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I1],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I2],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I3],
c0,
c1,
c2,
c3);
}
// __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
// int8x4_t b0,
// int8x4_t b1,
// int8x4_t b2,
// int8x4_t b3,
// int32_t& c0,
// int32_t& c1,
// int32_t& c2,
// int32_t& c3)
// {
// #if 1
// asm volatile("\n \
// v_dot4_i32_i8 %0, %4, %5, %0\n \
// v_dot4_i32_i8 %1, %4, %6, %1\n \
// v_dot4_i32_i8 %2, %4, %7, %2\n \
// v_dot4_i32_i8 %3, %4, %8, %3\n \
// "
// : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
// : "v"(bit_cast<int32_t>(a)),
// "v"(bit_cast<int32_t>(b0)),
// "v"(bit_cast<int32_t>(b1)),
// "v"(bit_cast<int32_t>(b2)),
// "v"(bit_cast<int32_t>(b3)),
// "0"(c0),
// "1"(c1),
// "2"(c2),
// "3"(c3));
// #else
// c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
// c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
// c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
// c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
// #endif
// }
// __device__ void amd_assembly_outer_product_1x4(int8x8_t a,
// int8x8_t b0,
// int8x8_t b1,
// int8x8_t b2,
// int8x8_t b3,
// int32_t& c0,
// int32_t& c1,
// int32_t& c2,
// int32_t& c3)
// {
// constexpr auto I0 = Number<0>{};
// constexpr auto I1 = Number<1>{};
// amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
// vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I0],
// vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I0],
// vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I0],
// vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I0],
// c0,
// c1,
// c2,
// c3);
// amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
// vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I1],
// vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I1],
// vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I1],
// vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I1],
// c0,
// c1,
// c2,
// c3);
// }
// __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
// int8x16_t b0,
// int8x16_t b1,
// int8x16_t b2,
// int8x16_t b3,
// int32_t& c0,
// int32_t& c1,
// int32_t& c2,
// int32_t& c3)
// {
// constexpr auto I0 = Number<0>{};
// constexpr auto I1 = Number<1>{};
// constexpr auto I2 = Number<2>{};
// constexpr auto I3 = Number<3>{};
// amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
// vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I0],
// vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I0],
// vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I0],
// vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I0],
// c0,
// c1,
// c2,
// c3);
// amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
// vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I1],
// vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I1],
// vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I1],
// vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I1],
// c0,
// c1,
// c2,
// c3);
// amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
// vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I2],
// vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I2],
// vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I2],
// vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I2],
// c0,
// c1,
// c2,
// c3);
// amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
// vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I3],
// vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I3],
// vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I3],
// vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I3],
// c0,
// c1,
// c2,
// c3);
// }
// Ranged input operand
__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c)
......
......@@ -161,17 +161,17 @@ template <>
__device__ void
inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c)
{
#if defined(CK_USE_AMD_V_DOT4_I32_I8)
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
asm volatile("\n \
v_dot4_i32_i8 %0, %1, %2, %0\n \
"
: "=v"(c)
: "v"(bit_cast<int32_t>(a)), "v"(bit_cast<int32_t>(b)), "0"(c));
#else
c = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b), c, false);
#endif
#else
// #if defined(CK_USE_AMD_V_DOT4_I32_I8)
// #if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
// asm volatile("\n \
// v_dot4_i32_i8 %0, %1, %2, %0\n \
// "
// : "=v"(c)
// : "v"(bit_cast<int32_t>(a)), "v"(bit_cast<int32_t>(b)), "0"(c));
// #else
// c = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b), c, false);
// #endif
// #else
const vector_type<int8_t, 4> a_vector{a};
const vector_type<int8_t, 4> b_vector{b};
......@@ -179,9 +179,10 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
c += type_convert<int32_t>(a_vector.AsType<int8_t>()[i]) *
type_convert<int32_t>(b_vector.AsType<int8_t>()[i]);
});
#endif
// #endif
}
template <>
__device__ void
inner_product<int8x8_t, int8x8_t, int32_t>(const int8x8_t& a, const int8x8_t& b, int32_t& c)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment