Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
753b98b5
Commit
753b98b5
authored
Apr 03, 2019
by
Jing Zhang
Browse files
refactor inline asm
parent
114fdb58
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
519 deletions
+41
-519
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+25
-500
src/include/threadwise_gemm.hip.hpp
src/include/threadwise_gemm.hip.hpp
+16
-19
No files found.
src/include/blockwise_gemm.hip.hpp
View file @
753b98b5
#pragma once
#include "threadwise_gemm.hip.hpp"
extern
"C"
__attribute__
((
address_space
(
3
)))
void
*
__to_local
(
void
*
p
)[[
hc
]];
inline
__device__
void
outerProduct4x4
(
float4
&
a
,
float4
&
b
,
float4
&
c0
,
float4
&
c1
,
float4
&
c2
,
float4
&
c3
)
{
asm
volatile
(
"
\n
\
v_mac_f32 %0, %4, %5
\n
\
v_mac_f32 %1, %4, %6
\n
\
v_mac_f32 %2, %4, %7
\n
\
v_mac_f32 %3, %4, %8
\n
\
"
:
:
"v"
(
c0
.
x
),
"v"
(
c0
.
y
),
"v"
(
c0
.
z
),
"v"
(
c0
.
w
),
\
"v"
(
a
.
x
),
"v"
(
b
.
x
),
"v"
(
b
.
y
),
"v"
(
b
.
z
),
"v"
(
b
.
w
)
);
asm
volatile
(
"
\n
\
v_mac_f32 %0, %4, %5
\n
\
v_mac_f32 %1, %4, %6
\n
\
v_mac_f32 %2, %4, %7
\n
\
v_mac_f32 %3, %4, %8
\n
\
"
:
:
"v"
(
c1
.
x
),
"v"
(
c1
.
y
),
"v"
(
c1
.
z
),
"v"
(
c1
.
w
),
\
"v"
(
a
.
y
),
"v"
(
b
.
x
),
"v"
(
b
.
y
),
"v"
(
b
.
z
),
"v"
(
b
.
w
)
);
asm
volatile
(
"
\n
\
v_mac_f32 %0, %4, %5
\n
\
v_mac_f32 %1, %4, %6
\n
\
v_mac_f32 %2, %4, %7
\n
\
v_mac_f32 %3, %4, %8
\n
\
"
:
:
"v"
(
c2
.
x
),
"v"
(
c2
.
y
),
"v"
(
c2
.
z
),
"v"
(
c2
.
w
),
\
"v"
(
a
.
z
),
"v"
(
b
.
x
),
"v"
(
b
.
y
),
"v"
(
b
.
z
),
"v"
(
b
.
w
)
);
asm
volatile
(
"
\n
\
v_mac_f32 %0, %4, %5
\n
\
v_mac_f32 %1, %4, %6
\n
\
v_mac_f32 %2, %4, %7
\n
\
v_mac_f32 %3, %4, %8
\n
\
"
:
:
"v"
(
c3
.
x
),
"v"
(
c3
.
y
),
"v"
(
c3
.
z
),
"v"
(
c3
.
w
),
\
"v"
(
a
.
w
),
"v"
(
b
.
x
),
"v"
(
b
.
y
),
"v"
(
b
.
z
),
"v"
(
b
.
w
)
);
}
template
<
uint32_t
cnt
>
inline
__device__
void
lgkmcnt
(){
if
(
cnt
==
0
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(0)
\n
\
"
::
);
}
if
(
cnt
==
1
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(1)
\n
\
"
::
);
}
if
(
cnt
==
2
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(2)
\n
\
"
::
);
}
if
(
cnt
==
3
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(3)
\n
\
"
::
);
}
if
(
cnt
==
4
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(4)
\n
\
"
::
);
}
if
(
cnt
==
5
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(5)
\n
\
"
::
);
}
if
(
cnt
==
6
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(6)
\n
\
"
::
);
}
}
template
<
uint32_t
off
>
inline
__device__
void
shared_read_b128
(
float4
&
a0
,
float4
&
a1
,
float4
&
b0
,
float4
&
b1
,
uint32_t
&
ldsA
,
uint32_t
&
ldsB
)
{
if
(
off
==
0
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %4 offset:0
\n
\
ds_read_b128 %1, %4 offset:256
\n
\
ds_read_b128 %2, %5 offset:0
\n
\
ds_read_b128 %3, %5 offset:256
\n
\
"
:
"=v"
(
a0
),
"=v"
(
a1
),
"=v"
(
b0
),
"=v"
(
b1
)
:
"v"
(
ldsA
),
"v"
(
ldsB
));
}
if
(
off
==
1
*
512
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %4 offset:1*512
\n
\
ds_read_b128 %1, %4 offset:1*512+256
\n
\
ds_read_b128 %2, %5 offset:1*512
\n
\
ds_read_b128 %3, %5 offset:1*512+256
\n
\
"
:
"=v"
(
a0
),
"=v"
(
a1
),
"=v"
(
b0
),
"=v"
(
b1
)
:
"v"
(
ldsA
),
"v"
(
ldsB
));
}
if
(
off
==
2
*
512
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %4 offset:2*512
\n
\
ds_read_b128 %1, %4 offset:2*512+256
\n
\
ds_read_b128 %2, %5 offset:2*512
\n
\
ds_read_b128 %3, %5 offset:2*512+256
\n
\
"
:
"=v"
(
a0
),
"=v"
(
a1
),
"=v"
(
b0
),
"=v"
(
b1
)
:
"v"
(
ldsA
),
"v"
(
ldsB
));
}
if
(
off
==
3
*
512
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %4 offset:3*512
\n
\
ds_read_b128 %1, %4 offset:3*512+256
\n
\
ds_read_b128 %2, %5 offset:3*512
\n
\
ds_read_b128 %3, %5 offset:3*512+256
\n
\
"
:
"=v"
(
a0
),
"=v"
(
a1
),
"=v"
(
b0
),
"=v"
(
b1
)
:
"v"
(
ldsA
),
"v"
(
ldsB
));
}
if
(
off
==
4
*
512
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %4 offset:4*512
\n
\
ds_read_b128 %1, %4 offset:4*512+256
\n
\
ds_read_b128 %2, %5 offset:4*512
\n
\
ds_read_b128 %3, %5 offset:4*512+256
\n
\
"
:
"=v"
(
a0
),
"=v"
(
a1
),
"=v"
(
b0
),
"=v"
(
b1
)
:
"v"
(
ldsA
),
"v"
(
ldsB
));
}
if
(
off
==
5
*
512
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %4 offset:5*512
\n
\
ds_read_b128 %1, %4 offset:5*512+256
\n
\
ds_read_b128 %2, %5 offset:5*512
\n
\
ds_read_b128 %3, %5 offset:5*512+256
\n
\
"
:
"=v"
(
a0
),
"=v"
(
a1
),
"=v"
(
b0
),
"=v"
(
b1
)
:
"v"
(
ldsA
),
"v"
(
ldsB
));
}
if
(
off
==
6
*
512
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %4 offset:6*512
\n
\
ds_read_b128 %1, %4 offset:6*512+256
\n
\
ds_read_b128 %2, %5 offset:6*512
\n
\
ds_read_b128 %3, %5 offset:6*512+256
\n
\
"
:
"=v"
(
a0
),
"=v"
(
a1
),
"=v"
(
b0
),
"=v"
(
b1
)
:
"v"
(
ldsA
),
"v"
(
ldsB
));
}
if
(
off
==
7
*
512
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %4 offset:7*512
\n
\
ds_read_b128 %1, %4 offset:7*512+256
\n
\
ds_read_b128 %2, %5 offset:7*512
\n
\
ds_read_b128 %3, %5 offset:7*512+256
\n
\
"
:
"=v"
(
a0
),
"=v"
(
a1
),
"=v"
(
b0
),
"=v"
(
b1
)
:
"v"
(
ldsA
),
"v"
(
ldsB
));
}
}
template
<
index_t
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
...
...
@@ -557,341 +382,41 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
// loop over k
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
#if 1
auto
a_src_index
=
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
;
auto
b_src_index
=
b_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetB
;
uint32_t
a_loc
=
block_off
+
a_src_index
;
uint32_t
b_loc
=
b_src_index
;
//const float4* a_loc = (const float4*)(p_b_block + block_off + a_src_index);
//const float4* b_loc = (const float4*)(p_b_block + b_src_index);
float4
*
reg
=
(
float4
*
)(
p_thread
);
float4
*
c_v
=
(
float4
*
)(
p_c_thread
);
//shared_read_b128<0>(reg[0], reg[1], reg[2], reg[3], a_loc, b_loc);
//reg[0] = a_loc[0];
//reg[1] = a_loc[16];
//reg[2] = b_loc[0];
//reg[3] = b_loc[8];
//asm volatile("\n \
//ds_read_b128 %0, %1 \n \
//"
//: "=v"(reg[0])
//: "v"(a_loc)
//);
//asm volatile("\n \
//ds_read_b128 %0, %1 \n \
//"
//: "=v"(reg[1])
//: "v"(
a_loc +
256)
//
);
auto
a_src_index
=
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
;
auto
b_src_index
=
b_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetB
;
Float4
*
reg_a
=
(
Float4
*
)(
p_a_thread
);
Float4
*
reg_b
=
(
Float4
*
)(
p_b_thread
);
void
*
a_loc
=
(
void
*
)(
p_
a_
b
loc
k
+
a_src_index
);
void
*
b_loc
=
(
void
*
)(
p_b_block
+
b_src_index
);
//asm volatile("\n \
//ds_read_b128 %0, %1 \n \
//ds_read_b128 %0, %2 \n \
//ds_read_b128 %1, %2 offset:256\n \
//"
//: "=v"(reg
[2
])
//: "v"(
b_loc
)
//: "=v"(reg
_a[0]), "=v"(reg_a[1
])
//: "v"(
__to_local(a_loc)
)
//);
//asm volatile("\n \
//ds_read_b128 %0, %1\n \
//"
//: "=v"(reg[3])
//: "v"(b_loc + 128)
//);
ds_read_b128
(
reg_a
[
0
],
a_loc
,
0
);
ds_read_b128
(
reg_a
[
1
],
a_loc
,
256
);
asm
volatile
(
"
\n
\
ds_read_b128 %0, %4
\n
\
ds_read_b128 %1, %4 offset:16
\n
\
ds_read_b128 %2, %5
\n
\
ds_read_b128 %3, %5 offset:8
\n
\
"
:
"=v"
(
reg
[
0
]),
"=v"
(
reg
[
1
]),
"=v"
(
reg
[
2
]),
"=v"
(
reg
[
3
])
:
"v"
(
a_loc
),
"v"
(
b_loc
)
);
lgkmcnt
<
0
>
();
outerProduct4x4
(
reg
[
0
],
reg
[
2
],
c_v
[
0
],
c_v
[
1
],
c_v
[
2
],
c_v
[
3
]);
outerProduct4x4
(
reg
[
0
],
reg
[
3
],
c_v
[
4
],
c_v
[
5
],
c_v
[
6
],
c_v
[
7
]);
outerProduct4x4
(
reg
[
1
],
reg
[
2
],
c_v
[
8
],
c_v
[
9
],
c_v
[
10
],
c_v
[
11
]);
outerProduct4x4
(
reg
[
1
],
reg
[
3
],
c_v
[
12
],
c_v
[
13
],
c_v
[
14
],
c_v
[
15
]);
ds_read_b128
(
reg_b
[
0
],
b_loc
,
0
);
ds_read_b128
(
reg_b
[
1
],
b_loc
,
128
);
//asm volatile("\n \
//ds_read_b32 %0, %16 \n \
//ds_read_b32 %1, %16 offset:1\n \
//ds_read_b32 %2, %16 offset:2\n \
//ds_read_b32 %3, %16 offset:3\n \
//ds_read_b32 %4, %17 \n \
//ds_read_b32 %5, %17 offset:1\n \
//ds_read_b32 %6, %17 offset:2\n \
//ds_read_b32 %7, %17 offset:3\n \
//ds_read_b32 %8, %18 \n \
//ds_read_b32 %9, %18 offset:1\n \
//ds_read_b32 %10, %18 offset:2\n \
//ds_read_b32 %11, %18 offset:3\n \
//ds_read_b32 %12, %19 \n \
//ds_read_b32 %13, %19 offset:1\n \
//ds_read_b32 %14, %19 offset:2\n \
//ds_read_b32 %15, %19 offset:3\n \
//s_waitcnt lgkmcnt(0)"
//:
//"=v"(p_a_thread[0]),
//"=v"(p_a_thread[1]),
//"=v"(p_a_thread[2]),
//"=v"(p_a_thread[3]),
//"=v"(p_a_thread[4]),
//"=v"(p_a_thread[5]),
//"=v"(p_a_thread[6]),
//"=v"(p_a_thread[7]),
//"=v"(p_b_thread[0]),
//"=v"(p_b_thread[1]),
//"=v"(p_b_thread[2]),
//"=v"(p_b_thread[3]),
//"=v"(p_b_thread[4]),
//"=v"(p_b_thread[5]),
//"=v"(p_b_thread[6]),
//"=v"(p_b_thread[7])
//:
//"v"(__to_local((void *)(&p_a_block[0]))),
//"v"(__to_local((void *)(&p_a_block[64]))),
//"v"(__to_local((void *)(&p_b_block[0]))),
//"v"(__to_local((void *)(&p_b_block[32])))
//);
lgkmcnt
(
0
);
//C = A * B
#else
auto
a_src_index
=
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
;
auto
b_src_index
=
b_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetB
;
auto
dst_index
=
a_thread_sub_mtx
.
Get1dIndex
(
0
,
0
);
const
float4
*
a_loc
=
(
const
float4
*
)(
p_a_block
+
a_src_index
);
const
float4
*
b_loc
=
(
const
float4
*
)(
p_b_block
+
b_src_index
);
float4
*
reg
=
(
float4
*
)(
p_a_thread
+
dst_index
);
asm
volatile
(
"
\n
\
ds_read2_b64 %0, %84 offset1:1
\n
\
ds_read2_b64 %1, %84 offset0:32 offset1:33
\n
\
ds_read2_b64 %2, %85 offset1:1
\n
\
ds_read2_b64 %3, %85 offset0:16 offset1:17
\n
\
s_waitcnt lgkmcnt(0)
\n
\
v_mac_f32 %4, %68, %76
\n
\
v_mac_f32 %5, %68, %77
\n
\
v_mac_f32 %6, %68, %78
\n
\
v_mac_f32 %7, %68, %79
\n
\
v_mac_f32 %8, %68, %80
\n
\
v_mac_f32 %9, %68, %81
\n
\
v_mac_f32 %10, %68, %82
\n
\
v_mac_f32 %11, %68, %83
\n
\
v_mac_f32 %12, %69, %76
\n
\
v_mac_f32 %13, %69, %77
\n
\
v_mac_f32 %14, %69, %78
\n
\
v_mac_f32 %15, %69, %79
\n
\
v_mac_f32 %16, %69, %80
\n
\
v_mac_f32 %17, %69, %81
\n
\
v_mac_f32 %18, %69, %82
\n
\
v_mac_f32 %19, %69, %83
\n
\
v_mac_f32 %20, %70, %76
\n
\
v_mac_f32 %21, %70, %77
\n
\
v_mac_f32 %22, %70, %78
\n
\
v_mac_f32 %23, %70, %79
\n
\
v_mac_f32 %24, %70, %80
\n
\
v_mac_f32 %25, %70, %81
\n
\
v_mac_f32 %26, %70, %82
\n
\
v_mac_f32 %27, %70, %83
\n
\
v_mac_f32 %28, %71, %76
\n
\
v_mac_f32 %29, %71, %77
\n
\
v_mac_f32 %30, %71, %78
\n
\
v_mac_f32 %31, %71, %79
\n
\
v_mac_f32 %32, %71, %80
\n
\
v_mac_f32 %33, %71, %81
\n
\
v_mac_f32 %34, %71, %82
\n
\
v_mac_f32 %35, %71, %83
\n
\
v_mac_f32 %36, %72, %76
\n
\
v_mac_f32 %37, %72, %77
\n
\
v_mac_f32 %38, %72, %78
\n
\
v_mac_f32 %39, %72, %79
\n
\
v_mac_f32 %40, %72, %80
\n
\
v_mac_f32 %41, %72, %81
\n
\
v_mac_f32 %42, %72, %82
\n
\
v_mac_f32 %43, %72, %83
\n
\
v_mac_f32 %44, %73, %76
\n
\
v_mac_f32 %45, %73, %77
\n
\
v_mac_f32 %46, %73, %78
\n
\
v_mac_f32 %47, %73, %79
\n
\
v_mac_f32 %48, %73, %80
\n
\
v_mac_f32 %49, %73, %81
\n
\
v_mac_f32 %50, %73, %82
\n
\
v_mac_f32 %51, %73, %83
\n
\
v_mac_f32 %52, %74, %76
\n
\
v_mac_f32 %53, %74, %77
\n
\
v_mac_f32 %54, %74, %78
\n
\
v_mac_f32 %55, %74, %79
\n
\
v_mac_f32 %56, %74, %80
\n
\
v_mac_f32 %57, %74, %81
\n
\
v_mac_f32 %58, %74, %82
\n
\
v_mac_f32 %59, %74, %83
\n
\
v_mac_f32 %60, %75, %76
\n
\
v_mac_f32 %61, %75, %77
\n
\
v_mac_f32 %62, %75, %78
\n
\
v_mac_f32 %63, %75, %79
\n
\
v_mac_f32 %64, %75, %80
\n
\
v_mac_f32 %65, %75, %81
\n
\
v_mac_f32 %66, %75, %82
\n
\
v_mac_f32 %67, %75, %83
\n
\
"
:
"=v"
(
reg
[
0
]),
"=v"
(
reg
[
1
]),
"=v"
(
reg
[
2
]),
"=v"
(
reg
[
3
]),
"=v"
(
p_c_thread
[
0
]),
"=v"
(
p_c_thread
[
1
]),
"=v"
(
p_c_thread
[
2
]),
"=v"
(
p_c_thread
[
3
]),
"=v"
(
p_c_thread
[
4
]),
"=v"
(
p_c_thread
[
5
]),
"=v"
(
p_c_thread
[
6
]),
"=v"
(
p_c_thread
[
7
]),
"=v"
(
p_c_thread
[
8
]),
"=v"
(
p_c_thread
[
9
]),
"=v"
(
p_c_thread
[
10
]),
"=v"
(
p_c_thread
[
11
]),
"=v"
(
p_c_thread
[
12
]),
"=v"
(
p_c_thread
[
13
]),
"=v"
(
p_c_thread
[
14
]),
"=v"
(
p_c_thread
[
15
]),
"=v"
(
p_c_thread
[
16
]),
"=v"
(
p_c_thread
[
17
]),
"=v"
(
p_c_thread
[
18
]),
"=v"
(
p_c_thread
[
19
]),
"=v"
(
p_c_thread
[
20
]),
"=v"
(
p_c_thread
[
21
]),
"=v"
(
p_c_thread
[
22
]),
"=v"
(
p_c_thread
[
23
]),
"=v"
(
p_c_thread
[
24
]),
"=v"
(
p_c_thread
[
25
]),
"=v"
(
p_c_thread
[
26
]),
"=v"
(
p_c_thread
[
27
]),
"=v"
(
p_c_thread
[
28
]),
"=v"
(
p_c_thread
[
29
]),
"=v"
(
p_c_thread
[
30
]),
"=v"
(
p_c_thread
[
31
]),
"=v"
(
p_c_thread
[
32
]),
"=v"
(
p_c_thread
[
33
]),
"=v"
(
p_c_thread
[
34
]),
"=v"
(
p_c_thread
[
35
]),
"=v"
(
p_c_thread
[
36
]),
"=v"
(
p_c_thread
[
37
]),
"=v"
(
p_c_thread
[
38
]),
"=v"
(
p_c_thread
[
39
]),
"=v"
(
p_c_thread
[
40
]),
"=v"
(
p_c_thread
[
41
]),
"=v"
(
p_c_thread
[
42
]),
"=v"
(
p_c_thread
[
43
]),
"=v"
(
p_c_thread
[
44
]),
"=v"
(
p_c_thread
[
45
]),
"=v"
(
p_c_thread
[
46
]),
"=v"
(
p_c_thread
[
47
]),
"=v"
(
p_c_thread
[
48
]),
"=v"
(
p_c_thread
[
49
]),
"=v"
(
p_c_thread
[
50
]),
"=v"
(
p_c_thread
[
51
]),
"=v"
(
p_c_thread
[
52
]),
"=v"
(
p_c_thread
[
53
]),
"=v"
(
p_c_thread
[
54
]),
"=v"
(
p_c_thread
[
55
]),
"=v"
(
p_c_thread
[
56
]),
"=v"
(
p_c_thread
[
57
]),
"=v"
(
p_c_thread
[
58
]),
"=v"
(
p_c_thread
[
59
]),
"=v"
(
p_c_thread
[
60
]),
"=v"
(
p_c_thread
[
61
]),
"=v"
(
p_c_thread
[
62
]),
"=v"
(
p_c_thread
[
63
])
:
"v"
(
p_a_thread
[
0
]),
"v"
(
p_a_thread
[
1
]),
"v"
(
p_a_thread
[
2
]),
"v"
(
p_a_thread
[
3
]),
"v"
(
p_a_thread
[
4
]),
"v"
(
p_a_thread
[
5
]),
"v"
(
p_a_thread
[
6
]),
"v"
(
p_a_thread
[
7
]),
"v"
(
p_b_thread
[
0
]),
"v"
(
p_b_thread
[
1
]),
"v"
(
p_b_thread
[
2
]),
"v"
(
p_b_thread
[
3
]),
"v"
(
p_b_thread
[
4
]),
"v"
(
p_b_thread
[
5
]),
"v"
(
p_b_thread
[
6
]),
"v"
(
p_b_thread
[
7
]),
"v"
(
__to_local
((
void
*
)(
a_loc
))),
"v"
(
__to_local
((
void
*
)(
b_loc
))),
"4"
(
p_c_thread
[
0
]),
"5"
(
p_c_thread
[
1
]),
"6"
(
p_c_thread
[
2
]),
"7"
(
p_c_thread
[
3
]),
"8"
(
p_c_thread
[
4
]),
"9"
(
p_c_thread
[
5
]),
"10"
(
p_c_thread
[
6
]),
"11"
(
p_c_thread
[
7
]),
"12"
(
p_c_thread
[
8
]),
"13"
(
p_c_thread
[
9
]),
"14"
(
p_c_thread
[
10
]),
"15"
(
p_c_thread
[
11
]),
"16"
(
p_c_thread
[
12
]),
"17"
(
p_c_thread
[
13
]),
"18"
(
p_c_thread
[
14
]),
"19"
(
p_c_thread
[
15
]),
"20"
(
p_c_thread
[
16
]),
"21"
(
p_c_thread
[
17
]),
"22"
(
p_c_thread
[
18
]),
"23"
(
p_c_thread
[
19
]),
"24"
(
p_c_thread
[
20
]),
"25"
(
p_c_thread
[
21
]),
"26"
(
p_c_thread
[
22
]),
"27"
(
p_c_thread
[
23
]),
"28"
(
p_c_thread
[
24
]),
"29"
(
p_c_thread
[
25
]),
"30"
(
p_c_thread
[
26
]),
"31"
(
p_c_thread
[
27
]),
"32"
(
p_c_thread
[
28
]),
"33"
(
p_c_thread
[
29
]),
"34"
(
p_c_thread
[
30
]),
"35"
(
p_c_thread
[
31
]),
"36"
(
p_c_thread
[
32
]),
"37"
(
p_c_thread
[
33
]),
"38"
(
p_c_thread
[
34
]),
"39"
(
p_c_thread
[
35
]),
"40"
(
p_c_thread
[
36
]),
"41"
(
p_c_thread
[
37
]),
"42"
(
p_c_thread
[
38
]),
"43"
(
p_c_thread
[
39
]),
"44"
(
p_c_thread
[
40
]),
"45"
(
p_c_thread
[
41
]),
"46"
(
p_c_thread
[
42
]),
"47"
(
p_c_thread
[
43
]),
"48"
(
p_c_thread
[
44
]),
"49"
(
p_c_thread
[
45
]),
"50"
(
p_c_thread
[
46
]),
"51"
(
p_c_thread
[
47
]),
"52"
(
p_c_thread
[
48
]),
"53"
(
p_c_thread
[
49
]),
"54"
(
p_c_thread
[
50
]),
"55"
(
p_c_thread
[
51
]),
"56"
(
p_c_thread
[
52
]),
"57"
(
p_c_thread
[
53
]),
"58"
(
p_c_thread
[
54
]),
"59"
(
p_c_thread
[
55
]),
"60"
(
p_c_thread
[
56
]),
"61"
(
p_c_thread
[
57
]),
"62"
(
p_c_thread
[
58
]),
"63"
(
p_c_thread
[
59
]),
"64"
(
p_c_thread
[
60
]),
"65"
(
p_c_thread
[
61
]),
"66"
(
p_c_thread
[
62
]),
"67"
(
p_c_thread
[
63
]));
#endif
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
,
f_accum
);
}
}
...
...
src/include/threadwise_gemm.hip.hpp
View file @
753b98b5
#pragma once
#include "inline_asm.hpp"
template
<
class
Float
,
class
SrcMatrix
,
class
DstMatrix
,
index_t
NRow
,
index_t
NCol
>
__device__
void
threadwise_matrix_copy
(
SrcMatrix
,
const
Float
*
__restrict__
p_src
,
...
...
@@ -21,18 +23,18 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
p_dst[dst_index] = p_src[src_index];
}
}
#el
if
1
#el
se
static_assert
(
NCol
==
4
,
"only for NCol == 4"
);
using
vector_t
=
typename
vector_type
<
Float
,
4
>::
MemoryType
;
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
const
index_t
src_index
=
src_mtx
.
Get1dIndex
(
i
,
0
);
const
index_t
dst_index
=
dst_mtx
.
Get1dIndex
(
i
,
0
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]));
Float4
*
reg_p
=
(
Float4
*
)
&
p_dst
[
dst_index
];
Float4
*
loc_p
=
(
Float4
*
)
&
p_src
[
src_index
];
ds_read_b128
(
reg_p
[
0
],
(
void
*
)
&
loc_p
[
0
]);
}
#endif
}
...
...
@@ -70,25 +72,20 @@ __device__ void threadwise_gemm(MatrixA,
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
for
(
index_t
i
=
0
;
i
<
M
;
++
i
)
{
for
(
index_t
j
=
0
;
j
<
N
;
++
j
)
for
(
index_t
i
=
0
;
i
<
M
;
i
+=
4
)
{
const
index_t
aindex
=
a_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
Float4
*
a_vec
=
(
const
Float4
*
)
&
p_a_thread
[
aindex
];
for
(
index_t
j
=
0
;
j
<
N
;
j
+=
4
)
{
const
index_t
bindex
=
b_mtx
.
Get1dIndex
(
k
,
j
);
const
index_t
cindex
=
c_mtx
.
Get1dIndex
(
i
,
j
);
#if 0
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
#elif
1
asm
volatile
(
"
\n
\
v_mac_f32 %0, %1, %2
\n
\
"
:
"=v"
(
p_c_thread
[
cindex
])
:
"v"
(
p_a_thread
[
aindex
]),
"v"
(
p_b_thread
[
bindex
]),
"0"
(
p_c_thread
[
cindex
]));
#endif
const
Float4
*
b_vec
=
(
const
Float4
*
)
&
p_b_thread
[
bindex
];
Float4
*
c_vec
=
(
Float4
*
)
&
p_c_thread
[
cindex
];
outerProduct4x4
(
a_vec
[
0
],
b_vec
[
0
],
c_vec
[
0
],
c_vec
[
2
],
c_vec
[
4
],
c_vec
[
6
]);
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment