Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7267c0b3
Commit
7267c0b3
authored
Mar 16, 2023
by
rocking
Browse files
Comment v_dot4_i32_i8
parent
4a93c836
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
157 additions
and
156 deletions
+157
-156
include/ck/utility/amd_inline_asm.hpp
include/ck/utility/amd_inline_asm.hpp
+144
-144
include/ck/utility/inner_product.hpp
include/ck/utility/inner_product.hpp
+13
-12
No files found.
include/ck/utility/amd_inline_asm.hpp
View file @
7267c0b3
...
...
@@ -205,155 +205,155 @@ __device__ void amd_assembly_outer_product_1x4(half16_t a,
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__
void
amd_assembly_outer_product_1x2
(
int8x4_t
a
,
int8x4_t
b0
,
int8x4_t
b1
,
int32_t
&
c0
,
int32_t
&
c1
)
{
#if 1
asm
volatile
(
"
\n
\
v_dot4_i32_i8 %0, %2, %3, %0
\n
\
v_dot4_i32_i8 %1, %2, %4, %1
\n
\
"
:
"=v"
(
c0
),
"=v"
(
c1
)
:
"v"
(
bit_cast
<
int32_t
>
(
a
)),
"v"
(
bit_cast
<
int32_t
>
(
b0
)),
"v"
(
bit_cast
<
int32_t
>
(
b1
)),
"0"
(
c0
),
"1"
(
c1
));
#else
c0
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b0
),
c0
,
false
);
c1
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b1
),
c1
,
false
);
#endif
}
//
__device__ void
//
amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1)
//
{
//
#if 1
//
asm volatile("\n \
//
v_dot4_i32_i8 %0, %2, %3, %0\n \
//
v_dot4_i32_i8 %1, %2, %4, %1\n \
//
"
//
: "=v"(c0), "=v"(c1)
//
: "v"(bit_cast<int32_t>(a)),
//
"v"(bit_cast<int32_t>(b0)),
//
"v"(bit_cast<int32_t>(b1)),
//
"0"(c0),
//
"1"(c1));
//
#else
//
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
//
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
//
#endif
//
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__
void
amd_assembly_outer_product_1x4
(
int8x4_t
a
,
int8x4_t
b0
,
int8x4_t
b1
,
int8x4_t
b2
,
int8x4_t
b3
,
int32_t
&
c0
,
int32_t
&
c1
,
int32_t
&
c2
,
int32_t
&
c3
)
{
#if 1
asm
volatile
(
"
\n
\
v_dot4_i32_i8 %0, %4, %5, %0
\n
\
v_dot4_i32_i8 %1, %4, %6, %1
\n
\
v_dot4_i32_i8 %2, %4, %7, %2
\n
\
v_dot4_i32_i8 %3, %4, %8, %3
\n
\
"
:
"=v"
(
c0
),
"=v"
(
c1
),
"=v"
(
c2
),
"=v"
(
c3
)
:
"v"
(
bit_cast
<
int32_t
>
(
a
)),
"v"
(
bit_cast
<
int32_t
>
(
b0
)),
"v"
(
bit_cast
<
int32_t
>
(
b1
)),
"v"
(
bit_cast
<
int32_t
>
(
b2
)),
"v"
(
bit_cast
<
int32_t
>
(
b3
)),
"0"
(
c0
),
"1"
(
c1
),
"2"
(
c2
),
"3"
(
c3
));
#else
c0
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b0
),
c0
,
false
);
c1
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b1
),
c1
,
false
);
c2
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b2
),
c2
,
false
);
c3
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b3
),
c3
,
false
);
#endif
}
__device__
void
amd_assembly_outer_product_1x4
(
int8x8_t
a
,
int8x8_t
b0
,
int8x8_t
b1
,
int8x8_t
b2
,
int8x8_t
b3
,
int32_t
&
c0
,
int32_t
&
c1
,
int32_t
&
c2
,
int32_t
&
c3
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
amd_assembly_outer_product_1x4
(
vector_type
<
int8_t
,
8
>
{
a
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
8
>
{
b0
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
8
>
{
b1
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
8
>
{
b2
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
8
>
{
b3
}.
AsType
<
int8x4_t
>
()[
I0
],
c0
,
c1
,
c2
,
c3
);
amd_assembly_outer_product_1x4
(
vector_type
<
int8_t
,
8
>
{
a
}.
AsType
<
int8x4_t
>
()[
I1
],
vector_type
<
int8_t
,
8
>
{
b0
}.
AsType
<
int8x4_t
>
()[
I1
],
vector_type
<
int8_t
,
8
>
{
b1
}.
AsType
<
int8x4_t
>
()[
I1
],
vector_type
<
int8_t
,
8
>
{
b2
}.
AsType
<
int8x4_t
>
()[
I1
],
vector_type
<
int8_t
,
8
>
{
b3
}.
AsType
<
int8x4_t
>
()[
I1
],
c0
,
c1
,
c2
,
c3
);
}
__device__
void
amd_assembly_outer_product_1x4
(
int8x16_t
a
,
int8x16_t
b0
,
int8x16_t
b1
,
int8x16_t
b2
,
int8x16_t
b3
,
int32_t
&
c0
,
int32_t
&
c1
,
int32_t
&
c2
,
int32_t
&
c3
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
amd_assembly_outer_product_1x4
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
16
>
{
b0
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
16
>
{
b1
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
16
>
{
b2
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
16
>
{
b3
}.
AsType
<
int8x4_t
>
()[
I0
],
c0
,
c1
,
c2
,
c3
);
amd_assembly_outer_product_1x4
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I1
],
vector_type
<
int8_t
,
16
>
{
b0
}.
AsType
<
int8x4_t
>
()[
I1
],
vector_type
<
int8_t
,
16
>
{
b1
}.
AsType
<
int8x4_t
>
()[
I1
],
vector_type
<
int8_t
,
16
>
{
b2
}.
AsType
<
int8x4_t
>
()[
I1
],
vector_type
<
int8_t
,
16
>
{
b3
}.
AsType
<
int8x4_t
>
()[
I1
],
c0
,
c1
,
c2
,
c3
);
amd_assembly_outer_product_1x4
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I2
],
vector_type
<
int8_t
,
16
>
{
b0
}.
AsType
<
int8x4_t
>
()[
I2
],
vector_type
<
int8_t
,
16
>
{
b1
}.
AsType
<
int8x4_t
>
()[
I2
],
vector_type
<
int8_t
,
16
>
{
b2
}.
AsType
<
int8x4_t
>
()[
I2
],
vector_type
<
int8_t
,
16
>
{
b3
}.
AsType
<
int8x4_t
>
()[
I2
],
c0
,
c1
,
c2
,
c3
);
amd_assembly_outer_product_1x4
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I3
],
vector_type
<
int8_t
,
16
>
{
b0
}.
AsType
<
int8x4_t
>
()[
I3
],
vector_type
<
int8_t
,
16
>
{
b1
}.
AsType
<
int8x4_t
>
()[
I3
],
vector_type
<
int8_t
,
16
>
{
b2
}.
AsType
<
int8x4_t
>
()[
I3
],
vector_type
<
int8_t
,
16
>
{
b3
}.
AsType
<
int8x4_t
>
()[
I3
],
c0
,
c1
,
c2
,
c3
);
}
//
__device__ void amd_assembly_outer_product_1x4(int8x4_t a,
//
int8x4_t b0,
//
int8x4_t b1,
//
int8x4_t b2,
//
int8x4_t b3,
//
int32_t& c0,
//
int32_t& c1,
//
int32_t& c2,
//
int32_t& c3)
//
{
//
#if 1
//
asm volatile("\n \
//
v_dot4_i32_i8 %0, %4, %5, %0\n \
//
v_dot4_i32_i8 %1, %4, %6, %1\n \
//
v_dot4_i32_i8 %2, %4, %7, %2\n \
//
v_dot4_i32_i8 %3, %4, %8, %3\n \
//
"
//
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
//
: "v"(bit_cast<int32_t>(a)),
//
"v"(bit_cast<int32_t>(b0)),
//
"v"(bit_cast<int32_t>(b1)),
//
"v"(bit_cast<int32_t>(b2)),
//
"v"(bit_cast<int32_t>(b3)),
//
"0"(c0),
//
"1"(c1),
//
"2"(c2),
//
"3"(c3));
//
#else
//
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
//
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
//
c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
//
c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
//
#endif
//
}
//
__device__ void amd_assembly_outer_product_1x4(int8x8_t a,
//
int8x8_t b0,
//
int8x8_t b1,
//
int8x8_t b2,
//
int8x8_t b3,
//
int32_t& c0,
//
int32_t& c1,
//
int32_t& c2,
//
int32_t& c3)
//
{
//
constexpr auto I0 = Number<0>{};
//
constexpr auto I1 = Number<1>{};
//
amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
//
vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I0],
//
vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I0],
//
vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I0],
//
vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I0],
//
c0,
//
c1,
//
c2,
//
c3);
//
amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
//
vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I1],
//
vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I1],
//
vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I1],
//
vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I1],
//
c0,
//
c1,
//
c2,
//
c3);
//
}
//
__device__ void amd_assembly_outer_product_1x4(int8x16_t a,
//
int8x16_t b0,
//
int8x16_t b1,
//
int8x16_t b2,
//
int8x16_t b3,
//
int32_t& c0,
//
int32_t& c1,
//
int32_t& c2,
//
int32_t& c3)
//
{
//
constexpr auto I0 = Number<0>{};
//
constexpr auto I1 = Number<1>{};
//
constexpr auto I2 = Number<2>{};
//
constexpr auto I3 = Number<3>{};
//
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
//
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I0],
//
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I0],
//
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I0],
//
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I0],
//
c0,
//
c1,
//
c2,
//
c3);
//
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
//
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I1],
//
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I1],
//
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I1],
//
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I1],
//
c0,
//
c1,
//
c2,
//
c3);
//
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
//
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I2],
//
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I2],
//
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I2],
//
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I2],
//
c0,
//
c1,
//
c2,
//
c3);
//
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
//
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I3],
//
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I3],
//
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I3],
//
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I3],
//
c0,
//
c1,
//
c2,
//
c3);
//
}
// Ranged input operand
__device__
void
amd_assembly_wmma_f32_16x16x16_f16_w32
(
half16_t
a
,
half16_t
b
,
float8_t
&
c
)
...
...
include/ck/utility/inner_product.hpp
View file @
7267c0b3
...
...
@@ -161,17 +161,17 @@ template <>
__device__
void
inner_product
<
int8x4_t
,
int8x4_t
,
int32_t
>
(
const
int8x4_t
&
a
,
const
int8x4_t
&
b
,
int32_t
&
c
)
{
#if defined(CK_USE_AMD_V_DOT4_I32_I8)
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
asm
volatile
(
"
\n
\
v_dot4_i32_i8 %0, %1, %2, %0
\n
\
"
:
"=v"
(
c
)
:
"v"
(
bit_cast
<
int32_t
>
(
a
)),
"v"
(
bit_cast
<
int32_t
>
(
b
)),
"0"
(
c
));
#else
c
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b
),
c
,
false
);
#endif
#else
//
#if defined(CK_USE_AMD_V_DOT4_I32_I8)
//
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
//
asm volatile("\n \
//
v_dot4_i32_i8 %0, %1, %2, %0\n \
//
"
//
: "=v"(c)
//
: "v"(bit_cast<int32_t>(a)), "v"(bit_cast<int32_t>(b)), "0"(c));
//
#else
//
c = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b), c, false);
//
#endif
//
#else
const
vector_type
<
int8_t
,
4
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
4
>
b_vector
{
b
};
...
...
@@ -179,9 +179,10 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
c
+=
type_convert
<
int32_t
>
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
type_convert
<
int32_t
>
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
});
#endif
//
#endif
}
template
<
>
__device__
void
inner_product
<
int8x8_t
,
int8x8_t
,
int32_t
>
(
const
int8x8_t
&
a
,
const
int8x8_t
&
b
,
int32_t
&
c
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment