Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
13af8cc4
Commit
13af8cc4
authored
Dec 13, 2022
by
aska-0096
Browse files
add inline asm for wmmaop test
parent
e43df26a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
7 deletions
+10
-7
test/wmma_op/wmma_op_util.hpp
test/wmma_op/wmma_op_util.hpp
+10
-7
No files found.
test/wmma_op/wmma_op_util.hpp
View file @
13af8cc4
...
@@ -97,7 +97,7 @@ builtin_wmma_naive_selector<int4x16_t,
...
@@ -97,7 +97,7 @@ builtin_wmma_naive_selector<int4x16_t,
template
<
typename
src_t
,
typename
dst_t
,
typename
acc_t
,
index_t
acc_num
>
template
<
typename
src_t
,
typename
dst_t
,
typename
acc_t
,
index_t
acc_num
>
__global__
void
matmul
(
const
src_t
*
a
,
const
src_t
*
b
,
dst_t
*
c
)
__global__
void
matmul
(
const
src_t
*
a
,
const
src_t
*
b
,
dst_t
*
c
)
{
{
__shared__
src_t
p_shared
[
16
*
16
*
2
];
__shared__
src_t
p_shared
[
16
*
16
*
2
];
const
int
lIdx
=
threadIdx
.
x
;
const
int
lIdx
=
threadIdx
.
x
;
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and
// b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the
// b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the
...
@@ -115,7 +115,7 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
...
@@ -115,7 +115,7 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
// lane is (0-31) mod 16 instead of 0-31 due to matrix replication in gfx11
// lane is (0-31) mod 16 instead of 0-31 due to matrix replication in gfx11
// see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482
// see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482
// TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101
// TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101
const
int
lane
=
lIdx
%
16
;
const
int
lane
=
lIdx
%
16
;
const
int
lane_lo
=
lIdx
/
2
;
const
int
lane_lo
=
lIdx
/
2
;
const
int
lane_hi
=
lIdx
%
2
;
const
int
lane_hi
=
lIdx
%
2
;
for
(
int
ele
=
0
;
ele
<
8
;
++
ele
)
for
(
int
ele
=
0
;
ele
<
8
;
++
ele
)
...
@@ -129,15 +129,15 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
...
@@ -129,15 +129,15 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
}
}
__syncthreads
();
__syncthreads
();
for
(
int
ele
=
0
;
ele
<
8
;
++
ele
)
for
(
int
ele
=
0
;
ele
<
8
;
++
ele
)
{
{
p_shared
[
8
*
16
*
lane_hi
+
8
*
lane_lo
+
ele
]
=
a_temp
[
ele
];
p_shared
[
8
*
16
*
lane_hi
+
8
*
lane_lo
+
ele
]
=
a_temp
[
ele
];
}
}
for
(
int
ele
=
0
;
ele
<
8
;
++
ele
)
for
(
int
ele
=
0
;
ele
<
8
;
++
ele
)
{
{
p_shared
[
8
*
16
*
lane_hi
+
8
*
lane_lo
+
ele
+
16
*
16
]
=
b_temp
[
ele
];
p_shared
[
8
*
16
*
lane_hi
+
8
*
lane_lo
+
ele
+
16
*
16
]
=
b_temp
[
ele
];
}
}
asm
volatile
(
"\
asm
volatile
(
"\
...
@@ -147,12 +147,12 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
...
@@ -147,12 +147,12 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
{
b_frag
[
ele
]
=
p_shared
[(
ele
/
8
)
*
16
*
8
+
8
*
lane
+
ele
%
8
+
16
*
16
];
b_frag
[
ele
]
=
p_shared
[(
ele
/
8
)
*
16
*
8
+
8
*
lane
+
ele
%
8
+
16
*
16
];
}
}
// follow origin design
// follow origin design
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
{
a_frag
[
ele
]
=
p_shared
[(
ele
/
8
)
*
16
*
8
+
8
*
lane
+
ele
%
8
];
a_frag
[
ele
]
=
p_shared
[(
ele
/
8
)
*
16
*
8
+
8
*
lane
+
ele
%
8
];
}
}
asm
volatile
(
"\
asm
volatile
(
"\
...
@@ -163,6 +163,9 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
...
@@ -163,6 +163,9 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
// sync threads, similar to mma_sync
// sync threads, similar to mma_sync
// __syncthreads();
// __syncthreads();
builtin_wmma_naive_selector
<
src_vec
,
acc_vec
>
(
a_frag
,
b_frag
,
c_thread_buf_
);
builtin_wmma_naive_selector
<
src_vec
,
acc_vec
>
(
a_frag
,
b_frag
,
c_thread_buf_
);
// since only fp16_fp32 asm wmma implemented for experiment purpose, restrict test case to fp16
// when enable this ck::amd_assembly_wmma_f32_16x16x16_f16_w32(a_frag, b_frag,
// c_thread_buf_.GetVectorTypeReference(Number<0>{}).template AsType<float8_t>()(Number<0>{}));
__syncthreads
();
__syncthreads
();
// wait for results, similar to mma_sync
// wait for results, similar to mma_sync
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
ele
)
{
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
ele
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment