Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7dca8463
Commit
7dca8463
authored
Oct 21, 2022
by
aska-0096
Browse files
add arch limitation to wmma test
parent
36c38ad9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
3 deletions
+19
-3
test/wmma_op/wmma_op.cpp
test/wmma_op/wmma_op.cpp
+19
-3
No files found.
test/wmma_op/wmma_op.cpp
View file @
7dca8463
...
...
@@ -16,6 +16,7 @@
namespace
ck
{
__global__
void
matmul
(
const
half_t
*
a
,
const
half_t
*
b
,
float
*
c
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
const
int
lIdx
=
threadIdx
.
x
;
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and
...
...
@@ -52,10 +53,16 @@ __global__ void matmul(const half_t* a, const half_t* b, float* c)
// store results from unpacked c_thread_buf_ output
c
[
16
*
r
+
lane
]
=
c_thread_buf_
[
Number
<
ele
>
{}];
});
#else
ignore
=
a
;
ignore
=
b
;
ignore
=
c
;
#endif // end of if (defined(__gfx1100__))
}
__global__
void
matmul_swizzle_a
(
const
half_t
*
a
,
const
half_t
*
b
,
float
*
c
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
const
int
lIdx
=
threadIdx
.
x
;
half16_t
a_frag
=
{};
...
...
@@ -85,6 +92,11 @@ __global__ void matmul_swizzle_a(const half_t* a, const half_t* b, float* c)
const
int
r
=
ele
;
c
[
16
*
8
*
blk
+
16
*
r
+
lane
]
=
c_thread_buf_
[
Number
<
ele
>
{}];
});
#else
ignore
=
a
;
ignore
=
b
;
ignore
=
c
;
#endif // end of if (defined(__gfx1100__))
}
}
// namespace ck
...
...
@@ -152,16 +164,20 @@ int main(int, char*[])
device_c
.
FromDevice
(
wmma_c
.
data
());
bool
res
=
ck
::
utils
::
check_err
(
wmma_c
,
host_c
,
"Error: Incorrect results!"
,
1e-2
);
// run single wave wmma_swizzle_a on GPU
ck
::
matmul_swizzle_a
<<<
1
,
32
>>>
(
static_cast
<
const
ck
::
half_t
*>
(
device_a
.
GetDeviceBuffer
()),
static_cast
<
const
ck
::
half_t
*>
(
device_b
.
GetDeviceBuffer
()),
static_cast
<
float
*>
(
device_c
.
GetDeviceBuffer
()));
device_c
.
FromDevice
(
wmma_c_swizzle_a
.
data
());
bool
res_swizzle_a
=
// result check
bool
res
=
true
;
bool
res_swizzle_a
=
true
;
#if(defined(__gfx1100__))
res
=
ck
::
utils
::
check_err
(
wmma_c
,
host_c
,
"Error: Incorrect results!"
,
1e-2
);
res_swizzle_a
=
ck
::
utils
::
check_err
(
wmma_c_swizzle_a
,
host_c
,
"Error: Incorrect results!"
,
1e-2
);
#endif // end of if (defined(__gfx1100__))
if
(
res
&&
res_swizzle_a
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment