Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
ea028ac6
"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3b66cc0fc1244150356d43d788eaa52d816ec989"
Unverified
Commit
ea028ac6
authored
Mar 15, 2023
by
Haocong WANG
Committed by
GitHub
Mar 15, 2023
Browse files
Fix arch limitation bug (#639)
parent
5b57ab96
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
10 deletions
+10
-10
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+10
-10
No files found.
include/ck/utility/amd_wmma.hpp
View file @
ea028ac6
...
@@ -25,7 +25,7 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
...
@@ -25,7 +25,7 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
// delete them.
// delete them.
// amd_assembly_wmma_f32_16x16x16_f16_w32(
// amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
#if defined(__gfx11__)
#if defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
#else
#else
...
@@ -46,7 +46,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
...
@@ -46,7 +46,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
#if defined(__gfx11__)
#if defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32
(
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
...
@@ -71,7 +71,7 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
...
@@ -71,7 +71,7 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
// opsel usage
// opsel usage
// false: D0.[0:15] = result
// false: D0.[0:15] = result
// true : D0.[16:31]= result
// true : D0.[16:31]= result
#if defined(__gfx11__)
#if defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
reg_c
.
template
AsType
<
half16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f16_16x16x16_f16_w32
(
reg_c
.
template
AsType
<
half16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f16_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
half16_t
>()[
Number
<
0
>
{}],
Opsel
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
half16_t
>()[
Number
<
0
>
{}],
Opsel
);
#else
#else
...
@@ -95,7 +95,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
...
@@ -95,7 +95,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
// opsel usage
// opsel usage
// false: D0.[0:15] = result
// false: D0.[0:15] = result
// true : D0.[16:31]= result
// true : D0.[16:31]= result
#if defined(__gfx11__)
#if defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
reg_c
.
template
AsType
<
bhalf16_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
bhalf16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32
(
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
bhalf16_t
>()[
Number
<
0
>
{}],
Opsel
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
bhalf16_t
>()[
Number
<
0
>
{}],
Opsel
);
...
@@ -117,7 +117,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
...
@@ -117,7 +117,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
#if defined(__gfx11__)
#if defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
reg_c
.
template
AsType
<
int32x8_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
int32x8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32
(
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32
(
neg_a
,
neg_a
,
...
@@ -145,7 +145,7 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
...
@@ -145,7 +145,7 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
#if defined(__gfx11__)
#if defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w64
(
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}]);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}]);
#else
#else
...
@@ -166,7 +166,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
...
@@ -166,7 +166,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
#if defined(__gfx11__)
#if defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64
(
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}]);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}]);
...
@@ -191,7 +191,7 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
...
@@ -191,7 +191,7 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
// opsel usage
// opsel usage
// false: D0.[0:15] = result
// false: D0.[0:15] = result
// true : D0.[16:31]= result
// true : D0.[16:31]= result
#if defined(__gfx11__)
#if defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
reg_c
.
template
AsType
<
half8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f16_16x16x16_f16_w64
(
reg_c
.
template
AsType
<
half8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f16_16x16x16_f16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}],
Opsel
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}],
Opsel
);
#else
#else
...
@@ -215,7 +215,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
...
@@ -215,7 +215,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
// opsel usage
// opsel usage
// false: D0.[0:15] = result
// false: D0.[0:15] = result
// true : D0.[16:31]= result
// true : D0.[16:31]= result
#if defined(__gfx11__)
#if defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
reg_c
.
template
AsType
<
bhalf8_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
bhalf8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64
(
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}],
Opsel
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}],
Opsel
);
...
@@ -237,7 +237,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
...
@@ -237,7 +237,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
#if defined(__gfx11__)
#if defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64
(
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64
(
neg_a
,
neg_a
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment