Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
2d4fb7d5
Commit
2d4fb7d5
authored
Feb 06, 2025
by
Andriy Roshchenko
Browse files
Latest reproducer for SCALE MFMA
parent
f3af1da6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
16 deletions
+36
-16
test/mx_mfma_op/scale_mfma_repro.cpp
test/mx_mfma_op/scale_mfma_repro.cpp
+36
-16
No files found.
test/mx_mfma_op/scale_mfma_repro.cpp
View file @
2d4fb7d5
#include <hip/hip_ext.h>
#include <hip/hip_ext.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
template
<
typename
Y
,
typename
X
,
typename
std
::
enable_if
<
sizeof
(
X
)
==
sizeof
(
Y
),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
Y
bit_cast
(
const
X
&
x
)
{
static_assert
(
__has_builtin
(
__builtin_bit_cast
),
""
);
static_assert
(
sizeof
(
X
)
==
sizeof
(
Y
),
"Do not support cast between different size of type"
);
return
__builtin_bit_cast
(
Y
,
x
);
}
__global__
void
kernel
()
__global__
void
kernel
()
{
{
using
dataAB
=
uint8_t
__attribute__
((
ext_vector_type
(
32
)));
using
dataAB
=
uint8_t
__attribute__
((
ext_vector_type
(
32
)));
using
dataC
=
float
__attribute__
((
ext_vector_type
(
16
)));
using
dataC
=
float
__attribute__
((
ext_vector_type
(
16
)));
using
dataX
=
int32_t
__attribute__
((
ext_vector_type
(
2
)));
using
dataX
=
int32_t
__attribute__
((
ext_vector_type
(
2
)));
dataAB
regA
(
0
x38
);
dataAB
regA
(
0
);
dataAB
regB
(
0
x38
);
dataAB
regB
(
0
);
dataC
regC
(
1
.0
f
);
dataC
regC
(
0
.0
f
);
// dataC regCin(1.0f);
// dataC regCin(1.0f);
#if 1
#if 0
// dataX xa{127, 127}; // 1.0
dataX xa{0x3F800000, 0x3F800000};
dataX
xa
(
127
&
0xFF
);
// 1.0
dataX xb(0x3F800000);
dataX
xb
(
127
&
0xFF
);
// 1.0
#elif
0
dataX
xa
{
0x3F000000
,
0x3F000000
};
dataX
xb
(
0x3F800000
);
#elif 0
dataX
xa
{
0x3F800000
,
0x3F000000
};
dataX
xb
(
0x3F800000
);
#elif 1
dataX
xa
{
0x7F
,
0x7E
};
// expect 64 at c(0,0)
dataX
xb
(
0x3F800000
);
// dataX xb(0x7F);
#else
#else
dataX
xa
(
0
);
dataX
xa
(
0
);
dataX
xb
(
0
);
dataX
xb
(
0
);
#endif
#endif
#if
0
#if
1
if
(
threadIdx
.
x
==
0
)
if
(
threadIdx
.
x
==
0
)
{
{
// xa = 127; // 1.0
for
(
size_t
i
=
0
;
i
<
32
;
i
++
)
for
(
size_t
i
=
0
;
i
<
32
;
i
++
)
{
{
regA
[
i
]
=
0x38
;
// 1.0
regA
[
i
]
=
0x38
;
// 1.0
...
@@ -31,27 +50,23 @@ __global__ void kernel()
...
@@ -31,27 +50,23 @@ __global__ void kernel()
{
{
regB
[
i
]
=
0x38
;
// 1.0
regB
[
i
]
=
0x38
;
// 1.0
}
}
printf("thread: %u -- xA: %x\n", threadIdx.x, xa[threadIdx.x / 32]);
printf("thread: %u -- xB: %x\n", threadIdx.x, xb[threadIdx.x / 32]);
}
}
if
(
threadIdx
.
x
==
32
)
if
(
threadIdx
.
x
==
32
)
{
{
// xa = 126; // 0.5
for
(
size_t
i
=
0
;
i
<
32
;
i
++
)
for
(
size_t
i
=
0
;
i
<
32
;
i
++
)
{
{
regA[i] = 0x
C
0; //
-
2.0
regA
[
i
]
=
0x
4
0
;
// 2.0
}
}
for
(
size_t
i
=
0
;
i
<
32
;
i
++
)
for
(
size_t
i
=
0
;
i
<
32
;
i
++
)
{
{
regB
[
i
]
=
0x38
;
// 1.0
regB
[
i
]
=
0x38
;
// 1.0
}
}
printf("thread: %u -- xA: %x\n", threadIdx.x, xa[threadIdx.x / 32]);
printf("thread: %u -- xB: %x\n", threadIdx.x, xb[threadIdx.x / 32]);
}
}
#endif
#endif
__syncthreads
();
__syncthreads
();
#if 1
printf
(
"thread: %u -- regA: %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x "
printf
(
"thread: %u -- regA: %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x "
"%x %x %x %x %x %x %x %x %x %x
\n
"
,
"%x %x %x %x %x %x %x %x %x %x
\n
"
,
threadIdx
.
x
,
threadIdx
.
x
,
...
@@ -123,7 +138,7 @@ __global__ void kernel()
...
@@ -123,7 +138,7 @@ __global__ void kernel()
regB
[
29
],
regB
[
29
],
regB
[
30
],
regB
[
30
],
regB
[
31
]);
regB
[
31
]);
#endif
//__builtin_amdgcn_mfma_ld_scale_b32(xb[threadIdx.x / 32], 0, 0);
//__builtin_amdgcn_mfma_ld_scale_b32(xb[threadIdx.x / 32], 0, 0);
regC
=
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4
(
regA
,
regC
=
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4
(
regA
,
regB
,
regB
,
...
@@ -136,6 +151,11 @@ __global__ void kernel()
...
@@ -136,6 +151,11 @@ __global__ void kernel()
xb
[
threadIdx
.
x
/
32
]);
xb
[
threadIdx
.
x
/
32
]);
__syncthreads
();
__syncthreads
();
if
(
threadIdx
.
x
==
0
||
threadIdx
.
x
==
32
)
{
printf
(
"thread: %u -- xA: %x
\n
"
,
threadIdx
.
x
,
bit_cast
<
int32_t
>
(
xa
[
threadIdx
.
x
/
32
]));
printf
(
"thread: %u -- xB: %x
\n
"
,
threadIdx
.
x
,
bit_cast
<
int32_t
>
(
xb
[
threadIdx
.
x
/
32
]));
}
printf
(
"thread: %u -- regC: %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f
\n
"
,
printf
(
"thread: %u -- regC: %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f
\n
"
,
threadIdx
.
x
,
threadIdx
.
x
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment