Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
e43d7bc6
Commit
e43d7bc6
authored
Apr 01, 2019
by
Chao Liu
Browse files
refactor
parent
d058d164
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
862 additions
and
917 deletions
+862
-917
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
+9
-7
driver/driver.hip.cpp
driver/driver.hip.cpp
+1
-1
src/include/ConstantTensorDescriptor.hip.hpp
src/include/ConstantTensorDescriptor.hip.hpp
+4
-1
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+363
-485
src/include/common.hip.hpp
src/include/common.hip.hpp
+34
-12
src/include/gridwise_direct_convolution_1.hip.hpp
src/include/gridwise_direct_convolution_1.hip.hpp
+5
-5
src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp
...lude/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp
+5
-4
src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
...se_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
+4
-4
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp
...idwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp
+5
-4
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp
...implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp
+4
-4
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
...idwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
+221
-172
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
...mm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
+204
-196
src/include/threadwise_gemm.hip.hpp
src/include/threadwise_gemm.hip.hpp
+3
-22
No files found.
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
View file @
e43d7bc6
...
@@ -270,7 +270,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -270,7 +270,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
float
time
=
launch_kernel
(
constexpr
auto
gridwise_conv
=
#if 1
#if 1
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
#else
#else
...
@@ -301,12 +301,14 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -301,12 +301,14 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
,
WeiBlockCopyThreadPerDim1
,
InBlockCopyDataPerRead
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
>
,
WeiBlockCopyDataPerRead
>
();
dim3
(
GridSize
),
dim3
(
BlockSize
),
float
time
=
launch_kernel
(
gridwise_conv
.
Run
,
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
dim3
(
GridSize
),
static_cast
<
T
*>
(
wei_cyxk_device_buf
.
GetDeviceBuffer
()),
dim3
(
BlockSize
),
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_cyxk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
...
...
driver/driver.hip.cpp
View file @
e43d7bc6
...
@@ -580,7 +580,7 @@ int main(int argc, char* argv[])
...
@@ -580,7 +580,7 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif
0
#elif
1
// 1x1 filter, 14x14 image, C = 2048
// 1x1 filter, 14x14 image, C = 2048
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
2048
;
constexpr
index_t
C
=
2048
;
...
...
src/include/ConstantTensorDescriptor.hip.hpp
View file @
e43d7bc6
...
@@ -137,7 +137,10 @@ struct ConstantTensorDescriptor
...
@@ -137,7 +137,10 @@ struct ConstantTensorDescriptor
}
}
};
};
return
static_const_reduce_n
<
nDim
>
{}(
GetElementSpace_f
{},
add
{})
+
align
.
Get
();
index_t
element_space_unaligned
=
static_const_reduce_n
<
nDim
>
{}(
GetElementSpace_f
{},
add
{})
+
1
;
return
align
.
Get
()
*
((
element_space_unaligned
+
align
.
Get
()
-
1
)
/
align
.
Get
());
}
}
template
<
class
...
Is
>
template
<
class
...
Is
>
...
...
src/include/blockwise_gemm.hip.hpp
View file @
e43d7bc6
#pragma once
#pragma once
#include "threadwise_gemm.hip.hpp"
#include "threadwise_gemm.hip.hpp"
extern
"C"
__attribute__
((
address_space
(
3
)))
void
*
__to_local
(
void
*
p
)
[[
hc
]];
extern
"C"
__attribute__
((
address_space
(
3
)))
void
*
__to_local
(
void
*
p
)[[
hc
]];
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixA
,
...
@@ -335,10 +335,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -335,10 +335,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run
(
const
FloatA
*
__restrict__
p_a_block
,
__device__
void
Run
_asm
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
FloatC
*
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
Accumulator
f_accum
)
const
{
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
...
@@ -368,10 +368,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -368,10 +368,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
float
p_thread
[
a_thread_mtx
.
GetElementSpace
()
+
b_thread_mtx
.
GetElementSpace
()];
float
p_thread
[
a_thread_mtx
.
GetElementSpace
()
+
b_thread_mtx
.
GetElementSpace
()];
FloatA
*
p_a_thread
=
p_thread
;
FloatA
*
p_a_thread
=
p_thread
;
FloatB
*
p_b_thread
=
p_thread
+
a_thread_mtx
.
GetElementSpace
();
FloatB
*
p_b_thread
=
p_thread
+
a_thread_mtx
.
GetElementSpace
();
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
...
@@ -387,9 +387,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -387,9 +387,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
auto
a_src_index
=
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
;
auto
a_src_index
=
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
;
auto
b_src_index
=
b_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetB
;
auto
b_src_index
=
b_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetB
;
const
float4
*
a_loc
=
(
const
float4
*
)(
p_a_block
+
a_src_index
);
const
float4
*
a_loc
=
(
const
float4
*
)(
p_a_block
+
a_src_index
);
const
float4
*
b_loc
=
(
const
float4
*
)(
p_b_block
+
b_src_index
);
const
float4
*
b_loc
=
(
const
float4
*
)(
p_b_block
+
b_src_index
);
float4
*
reg
=
(
float4
*
)(
p_thread
);
float4
*
reg
=
(
float4
*
)(
p_thread
);
reg
[
0
]
=
a_loc
[
0
];
reg
[
0
]
=
a_loc
[
0
];
reg
[
1
]
=
a_loc
[
16
];
reg
[
1
]
=
a_loc
[
16
];
...
@@ -398,41 +398,41 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -398,41 +398,41 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
//asm volatile("\n \
//asm volatile("\n \
//ds_read2_b64 %0, %1 offset1:1 \n \
//ds_read2_b64 %0, %1 offset1:1 \n \
//s_waitcnt lgkmcnt(0)"
//s_waitcnt lgkmcnt(0)"
//: "=v"(reg[0])
//: "=v"(reg[0])
//: "v"(__to_local((void *)(a_loc)))
//: "v"(__to_local((void *)(a_loc)))
//);
//);
//asm volatile("\n \
//asm volatile("\n \
//ds_read2_b64 %0, %1 offset1:1 \n \
//ds_read2_b64 %0, %1 offset1:1 \n \
//s_waitcnt lgkmcnt(0)"
//s_waitcnt lgkmcnt(0)"
//: "=v"(reg[1])
//: "=v"(reg[1])
//: "v"(__to_local((void *)(a_loc + 16)))
//: "v"(__to_local((void *)(a_loc + 16)))
//);
//);
//asm volatile("\n \
//asm volatile("\n \
//ds_read2_b64 %0, %1 offset1:1 \n \
//ds_read2_b64 %0, %1 offset1:1 \n \
//s_waitcnt lgkmcnt(0)"
//s_waitcnt lgkmcnt(0)"
//: "=v"(reg[2])
//: "=v"(reg[2])
//: "v"(__to_local((void *)(b_loc)))
//: "v"(__to_local((void *)(b_loc)))
//);
//);
//asm volatile("\n \
//asm volatile("\n \
//ds_read2_b64 %0, %1 offset1:1 \n \
//ds_read2_b64 %0, %1 offset1:1 \n \
//s_waitcnt lgkmcnt(0)"
//s_waitcnt lgkmcnt(0)"
//: "=v"(reg[3])
//: "=v"(reg[3])
//: "v"(__to_local((void *)(b_loc + 8)))
//: "v"(__to_local((void *)(b_loc + 8)))
//);
//);
//asm volatile("\n \
//asm volatile("\n \
//ds_read2_b64 %0, %4 offset1:1 \n \
//ds_read2_b64 %0, %4 offset1:1 \n \
//ds_read2_b64 %1, %4 offset0:32 offset1:33 \n \
//ds_read2_b64 %1, %4 offset0:32 offset1:33 \n \
//ds_read2_b64 %2, %5 offset1:1 \n \
//ds_read2_b64 %2, %5 offset1:1 \n \
//ds_read2_b64 %3, %5 offset0:16 offset1:17 \n \
//ds_read2_b64 %3, %5 offset0:16 offset1:17 \n \
//s_waitcnt lgkmcnt(0)"
//s_waitcnt lgkmcnt(0)"
//: "=v"(reg[0]), "=v"(reg[1]), "=v"(reg[2]), "=v"(reg[3])
//: "=v"(reg[0]), "=v"(reg[1]), "=v"(reg[2]), "=v"(reg[3])
//: "v"(__to_local((void *)(a_loc))), "v"(__to_local((void *)(b_loc)))
//: "v"(__to_local((void *)(a_loc))), "v"(__to_local((void *)(b_loc)))
//);
//);
//asm volatile("\n \
//asm volatile("\n \
//ds_read_b32 %0, %16 \n \
//ds_read_b32 %0, %16 \n \
...
@@ -451,32 +451,31 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -451,32 +451,31 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
//ds_read_b32 %13, %19 offset:1\n \
//ds_read_b32 %13, %19 offset:1\n \
//ds_read_b32 %14, %19 offset:2\n \
//ds_read_b32 %14, %19 offset:2\n \
//ds_read_b32 %15, %19 offset:3\n \
//ds_read_b32 %15, %19 offset:3\n \
//s_waitcnt lgkmcnt(0)"
//s_waitcnt lgkmcnt(0)"
//:
//:
//"=v"(p_a_thread[0]),
//"=v"(p_a_thread[0]),
//"=v"(p_a_thread[1]),
//"=v"(p_a_thread[1]),
//"=v"(p_a_thread[2]),
//"=v"(p_a_thread[2]),
//"=v"(p_a_thread[3]),
//"=v"(p_a_thread[3]),
//"=v"(p_a_thread[4]),
//"=v"(p_a_thread[4]),
//"=v"(p_a_thread[5]),
//"=v"(p_a_thread[5]),
//"=v"(p_a_thread[6]),
//"=v"(p_a_thread[6]),
//"=v"(p_a_thread[7]),
//"=v"(p_a_thread[7]),
//"=v"(p_b_thread[0]),
//"=v"(p_b_thread[0]),
//"=v"(p_b_thread[1]),
//"=v"(p_b_thread[1]),
//"=v"(p_b_thread[2]),
//"=v"(p_b_thread[2]),
//"=v"(p_b_thread[3]),
//"=v"(p_b_thread[3]),
//"=v"(p_b_thread[4]),
//"=v"(p_b_thread[4]),
//"=v"(p_b_thread[5]),
//"=v"(p_b_thread[5]),
//"=v"(p_b_thread[6]),
//"=v"(p_b_thread[6]),
//"=v"(p_b_thread[7])
//"=v"(p_b_thread[7])
//:
//:
//"v"(__to_local((void *)(&p_a_block[0]))),
//"v"(__to_local((void *)(&p_a_block[0]))),
//"v"(__to_local((void *)(&p_a_block[64]))),
//"v"(__to_local((void *)(&p_a_block[64]))),
//"v"(__to_local((void *)(&p_b_block[0]))),
//"v"(__to_local((void *)(&p_b_block[0]))),
//"v"(__to_local((void *)(&p_b_block[32])))
//"v"(__to_local((void *)(&p_b_block[32])))
//);
//);
// C = A * B
// C = A * B
asm
volatile
(
"
\n
\
asm
volatile
(
"
\n
\
v_mac_f32 %0, %64, %72
\n
\
v_mac_f32 %0, %64, %72
\n
\
...
@@ -544,165 +543,161 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -544,165 +543,161 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
v_mac_f32 %62, %71, %78
\n
\
v_mac_f32 %62, %71, %78
\n
\
v_mac_f32 %63, %71, %79
\n
\
v_mac_f32 %63, %71, %79
\n
\
"
"
:
:
"=v"
(
p_c_thread
[
0
]),
"=v"
(
p_c_thread
[
0
]),
"=v"
(
p_c_thread
[
1
]),
"=v"
(
p_c_thread
[
1
]),
"=v"
(
p_c_thread
[
2
]),
"=v"
(
p_c_thread
[
2
]),
"=v"
(
p_c_thread
[
3
]),
"=v"
(
p_c_thread
[
3
]),
"=v"
(
p_c_thread
[
4
]),
"=v"
(
p_c_thread
[
4
]),
"=v"
(
p_c_thread
[
5
]),
"=v"
(
p_c_thread
[
5
]),
"=v"
(
p_c_thread
[
6
]),
"=v"
(
p_c_thread
[
6
]),
"=v"
(
p_c_thread
[
7
]),
"=v"
(
p_c_thread
[
7
]),
"=v"
(
p_c_thread
[
8
]),
"=v"
(
p_c_thread
[
8
]),
"=v"
(
p_c_thread
[
9
]),
"=v"
(
p_c_thread
[
9
]),
"=v"
(
p_c_thread
[
10
]),
"=v"
(
p_c_thread
[
10
]),
"=v"
(
p_c_thread
[
11
]),
"=v"
(
p_c_thread
[
11
]),
"=v"
(
p_c_thread
[
12
]),
"=v"
(
p_c_thread
[
12
]),
"=v"
(
p_c_thread
[
13
]),
"=v"
(
p_c_thread
[
13
]),
"=v"
(
p_c_thread
[
14
]),
"=v"
(
p_c_thread
[
14
]),
"=v"
(
p_c_thread
[
15
]),
"=v"
(
p_c_thread
[
15
]),
"=v"
(
p_c_thread
[
16
]),
"=v"
(
p_c_thread
[
16
]),
"=v"
(
p_c_thread
[
17
]),
"=v"
(
p_c_thread
[
17
]),
"=v"
(
p_c_thread
[
18
]),
"=v"
(
p_c_thread
[
18
]),
"=v"
(
p_c_thread
[
19
]),
"=v"
(
p_c_thread
[
19
]),
"=v"
(
p_c_thread
[
20
]),
"=v"
(
p_c_thread
[
20
]),
"=v"
(
p_c_thread
[
21
]),
"=v"
(
p_c_thread
[
21
]),
"=v"
(
p_c_thread
[
22
]),
"=v"
(
p_c_thread
[
22
]),
"=v"
(
p_c_thread
[
23
]),
"=v"
(
p_c_thread
[
23
]),
"=v"
(
p_c_thread
[
24
]),
"=v"
(
p_c_thread
[
24
]),
"=v"
(
p_c_thread
[
25
]),
"=v"
(
p_c_thread
[
25
]),
"=v"
(
p_c_thread
[
26
]),
"=v"
(
p_c_thread
[
26
]),
"=v"
(
p_c_thread
[
27
]),
"=v"
(
p_c_thread
[
27
]),
"=v"
(
p_c_thread
[
28
]),
"=v"
(
p_c_thread
[
28
]),
"=v"
(
p_c_thread
[
29
]),
"=v"
(
p_c_thread
[
29
]),
"=v"
(
p_c_thread
[
30
]),
"=v"
(
p_c_thread
[
30
]),
"=v"
(
p_c_thread
[
31
]),
"=v"
(
p_c_thread
[
31
]),
"=v"
(
p_c_thread
[
32
]),
"=v"
(
p_c_thread
[
32
]),
"=v"
(
p_c_thread
[
33
]),
"=v"
(
p_c_thread
[
33
]),
"=v"
(
p_c_thread
[
34
]),
"=v"
(
p_c_thread
[
34
]),
"=v"
(
p_c_thread
[
35
]),
"=v"
(
p_c_thread
[
35
]),
"=v"
(
p_c_thread
[
36
]),
"=v"
(
p_c_thread
[
36
]),
"=v"
(
p_c_thread
[
37
]),
"=v"
(
p_c_thread
[
37
]),
"=v"
(
p_c_thread
[
38
]),
"=v"
(
p_c_thread
[
38
]),
"=v"
(
p_c_thread
[
39
]),
"=v"
(
p_c_thread
[
39
]),
"=v"
(
p_c_thread
[
40
]),
"=v"
(
p_c_thread
[
40
]),
"=v"
(
p_c_thread
[
41
]),
"=v"
(
p_c_thread
[
41
]),
"=v"
(
p_c_thread
[
42
]),
"=v"
(
p_c_thread
[
42
]),
"=v"
(
p_c_thread
[
43
]),
"=v"
(
p_c_thread
[
43
]),
"=v"
(
p_c_thread
[
44
]),
"=v"
(
p_c_thread
[
44
]),
"=v"
(
p_c_thread
[
45
]),
"=v"
(
p_c_thread
[
45
]),
"=v"
(
p_c_thread
[
46
]),
"=v"
(
p_c_thread
[
46
]),
"=v"
(
p_c_thread
[
47
]),
"=v"
(
p_c_thread
[
47
]),
"=v"
(
p_c_thread
[
48
]),
"=v"
(
p_c_thread
[
48
]),
"=v"
(
p_c_thread
[
49
]),
"=v"
(
p_c_thread
[
49
]),
"=v"
(
p_c_thread
[
50
]),
"=v"
(
p_c_thread
[
50
]),
"=v"
(
p_c_thread
[
51
]),
"=v"
(
p_c_thread
[
51
]),
"=v"
(
p_c_thread
[
52
]),
"=v"
(
p_c_thread
[
52
]),
"=v"
(
p_c_thread
[
53
]),
"=v"
(
p_c_thread
[
53
]),
"=v"
(
p_c_thread
[
54
]),
"=v"
(
p_c_thread
[
54
]),
"=v"
(
p_c_thread
[
55
]),
"=v"
(
p_c_thread
[
55
]),
"=v"
(
p_c_thread
[
56
]),
"=v"
(
p_c_thread
[
56
]),
"=v"
(
p_c_thread
[
57
]),
"=v"
(
p_c_thread
[
57
]),
"=v"
(
p_c_thread
[
58
]),
"=v"
(
p_c_thread
[
58
]),
"=v"
(
p_c_thread
[
59
]),
"=v"
(
p_c_thread
[
59
]),
"=v"
(
p_c_thread
[
60
]),
"=v"
(
p_c_thread
[
60
]),
"=v"
(
p_c_thread
[
61
]),
"=v"
(
p_c_thread
[
61
]),
"=v"
(
p_c_thread
[
62
]),
"=v"
(
p_c_thread
[
62
]),
"=v"
(
p_c_thread
[
63
])
"=v"
(
p_c_thread
[
63
])
:
"v"
(
p_a_thread
[
0
]),
:
"v"
(
p_a_thread
[
1
]),
"v"
(
p_a_thread
[
0
]),
"v"
(
p_a_thread
[
2
]),
"v"
(
p_a_thread
[
1
]),
"v"
(
p_a_thread
[
3
]),
"v"
(
p_a_thread
[
2
]),
"v"
(
p_a_thread
[
4
]),
"v"
(
p_a_thread
[
3
]),
"v"
(
p_a_thread
[
5
]),
"v"
(
p_a_thread
[
4
]),
"v"
(
p_a_thread
[
6
]),
"v"
(
p_a_thread
[
5
]),
"v"
(
p_a_thread
[
7
]),
"v"
(
p_a_thread
[
6
]),
"v"
(
p_b_thread
[
0
]),
"v"
(
p_a_thread
[
7
]),
"v"
(
p_b_thread
[
1
]),
"v"
(
p_b_thread
[
0
]),
"v"
(
p_b_thread
[
2
]),
"v"
(
p_b_thread
[
1
]),
"v"
(
p_b_thread
[
3
]),
"v"
(
p_b_thread
[
2
]),
"v"
(
p_b_thread
[
4
]),
"v"
(
p_b_thread
[
3
]),
"v"
(
p_b_thread
[
5
]),
"v"
(
p_b_thread
[
4
]),
"v"
(
p_b_thread
[
6
]),
"v"
(
p_b_thread
[
5
]),
"v"
(
p_b_thread
[
7
]),
"v"
(
p_b_thread
[
6
]),
"0"
(
p_c_thread
[
0
]),
"v"
(
p_b_thread
[
7
]),
"1"
(
p_c_thread
[
1
]),
"0"
(
p_c_thread
[
0
]),
"2"
(
p_c_thread
[
2
]),
"1"
(
p_c_thread
[
1
]),
"3"
(
p_c_thread
[
3
]),
"2"
(
p_c_thread
[
2
]),
"4"
(
p_c_thread
[
4
]),
"3"
(
p_c_thread
[
3
]),
"5"
(
p_c_thread
[
5
]),
"4"
(
p_c_thread
[
4
]),
"6"
(
p_c_thread
[
6
]),
"5"
(
p_c_thread
[
5
]),
"7"
(
p_c_thread
[
7
]),
"6"
(
p_c_thread
[
6
]),
"8"
(
p_c_thread
[
8
]),
"7"
(
p_c_thread
[
7
]),
"9"
(
p_c_thread
[
9
]),
"8"
(
p_c_thread
[
8
]),
"10"
(
p_c_thread
[
10
]),
"9"
(
p_c_thread
[
9
]),
"11"
(
p_c_thread
[
11
]),
"10"
(
p_c_thread
[
10
]),
"12"
(
p_c_thread
[
12
]),
"11"
(
p_c_thread
[
11
]),
"13"
(
p_c_thread
[
13
]),
"12"
(
p_c_thread
[
12
]),
"14"
(
p_c_thread
[
14
]),
"13"
(
p_c_thread
[
13
]),
"15"
(
p_c_thread
[
15
]),
"14"
(
p_c_thread
[
14
]),
"16"
(
p_c_thread
[
16
]),
"15"
(
p_c_thread
[
15
]),
"17"
(
p_c_thread
[
17
]),
"16"
(
p_c_thread
[
16
]),
"18"
(
p_c_thread
[
18
]),
"17"
(
p_c_thread
[
17
]),
"19"
(
p_c_thread
[
19
]),
"18"
(
p_c_thread
[
18
]),
"20"
(
p_c_thread
[
20
]),
"19"
(
p_c_thread
[
19
]),
"21"
(
p_c_thread
[
21
]),
"20"
(
p_c_thread
[
20
]),
"22"
(
p_c_thread
[
22
]),
"21"
(
p_c_thread
[
21
]),
"23"
(
p_c_thread
[
23
]),
"22"
(
p_c_thread
[
22
]),
"24"
(
p_c_thread
[
24
]),
"23"
(
p_c_thread
[
23
]),
"25"
(
p_c_thread
[
25
]),
"24"
(
p_c_thread
[
24
]),
"26"
(
p_c_thread
[
26
]),
"25"
(
p_c_thread
[
25
]),
"27"
(
p_c_thread
[
27
]),
"26"
(
p_c_thread
[
26
]),
"28"
(
p_c_thread
[
28
]),
"27"
(
p_c_thread
[
27
]),
"29"
(
p_c_thread
[
29
]),
"28"
(
p_c_thread
[
28
]),
"30"
(
p_c_thread
[
30
]),
"29"
(
p_c_thread
[
29
]),
"31"
(
p_c_thread
[
31
]),
"30"
(
p_c_thread
[
30
]),
"32"
(
p_c_thread
[
32
]),
"31"
(
p_c_thread
[
31
]),
"33"
(
p_c_thread
[
33
]),
"32"
(
p_c_thread
[
32
]),
"34"
(
p_c_thread
[
34
]),
"33"
(
p_c_thread
[
33
]),
"35"
(
p_c_thread
[
35
]),
"34"
(
p_c_thread
[
34
]),
"36"
(
p_c_thread
[
36
]),
"35"
(
p_c_thread
[
35
]),
"37"
(
p_c_thread
[
37
]),
"36"
(
p_c_thread
[
36
]),
"38"
(
p_c_thread
[
38
]),
"37"
(
p_c_thread
[
37
]),
"39"
(
p_c_thread
[
39
]),
"38"
(
p_c_thread
[
38
]),
"40"
(
p_c_thread
[
40
]),
"39"
(
p_c_thread
[
39
]),
"41"
(
p_c_thread
[
41
]),
"40"
(
p_c_thread
[
40
]),
"42"
(
p_c_thread
[
42
]),
"41"
(
p_c_thread
[
41
]),
"43"
(
p_c_thread
[
43
]),
"42"
(
p_c_thread
[
42
]),
"44"
(
p_c_thread
[
44
]),
"43"
(
p_c_thread
[
43
]),
"45"
(
p_c_thread
[
45
]),
"44"
(
p_c_thread
[
44
]),
"46"
(
p_c_thread
[
46
]),
"45"
(
p_c_thread
[
45
]),
"47"
(
p_c_thread
[
47
]),
"46"
(
p_c_thread
[
46
]),
"48"
(
p_c_thread
[
48
]),
"47"
(
p_c_thread
[
47
]),
"49"
(
p_c_thread
[
49
]),
"48"
(
p_c_thread
[
48
]),
"50"
(
p_c_thread
[
50
]),
"49"
(
p_c_thread
[
49
]),
"51"
(
p_c_thread
[
51
]),
"50"
(
p_c_thread
[
50
]),
"52"
(
p_c_thread
[
52
]),
"51"
(
p_c_thread
[
51
]),
"53"
(
p_c_thread
[
53
]),
"52"
(
p_c_thread
[
52
]),
"54"
(
p_c_thread
[
54
]),
"53"
(
p_c_thread
[
53
]),
"55"
(
p_c_thread
[
55
]),
"54"
(
p_c_thread
[
54
]),
"56"
(
p_c_thread
[
56
]),
"55"
(
p_c_thread
[
55
]),
"57"
(
p_c_thread
[
57
]),
"56"
(
p_c_thread
[
56
]),
"58"
(
p_c_thread
[
58
]),
"57"
(
p_c_thread
[
57
]),
"59"
(
p_c_thread
[
59
]),
"58"
(
p_c_thread
[
58
]),
"60"
(
p_c_thread
[
60
]),
"59"
(
p_c_thread
[
59
]),
"61"
(
p_c_thread
[
61
]),
"60"
(
p_c_thread
[
60
]),
"62"
(
p_c_thread
[
62
]),
"61"
(
p_c_thread
[
61
]),
"63"
(
p_c_thread
[
63
]));
"62"
(
p_c_thread
[
62
]),
"63"
(
p_c_thread
[
63
])
);
#else
#else
auto
a_src_index
=
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
;
auto
a_src_index
=
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
;
auto
b_src_index
=
b_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetB
;
auto
b_src_index
=
b_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetB
;
auto
dst_index
=
a_thread_sub_mtx
.
Get1dIndex
(
0
,
0
);
auto
dst_index
=
a_thread_sub_mtx
.
Get1dIndex
(
0
,
0
);
const
float4
*
a_loc
=
(
const
float4
*
)(
p_a_block
+
a_src_index
);
const
float4
*
b_loc
=
(
const
float4
*
)(
p_b_block
+
b_src_index
);
float4
*
reg
=
(
float4
*
)(
p_a_thread
+
dst_index
);
const
float4
*
a_loc
=
(
const
float4
*
)(
p_a_block
+
a_src_index
);
const
float4
*
b_loc
=
(
const
float4
*
)(
p_b_block
+
b_src_index
);
float4
*
reg
=
(
float4
*
)(
p_a_thread
+
dst_index
);
asm
volatile
(
"
\n
\
asm
volatile
(
"
\n
\
ds_read2_b64 %0, %84 offset1:1
\n
\
ds_read2_b64 %0, %84 offset1:1
\n
\
ds_read2_b64 %1, %84 offset0:32 offset1:33
\n
\
ds_read2_b64 %1, %84 offset0:32 offset1:33
\n
\
ds_read2_b64 %2, %85 offset1:1
\n
\
ds_read2_b64 %2, %85 offset1:1
\n
\
...
@@ -773,168 +768,165 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -773,168 +768,165 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
v_mac_f32 %66, %75, %82
\n
\
v_mac_f32 %66, %75, %82
\n
\
v_mac_f32 %67, %75, %83
\n
\
v_mac_f32 %67, %75, %83
\n
\
"
"
:
:
"=v"
(
reg
[
0
]),
"=v"
(
reg
[
0
]),
"=v"
(
reg
[
1
]),
"=v"
(
reg
[
1
]),
"=v"
(
reg
[
2
]),
"=v"
(
reg
[
2
]),
"=v"
(
reg
[
3
]),
"=v"
(
reg
[
3
]),
"=v"
(
p_c_thread
[
0
]),
"=v"
(
p_c_thread
[
0
]),
"=v"
(
p_c_thread
[
1
]),
"=v"
(
p_c_thread
[
1
]),
"=v"
(
p_c_thread
[
2
]),
"=v"
(
p_c_thread
[
2
]),
"=v"
(
p_c_thread
[
3
]),
"=v"
(
p_c_thread
[
3
]),
"=v"
(
p_c_thread
[
4
]),
"=v"
(
p_c_thread
[
4
]),
"=v"
(
p_c_thread
[
5
]),
"=v"
(
p_c_thread
[
5
]),
"=v"
(
p_c_thread
[
6
]),
"=v"
(
p_c_thread
[
6
]),
"=v"
(
p_c_thread
[
7
]),
"=v"
(
p_c_thread
[
7
]),
"=v"
(
p_c_thread
[
8
]),
"=v"
(
p_c_thread
[
8
]),
"=v"
(
p_c_thread
[
9
]),
"=v"
(
p_c_thread
[
9
]),
"=v"
(
p_c_thread
[
10
]),
"=v"
(
p_c_thread
[
10
]),
"=v"
(
p_c_thread
[
11
]),
"=v"
(
p_c_thread
[
11
]),
"=v"
(
p_c_thread
[
12
]),
"=v"
(
p_c_thread
[
12
]),
"=v"
(
p_c_thread
[
13
]),
"=v"
(
p_c_thread
[
13
]),
"=v"
(
p_c_thread
[
14
]),
"=v"
(
p_c_thread
[
14
]),
"=v"
(
p_c_thread
[
15
]),
"=v"
(
p_c_thread
[
15
]),
"=v"
(
p_c_thread
[
16
]),
"=v"
(
p_c_thread
[
16
]),
"=v"
(
p_c_thread
[
17
]),
"=v"
(
p_c_thread
[
17
]),
"=v"
(
p_c_thread
[
18
]),
"=v"
(
p_c_thread
[
18
]),
"=v"
(
p_c_thread
[
19
]),
"=v"
(
p_c_thread
[
19
]),
"=v"
(
p_c_thread
[
20
]),
"=v"
(
p_c_thread
[
20
]),
"=v"
(
p_c_thread
[
21
]),
"=v"
(
p_c_thread
[
21
]),
"=v"
(
p_c_thread
[
22
]),
"=v"
(
p_c_thread
[
22
]),
"=v"
(
p_c_thread
[
23
]),
"=v"
(
p_c_thread
[
23
]),
"=v"
(
p_c_thread
[
24
]),
"=v"
(
p_c_thread
[
24
]),
"=v"
(
p_c_thread
[
25
]),
"=v"
(
p_c_thread
[
25
]),
"=v"
(
p_c_thread
[
26
]),
"=v"
(
p_c_thread
[
26
]),
"=v"
(
p_c_thread
[
27
]),
"=v"
(
p_c_thread
[
27
]),
"=v"
(
p_c_thread
[
28
]),
"=v"
(
p_c_thread
[
28
]),
"=v"
(
p_c_thread
[
29
]),
"=v"
(
p_c_thread
[
29
]),
"=v"
(
p_c_thread
[
30
]),
"=v"
(
p_c_thread
[
30
]),
"=v"
(
p_c_thread
[
31
]),
"=v"
(
p_c_thread
[
31
]),
"=v"
(
p_c_thread
[
32
]),
"=v"
(
p_c_thread
[
32
]),
"=v"
(
p_c_thread
[
33
]),
"=v"
(
p_c_thread
[
33
]),
"=v"
(
p_c_thread
[
34
]),
"=v"
(
p_c_thread
[
34
]),
"=v"
(
p_c_thread
[
35
]),
"=v"
(
p_c_thread
[
35
]),
"=v"
(
p_c_thread
[
36
]),
"=v"
(
p_c_thread
[
36
]),
"=v"
(
p_c_thread
[
37
]),
"=v"
(
p_c_thread
[
37
]),
"=v"
(
p_c_thread
[
38
]),
"=v"
(
p_c_thread
[
38
]),
"=v"
(
p_c_thread
[
39
]),
"=v"
(
p_c_thread
[
39
]),
"=v"
(
p_c_thread
[
40
]),
"=v"
(
p_c_thread
[
40
]),
"=v"
(
p_c_thread
[
41
]),
"=v"
(
p_c_thread
[
41
]),
"=v"
(
p_c_thread
[
42
]),
"=v"
(
p_c_thread
[
42
]),
"=v"
(
p_c_thread
[
43
]),
"=v"
(
p_c_thread
[
43
]),
"=v"
(
p_c_thread
[
44
]),
"=v"
(
p_c_thread
[
44
]),
"=v"
(
p_c_thread
[
45
]),
"=v"
(
p_c_thread
[
45
]),
"=v"
(
p_c_thread
[
46
]),
"=v"
(
p_c_thread
[
46
]),
"=v"
(
p_c_thread
[
47
]),
"=v"
(
p_c_thread
[
47
]),
"=v"
(
p_c_thread
[
48
]),
"=v"
(
p_c_thread
[
48
]),
"=v"
(
p_c_thread
[
49
]),
"=v"
(
p_c_thread
[
49
]),
"=v"
(
p_c_thread
[
50
]),
"=v"
(
p_c_thread
[
50
]),
"=v"
(
p_c_thread
[
51
]),
"=v"
(
p_c_thread
[
51
]),
"=v"
(
p_c_thread
[
52
]),
"=v"
(
p_c_thread
[
52
]),
"=v"
(
p_c_thread
[
53
]),
"=v"
(
p_c_thread
[
53
]),
"=v"
(
p_c_thread
[
54
]),
"=v"
(
p_c_thread
[
54
]),
"=v"
(
p_c_thread
[
55
]),
"=v"
(
p_c_thread
[
55
]),
"=v"
(
p_c_thread
[
56
]),
"=v"
(
p_c_thread
[
56
]),
"=v"
(
p_c_thread
[
57
]),
"=v"
(
p_c_thread
[
57
]),
"=v"
(
p_c_thread
[
58
]),
"=v"
(
p_c_thread
[
58
]),
"=v"
(
p_c_thread
[
59
]),
"=v"
(
p_c_thread
[
59
]),
"=v"
(
p_c_thread
[
60
]),
"=v"
(
p_c_thread
[
60
]),
"=v"
(
p_c_thread
[
61
]),
"=v"
(
p_c_thread
[
61
]),
"=v"
(
p_c_thread
[
62
]),
"=v"
(
p_c_thread
[
62
]),
"=v"
(
p_c_thread
[
63
])
"=v"
(
p_c_thread
[
63
])
:
"v"
(
p_a_thread
[
0
]),
:
"v"
(
p_a_thread
[
1
]),
"v"
(
p_a_thread
[
0
]),
"v"
(
p_a_thread
[
2
]),
"v"
(
p_a_thread
[
1
]),
"v"
(
p_a_thread
[
3
]),
"v"
(
p_a_thread
[
2
]),
"v"
(
p_a_thread
[
4
]),
"v"
(
p_a_thread
[
3
]),
"v"
(
p_a_thread
[
5
]),
"v"
(
p_a_thread
[
4
]),
"v"
(
p_a_thread
[
6
]),
"v"
(
p_a_thread
[
5
]),
"v"
(
p_a_thread
[
7
]),
"v"
(
p_a_thread
[
6
]),
"v"
(
p_b_thread
[
0
]),
"v"
(
p_a_thread
[
7
]),
"v"
(
p_b_thread
[
1
]),
"v"
(
p_b_thread
[
0
]),
"v"
(
p_b_thread
[
2
]),
"v"
(
p_b_thread
[
1
]),
"v"
(
p_b_thread
[
3
]),
"v"
(
p_b_thread
[
2
]),
"v"
(
p_b_thread
[
4
]),
"v"
(
p_b_thread
[
3
]),
"v"
(
p_b_thread
[
5
]),
"v"
(
p_b_thread
[
4
]),
"v"
(
p_b_thread
[
6
]),
"v"
(
p_b_thread
[
5
]),
"v"
(
p_b_thread
[
7
]),
"v"
(
p_b_thread
[
6
]),
"v"
(
__to_local
((
void
*
)(
a_loc
))),
"v"
(
p_b_thread
[
7
]),
"v"
(
__to_local
((
void
*
)(
b_loc
))),
"v"
(
__to_local
((
void
*
)(
a_loc
))),
"4"
(
p_c_thread
[
0
]),
"v"
(
__to_local
((
void
*
)(
b_loc
))),
"5"
(
p_c_thread
[
1
]),
"4"
(
p_c_thread
[
0
]),
"6"
(
p_c_thread
[
2
]),
"5"
(
p_c_thread
[
1
]),
"7"
(
p_c_thread
[
3
]),
"6"
(
p_c_thread
[
2
]),
"8"
(
p_c_thread
[
4
]),
"7"
(
p_c_thread
[
3
]),
"9"
(
p_c_thread
[
5
]),
"8"
(
p_c_thread
[
4
]),
"10"
(
p_c_thread
[
6
]),
"9"
(
p_c_thread
[
5
]),
"11"
(
p_c_thread
[
7
]),
"10"
(
p_c_thread
[
6
]),
"12"
(
p_c_thread
[
8
]),
"11"
(
p_c_thread
[
7
]),
"13"
(
p_c_thread
[
9
]),
"12"
(
p_c_thread
[
8
]),
"14"
(
p_c_thread
[
10
]),
"13"
(
p_c_thread
[
9
]),
"15"
(
p_c_thread
[
11
]),
"14"
(
p_c_thread
[
10
]),
"16"
(
p_c_thread
[
12
]),
"15"
(
p_c_thread
[
11
]),
"17"
(
p_c_thread
[
13
]),
"16"
(
p_c_thread
[
12
]),
"18"
(
p_c_thread
[
14
]),
"17"
(
p_c_thread
[
13
]),
"19"
(
p_c_thread
[
15
]),
"18"
(
p_c_thread
[
14
]),
"20"
(
p_c_thread
[
16
]),
"19"
(
p_c_thread
[
15
]),
"21"
(
p_c_thread
[
17
]),
"20"
(
p_c_thread
[
16
]),
"22"
(
p_c_thread
[
18
]),
"21"
(
p_c_thread
[
17
]),
"23"
(
p_c_thread
[
19
]),
"22"
(
p_c_thread
[
18
]),
"24"
(
p_c_thread
[
20
]),
"23"
(
p_c_thread
[
19
]),
"25"
(
p_c_thread
[
21
]),
"24"
(
p_c_thread
[
20
]),
"26"
(
p_c_thread
[
22
]),
"25"
(
p_c_thread
[
21
]),
"27"
(
p_c_thread
[
23
]),
"26"
(
p_c_thread
[
22
]),
"28"
(
p_c_thread
[
24
]),
"27"
(
p_c_thread
[
23
]),
"29"
(
p_c_thread
[
25
]),
"28"
(
p_c_thread
[
24
]),
"30"
(
p_c_thread
[
26
]),
"29"
(
p_c_thread
[
25
]),
"31"
(
p_c_thread
[
27
]),
"30"
(
p_c_thread
[
26
]),
"32"
(
p_c_thread
[
28
]),
"31"
(
p_c_thread
[
27
]),
"33"
(
p_c_thread
[
29
]),
"32"
(
p_c_thread
[
28
]),
"34"
(
p_c_thread
[
30
]),
"33"
(
p_c_thread
[
29
]),
"35"
(
p_c_thread
[
31
]),
"34"
(
p_c_thread
[
30
]),
"36"
(
p_c_thread
[
32
]),
"35"
(
p_c_thread
[
31
]),
"37"
(
p_c_thread
[
33
]),
"36"
(
p_c_thread
[
32
]),
"38"
(
p_c_thread
[
34
]),
"37"
(
p_c_thread
[
33
]),
"39"
(
p_c_thread
[
35
]),
"38"
(
p_c_thread
[
34
]),
"40"
(
p_c_thread
[
36
]),
"39"
(
p_c_thread
[
35
]),
"41"
(
p_c_thread
[
37
]),
"40"
(
p_c_thread
[
36
]),
"42"
(
p_c_thread
[
38
]),
"41"
(
p_c_thread
[
37
]),
"43"
(
p_c_thread
[
39
]),
"42"
(
p_c_thread
[
38
]),
"44"
(
p_c_thread
[
40
]),
"43"
(
p_c_thread
[
39
]),
"45"
(
p_c_thread
[
41
]),
"44"
(
p_c_thread
[
40
]),
"46"
(
p_c_thread
[
42
]),
"45"
(
p_c_thread
[
41
]),
"47"
(
p_c_thread
[
43
]),
"46"
(
p_c_thread
[
42
]),
"48"
(
p_c_thread
[
44
]),
"47"
(
p_c_thread
[
43
]),
"49"
(
p_c_thread
[
45
]),
"48"
(
p_c_thread
[
44
]),
"50"
(
p_c_thread
[
46
]),
"49"
(
p_c_thread
[
45
]),
"51"
(
p_c_thread
[
47
]),
"50"
(
p_c_thread
[
46
]),
"52"
(
p_c_thread
[
48
]),
"51"
(
p_c_thread
[
47
]),
"53"
(
p_c_thread
[
49
]),
"52"
(
p_c_thread
[
48
]),
"54"
(
p_c_thread
[
50
]),
"53"
(
p_c_thread
[
49
]),
"55"
(
p_c_thread
[
51
]),
"54"
(
p_c_thread
[
50
]),
"56"
(
p_c_thread
[
52
]),
"55"
(
p_c_thread
[
51
]),
"57"
(
p_c_thread
[
53
]),
"56"
(
p_c_thread
[
52
]),
"58"
(
p_c_thread
[
54
]),
"57"
(
p_c_thread
[
53
]),
"59"
(
p_c_thread
[
55
]),
"58"
(
p_c_thread
[
54
]),
"60"
(
p_c_thread
[
56
]),
"59"
(
p_c_thread
[
55
]),
"61"
(
p_c_thread
[
57
]),
"60"
(
p_c_thread
[
56
]),
"62"
(
p_c_thread
[
58
]),
"61"
(
p_c_thread
[
57
]),
"63"
(
p_c_thread
[
59
]),
"62"
(
p_c_thread
[
58
]),
"64"
(
p_c_thread
[
60
]),
"63"
(
p_c_thread
[
59
]),
"65"
(
p_c_thread
[
61
]),
"64"
(
p_c_thread
[
60
]),
"66"
(
p_c_thread
[
62
]),
"65"
(
p_c_thread
[
61
]),
"67"
(
p_c_thread
[
63
]));
"66"
(
p_c_thread
[
62
]),
"67"
(
p_c_thread
[
63
])
);
#endif
#endif
}
}
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run
_asm
(
const
FloatA
*
const
__restrict__
p_a_block
,
__device__
void
Run
(
const
FloatA
*
const
__restrict__
p_a_block
,
const
FloatB
*
const
__restrict__
p_b_block
,
const
FloatB
*
const
__restrict__
p_b_block
,
FloatC
*
const
__restrict__
p_c_thread
,
FloatC
*
const
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
Accumulator
f_accum
)
const
{
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
...
@@ -973,17 +965,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -973,17 +965,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_assert
(
MPerThreadSubC
==
4
&&
NPerThreadSubC
==
4
&&
MRepeat
==
2
&&
NRepeat
==
2
&&
KPerThreadLoop
==
1
&&
K
==
1
,
"asm is not for this mtx shape"
);
const
FloatA
*
const
p_a_block_thread_offset
=
p_a_block
+
mMyThreadOffsetA
;
const
FloatA
*
const
p_a_block_thread_offset
=
p_a_block
+
mMyThreadOffsetA
;
#pragma unroll
#pragma unroll
// loop over k
// loop over k
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
{
#if 0
#pragma unroll
#pragma unroll
// copy A-sub to form A
// copy A-sub to form A
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
...
@@ -993,67 +980,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -993,67 +980,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
mMyThreadOffsetA
,
mMyThreadOffsetA
,
a_thread_mtx
,
a_thread_mtx
,
a_thread_sub_mtx.NCol(
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
a_thread_sub_mtx
.
GetLengths
());
a_thread_sub_mtx
.
GetLengths
());
}
}
#elif
1
// this produce right result
using
vectorA_t
=
typename
vector_type
<
FloatA
,
4
>::
MemoryType
;
// this is float4*
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1
\n
\
s_waitcnt lgkmcnt(0)"
:
"=v"
(
*
(
reinterpret_cast
<
vectorA_t
*>
(
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
0
))))
:
"v"
(
__to_local
(
(
void
*
)(
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
))));
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1
\n
\
s_waitcnt lgkmcnt(0)"
:
"=v"
(
*
(
reinterpret_cast
<
vectorA_t
*>
(
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
MPerThreadSubC
))))
:
"v"
(
__to_local
((
void
*
)(
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
MPerLevel1Cluster
)
+
mMyThreadOffsetA
))));
#elif 0
// this produce wrong result
using
vectorA_t
=
typename
vector_type
<
FloatA
,
4
>::
MemoryType
;
// this is float4*
asm
volatile
(
"
\n
\
ds_read_b128 %0, %2
\n
\
ds_read_b128 %1, %3
\n
\
s_waitcnt lgkmcnt(0)"
:
"=v"
(
*
(
reinterpret_cast
<
vectorA_t
*>
(
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
0
)))),
"=v"
(
*
(
reinterpret_cast
<
vectorA_t
*>
(
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
MPerThreadSubC
))))
:
"v"
(
__to_local
(
(
void
*
)(
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
))),
"v"
(
__to_local
((
void
*
)(
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
MPerLevel1Cluster
)
+
mMyThreadOffsetA
))));
#elif 1
// this produce wrong result
using
vectorA_t
=
typename
vector_type
<
FloatA
,
4
>::
MemoryType
;
// this is float4*
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1
\n
\
s_waitcnt lgkmcnt(0)"
:
"=v"
(
*
(
reinterpret_cast
<
vectorA_t
*>
(
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
0
))))
:
"v"
(
__to_local
((
void
*
)(
p_a_block_thread_offset
))));
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:16
\n
\
s_waitcnt lgkmcnt(0)"
:
"=v"
(
*
(
reinterpret_cast
<
vectorA_t
*>
(
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
MPerThreadSubC
))))
:
"v"
(
__to_local
((
void
*
)(
p_a_block_thread_offset
))));
#endif
#pragma unroll
//#pragma unroll
// copy B-sub to form B
// copy B-sub to form B
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
{
...
@@ -1066,8 +997,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -1066,8 +997,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
b_thread_sub_mtx
.
GetLengths
());
b_thread_sub_mtx
.
GetLengths
());
}
}
// C = A * B
// C = A * B
#if 1
threadwise_gemm
(
a_thread_mtx
,
threadwise_gemm
(
a_thread_mtx
,
True
,
True
,
p_a_thread
,
p_a_thread
,
...
@@ -1078,58 +1008,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -1078,58 +1008,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
False
,
False
,
p_c_thread
,
p_c_thread
,
f_accum
);
f_accum
);
#elif 0
// inline asm
static_assert
(
c_thread_mtx
.
NRow
()
==
8
&&
c_thread_mtx
.
NCol
()
==
8
,
"asm is only for 8x8"
);
for
(
index_t
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
// A is transposed
{
const
index_t
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
0
);
for
(
index_t
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
{
const
index_t
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
index_t
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
0
);
asm
volatile
(
"
\n
\
v_mac_f32 %0, %8, %9
\n
\
v_mac_f32 %1, %8, %10
\n
\
v_mac_f32 %2, %8, %11
\n
\
v_mac_f32 %3, %8, %12
\n
\
v_mac_f32 %4, %8, %13
\n
\
v_mac_f32 %5, %8, %14
\n
\
v_mac_f32 %6, %8, %15
\n
\
v_mac_f32 %7, %8, %16
\n
\
"
:
"=v"
(
p_c_thread
[
cindex
+
0
]),
"=v"
(
p_c_thread
[
cindex
+
1
]),
"=v"
(
p_c_thread
[
cindex
+
2
]),
"=v"
(
p_c_thread
[
cindex
+
3
]),
"=v"
(
p_c_thread
[
cindex
+
4
]),
"=v"
(
p_c_thread
[
cindex
+
5
]),
"=v"
(
p_c_thread
[
cindex
+
6
]),
"=v"
(
p_c_thread
[
cindex
+
7
])
:
"v"
(
p_a_thread
[
aindex
]),
"v"
(
p_b_thread
[
bindex
+
0
]),
"v"
(
p_b_thread
[
bindex
+
1
]),
"v"
(
p_b_thread
[
bindex
+
2
]),
"v"
(
p_b_thread
[
bindex
+
3
]),
"v"
(
p_b_thread
[
bindex
+
4
]),
"v"
(
p_b_thread
[
bindex
+
5
]),
"v"
(
p_b_thread
[
bindex
+
6
]),
"v"
(
p_b_thread
[
bindex
+
7
]),
"0"
(
p_c_thread
[
cindex
+
0
]),
"1"
(
p_c_thread
[
cindex
+
1
]),
"2"
(
p_c_thread
[
cindex
+
2
]),
"3"
(
p_c_thread
[
cindex
+
3
]),
"4"
(
p_c_thread
[
cindex
+
4
]),
"5"
(
p_c_thread
[
cindex
+
5
]),
"6"
(
p_c_thread
[
cindex
+
6
]),
"7"
(
p_c_thread
[
cindex
+
7
]));
}
}
#endif
}
}
}
}
...
...
src/include/common.hip.hpp
View file @
e43d7bc6
...
@@ -5,8 +5,6 @@
...
@@ -5,8 +5,6 @@
#include "Array.hip.hpp"
#include "Array.hip.hpp"
#include "functional.hip.hpp"
#include "functional.hip.hpp"
extern
"C"
__attribute__
((
address_space
(
3
)))
void
*
__to_local
(
void
*
p
)[[
hc
]];
__device__
index_t
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
index_t
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
__device__
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
...
@@ -23,21 +21,45 @@ struct is_same<T, T>
...
@@ -23,21 +21,45 @@ struct is_same<T, T>
static
const
bool
value
=
true
;
static
const
bool
value
=
true
;
};
};
#if DEVICE_BACKEND_CUDA
__host__
__device__
constexpr
index_t
integer_divide_ceil
(
index_t
a
,
index_t
b
)
template
<
typename
T
>
__host__
__device__
constexpr
T
max
(
T
a
,
T
b
)
{
{
return
a
>
b
?
a
:
b
;
return
(
a
+
b
-
1
)
/
b
;
}
}
template
<
typename
T
>
namespace
mod_conv
{
__host__
__device__
constexpr
T
min
(
T
a
,
T
b
)
template
<
class
T
>
__host__
__device__
constexpr
T
max
(
T
x
,
T
y
)
{
{
return
a
<
b
?
a
:
b
;
return
x
>
y
?
x
:
y
;
}
}
#endif
__host__
__device__
constexpr
index_t
integer_divide_ceil
(
index_t
a
,
index_t
b
)
template
<
class
T
,
class
...
Ts
>
__host__
__device__
constexpr
T
max
(
T
x
,
Ts
...
xs
)
{
{
return
(
a
+
b
-
1
)
/
b
;
static_assert
(
sizeof
...(
xs
)
>
0
,
"not enough argument"
);
auto
y
=
max
(
xs
...);
static_assert
(
is_same
<
decltype
(
y
),
T
>::
value
,
"not the same type"
);
return
x
>
y
?
x
:
y
;
}
template
<
class
T
>
__host__
__device__
constexpr
T
min
(
T
x
,
T
y
)
{
return
x
<
y
?
x
:
y
;
}
template
<
class
T
,
class
...
Ts
>
__host__
__device__
constexpr
T
min
(
T
x
,
Ts
...
xs
)
{
static_assert
(
sizeof
...(
xs
)
>
0
,
"not enough argument"
);
auto
y
=
min
(
xs
...);
static_assert
(
is_same
<
decltype
(
y
),
T
>::
value
,
"not the same type"
);
return
x
<
y
?
x
:
y
;
}
}
}
src/include/gridwise_direct_convolution_1.hip.hpp
View file @
e43d7bc6
...
@@ -59,12 +59,12 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
...
@@ -59,12 +59,12 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
constexpr
auto
out_block_desc
=
constexpr
auto
out_block_desc
=
make_ConstantTensorDescriptor
(
out_block_global_desc
.
GetLengths
());
make_ConstantTensorDescriptor
(
out_block_global_desc
.
GetLengths
());
constexpr
index_t
in_block_size
=
in_block_desc
.
GetElementSpace
();
constexpr
index_t
in_block_
element_
size
=
in_block_desc
.
GetElementSpace
();
constexpr
index_t
wei_block_size
=
wei_block_desc
.
GetElementSpace
();
constexpr
index_t
wei_block_
element_
size
=
wei_block_desc
.
GetElementSpace
();
constexpr
index_t
out_block_size
=
out_block_desc
.
GetElementSpace
();
constexpr
index_t
out_block_size
=
out_block_desc
.
GetElementSpace
();
__shared__
Float
p_in_block
[
in_block_size
];
__shared__
Float
p_in_block
[
in_block_
element_
size
];
__shared__
Float
p_wei_block
[
wei_block_size
];
__shared__
Float
p_wei_block
[
wei_block_
element_
size
];
__shared__
Float
p_out_block
[
out_block_size
];
__shared__
Float
p_out_block
[
out_block_size
];
const
index_t
block_id
=
blockIdx
.
x
;
const
index_t
block_id
=
blockIdx
.
x
;
...
...
src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp
View file @
e43d7bc6
...
@@ -63,17 +63,18 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
...
@@ -63,17 +63,18 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
Sequence
<
wei_ke_block_desc
.
GetStride
(
I0
),
Y
*
X
,
X
,
1
>
{});
Sequence
<
wei_ke_block_desc
.
GetStride
(
I0
),
Y
*
X
,
X
,
1
>
{});
// shared mem
// shared mem
constexpr
index_t
in_block_size
=
constexpr
index_t
in_block_
element_
size
=
in_nchw_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
in_nchw_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
index_t
wei_block_size
=
constexpr
index_t
wei_block_
element_
size
=
wei_kcyx_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
wei_kcyx_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
?
InBlockCopyDataPerRead
?
InBlockCopyDataPerRead
:
WeiBlockCopyDataPerRead
;
:
WeiBlockCopyDataPerRead
;
__shared__
Float
p_in_block
[
max_align
*
((
in_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_in_block
[
max_align
*
((
in_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block
[
max_align
*
((
wei_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block
[
max_align
*
((
wei_block_element_size
+
max_align
-
1
)
/
max_align
)];
// threadwise tensors
// threadwise tensors
constexpr
index_t
HiPerThread
=
HoPerThread
+
Y
-
1
;
constexpr
index_t
HiPerThread
=
HoPerThread
+
Y
-
1
;
...
...
src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
View file @
e43d7bc6
...
@@ -73,10 +73,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
...
@@ -73,10 +73,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
Sequence
<
wei_ke_vec_block_desc
.
GetStride
(
I0
),
Y
*
X
,
X
,
1
>
{});
Sequence
<
wei_ke_vec_block_desc
.
GetStride
(
I0
),
Y
*
X
,
X
,
1
>
{});
// shared mem
// shared mem
constexpr
index_t
in_block_size
=
constexpr
index_t
in_block_
element_
size
=
in_nchw_vec_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
in_nchw_vec_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
index_t
wei_block_size
=
constexpr
index_t
wei_block_
element_
size
=
wei_kcyx_vec_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
wei_kcyx_vec_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
...
@@ -84,9 +84,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
...
@@ -84,9 +84,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
:
WeiBlockCopyDataPerRead
;
:
WeiBlockCopyDataPerRead
;
__shared__
in_vector_mem_t
__shared__
in_vector_mem_t
p_in_vec_block
[
max_align
*
((
in_block_size
+
max_align
-
1
)
/
max_align
)];
p_in_vec_block
[
max_align
*
((
in_block_
element_
size
+
max_align
-
1
)
/
max_align
)];
__shared__
in_vector_mem_t
__shared__
in_vector_mem_t
p_wei_vec_block
[
max_align
*
((
wei_block_size
+
max_align
-
1
)
/
max_align
)];
p_wei_vec_block
[
max_align
*
((
wei_block_
element_
size
+
max_align
-
1
)
/
max_align
)];
// threadwise tensors
// threadwise tensors
constexpr
index_t
HiPerThread
=
HoPerThread
+
Y
-
1
;
constexpr
index_t
HiPerThread
=
HoPerThread
+
Y
-
1
;
...
...
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp
View file @
e43d7bc6
...
@@ -164,18 +164,19 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
...
@@ -164,18 +164,19 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
HoPerThread
>
{};
HoPerThread
>
{};
// LDS: be careful of alignment
// LDS: be careful of alignment
constexpr
index_t
in_block_size
=
constexpr
index_t
in_block_
element_
size
=
in_chwn_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
in_chwn_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
index_t
wei_block_size
=
constexpr
index_t
wei_block_
element_
size
=
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
?
InBlockCopyDataPerRead
?
InBlockCopyDataPerRead
:
WeiBlockCopyDataPerRead
;
:
WeiBlockCopyDataPerRead
;
__shared__
Float
p_in_block
[
max_align
*
((
in_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_in_block
[
max_align
*
((
in_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block
[
max_align
*
((
wei_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block
[
max_align
*
((
wei_block_element_size
+
max_align
-
1
)
/
max_align
)];
// register
// register
Float
p_out_thread
[
out_khwn_thread_desc
.
GetElementSpace
()];
Float
p_out_thread
[
out_khwn_thread_desc
.
GetElementSpace
()];
...
...
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp
View file @
e43d7bc6
...
@@ -204,11 +204,11 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
...
@@ -204,11 +204,11 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
true
>
{};
true
>
{};
// LDS
// LDS
constexpr
index_t
in_block_size
=
in_chwn_block_desc
.
GetElementSpace
();
constexpr
index_t
in_block_
element_
size
=
in_chwn_block_desc
.
GetElementSpace
();
constexpr
index_t
wei_block_size
=
wei_cyxk_block_desc
.
GetElementSpace
();
constexpr
index_t
wei_block_
element_
size
=
wei_cyxk_block_desc
.
GetElementSpace
();
__shared__
Float
p_in_block
[
in_block_size
];
__shared__
Float
p_in_block
[
in_block_
element_
size
];
__shared__
Float
p_wei_block
[
wei_block_size
];
__shared__
Float
p_wei_block
[
wei_block_
element_
size
];
// register
// register
Float
p_out_thread
[
out_hkwn_thread_desc
.
GetElementSpace
()];
Float
p_out_thread
[
out_hkwn_thread_desc
.
GetElementSpace
()];
...
...
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
View file @
e43d7bc6
...
@@ -34,63 +34,109 @@ template <index_t GridSize,
...
@@ -34,63 +34,109 @@ template <index_t GridSize,
index_t
WeiBlockCopyThreadPerDim1
,
index_t
WeiBlockCopyThreadPerDim1
,
index_t
InBlockCopyDataPerRead
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
>
index_t
WeiBlockCopyDataPerRead
>
__global__
void
class
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
public:
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
index_t
GetSharedMemorySize
()
constexpr
auto
I2
=
Number
<
2
>
{};
{
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
Hi
=
in_chwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wi
=
in_chwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Y
=
wei_cyxk_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_cyxk_global_desc
.
GetLength
(
I2
);
constexpr
index_t
BGhostRead
=
(
Y
-
1
)
*
Wi
+
(
X
-
1
);
// tensor view of blockwise input and weight
// be careful of alignment
constexpr
auto
in_cb_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
BPerBlock
+
BGhostRead
>
{},
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_kb_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
BPerThread
>
{});
constexpr
index_t
max_align
=
mod_conv
::
max
(
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
// LDS: be careful of alignment
constexpr
index_t
in_block_element_space
=
in_cb_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_element_space
=
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
return
(
in_block_element_space
+
wei_block_element_space
)
*
sizeof
(
Float
);
}
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
__global__
static
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
const
Float
*
const
__restrict__
p_wei_global
,
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
Float
*
const
__restrict__
p_out_global
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_chwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_chwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Hi
=
in_chwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_chwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wi
=
in_chwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_chwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
in_chwn_global_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
in_chwn_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_khwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_khwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_khwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_khwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wo
=
out_khwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_khwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Y
=
wei_cyxk_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_cyxk_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_cyxk_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_cyxk_global_desc
.
GetLength
(
I2
);
constexpr
index_t
B
=
N
*
Hi
*
Wi
;
constexpr
index_t
B
=
N
*
Hi
*
Wi
;
constexpr
index_t
BGhostRead
=
(
Y
-
1
)
*
Wi
+
(
X
-
1
);
constexpr
index_t
BGhostRead
=
(
Y
-
1
)
*
Wi
+
(
X
-
1
);
// divide block work by 2d: [K, B]
// divide block work by 2d: [K, B]
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
(
B
+
BPerBlock
-
1
)
/
BPerBlock
;
constexpr
index_t
BBlockWork
=
(
B
+
BPerBlock
-
1
)
/
BPerBlock
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
BBlockWork
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
BBlockWork
;
const
index_t
b_block_work_id
=
get_block_1d_id
()
-
k_block_work_id
*
BBlockWork
;
const
index_t
b_block_work_id
=
get_block_1d_id
()
-
k_block_work_id
*
BBlockWork
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
b_block_data_begin
=
b_block_work_id
*
BPerBlock
;
const
index_t
b_block_data_begin
=
b_block_work_id
*
BPerBlock
;
// flattend (2d) tensor view of gridwise input
// flattend (2d) tensor view of gridwise input
constexpr
auto
in_cb_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
B
>
{});
constexpr
auto
in_cb_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
B
>
{});
constexpr
auto
wei_ek_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
*
Y
*
X
,
K
>
{});
constexpr
auto
wei_ek_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
*
Y
*
X
,
K
>
{});
// tensor view of blockwise input and weight
// tensor view of blockwise input and weight
// be careful of alignment
// be careful of alignment
constexpr
auto
in_cb_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
in_cb_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
BPerBlock
+
BGhostRead
>
{},
Number
<
InBlockCopyDataPerRead
>
{});
Sequence
<
CPerBlock
,
BPerBlock
+
BGhostRead
>
{},
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
auto
wei_ek_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_ek_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
*
Y
*
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
Sequence
<
CPerBlock
*
Y
*
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
// tensor view of threadwise output in register
// tensor view of threadwise output in register
constexpr
auto
out_kb_thread_desc
=
constexpr
auto
out_kb_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
BPerThread
>
{});
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
BPerThread
>
{});
#if 0
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
...
@@ -121,20 +167,22 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
...
@@ -121,20 +167,22 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
decltype(in_cb_block_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{};
decltype(in_cb_block_desc.GetLengths())>{};
#elif
0
#elif
0
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
const
auto
blockwise_in_copy
=
Float
,
Blockwise2dTensorCopy2
<
BlockSize
,
decltype
(
in_cb_global_desc
),
Float
,
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
()),
decltype
(
in_cb_block_desc
),
InBlockCopyThreadPerDim0
,
decltype
(
in_cb_block_desc
.
GetLengths
()),
InBlockCopyThreadPerDim1
>
{};
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
>
{};
#elif 1
#elif 1
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
const
auto
blockwise_in_copy
=
Float
,
Blockwise2dTensorCopy3
<
BlockSize
,
decltype
(
in_cb_global_desc
),
Float
,
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
()),
decltype
(
in_cb_block_desc
),
InBlockCopyDataPerRead
>
{};
decltype
(
in_cb_block_desc
.
GetLengths
()),
InBlockCopyDataPerRead
>
{};
#endif
#endif
// blockwise wei copy
// blockwise wei copy
...
@@ -147,137 +195,138 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
...
@@ -147,137 +195,138 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
decltype(wei_ek_block_desc.GetLengths())>{};
#elif
0
#elif
0
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
const
auto
blockwise_wei_copy
=
Float
,
Blockwise2dTensorCopy2
<
BlockSize
,
decltype
(
wei_ek_global_desc
),
Float
,
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
decltype
(
wei_ek_block_desc
),
WeiBlockCopyThreadPerDim0
,
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyThreadPerDim1
>
{};
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
{};
#elif 1
#elif 1
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
const
auto
blockwise_wei_copy
=
Float
,
Blockwise2dTensorCopy3
<
BlockSize
,
decltype
(
wei_ek_global_desc
),
Float
,
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
decltype
(
wei_ek_block_desc
),
WeiBlockCopyDataPerRead
>
{};
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
{};
#endif
#endif
// a series of blockwise GEMM
// a series of blockwise GEMM
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
// a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
// c_mtx[K,B] is out_block[K,B]
// c_mtx[K,B] is out_block[K,B]
constexpr
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
constexpr
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_cyxk_block_desc
.
GetStride
(
I0
)
>
{});
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_cyxk_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_cxb_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
constexpr
auto
b_cxb_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
BPerBlock
>
{},
Number
<
in_cb_block_desc
.
GetStride
(
I0
)
>
{});
Number
<
CPerBlock
>
{},
Number
<
BPerBlock
>
{},
Number
<
in_cb_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_kxb_thread_mtx_desc
=
constexpr
auto
c_kxb_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
BPerThread
>
{});
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
BPerThread
>
{});
const
auto
blockwise_gemm
=
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxb_block_mtx_desc
),
decltype
(
b_cxb_block_mtx_desc
),
decltype
(
c_kxb_thread_mtx_desc
),
decltype
(
c_kxb_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
>
{};
GemmKPerThreadLoop
>
{};
// LDS: be careful of alignment
// LDS: be careful of alignment
constexpr
index_t
in_block_size
=
constexpr
index_t
max_align
=
in_cb_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
mod_conv
::
max
(
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
constexpr
index_t
wei_block_size
=
constexpr
index_t
in_block_element_space
=
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
in_cb_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
constexpr
index_t
wei_block_element_space
=
?
InBlockCopyDataPerRead
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
:
WeiBlockCopyDataPerRead
;
__shared__
Float
p_in_block
[
in_block_element_space
];
// LDS
__shared__
Float
p_wei_block
[
wei_block_element_space
];
__shared__
Float
p_in_block
[
max_align
*
((
in_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block
[
max_align
*
((
wei_block_size
+
max_align
-
1
)
/
max_align
)];
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_cb_global_desc
.
Get1dIndex
(
0
,
b_block_data_begin
);
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_cb_global_desc
.
Get1dIndex
(
0
,
b_block_data_begin
);
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_cyxk_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_cyxk_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
// register
Float
p_out_thread
[
out_kb_thread_desc
.
GetElementSpace
()];
// register
Float
p_out_thread
[
out_kb_thread_desc
.
GetElementSpace
()];
// set threadwise output tensor to 0
threadwise_2d_tensor_set_zero
(
out_kb_thread_desc
,
p_out_thread
);
// set threadwise output tensor to 0
threadwise_2d_tensor_set_zero
(
out_kb_thread_desc
,
p_out_thread
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_in_global_block_offset
+=
CPerBlock
*
in_cb_global_desc
.
GetStride
(
I0
),
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
),
p_in_global_block_offset
+=
CPerBlock
*
in_cb_global_desc
.
GetStride
(
I0
),
__syncthreads
())
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
),
{
__syncthreads
())
// load data
{
blockwise_in_copy
.
Run
(
p_in_global_block_offset
,
p_in_block
);
// load data
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
,
p_wei_block
);
blockwise_in_copy
.
Run
(
p_in_global_block_offset
,
p_in_block
);
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
,
p_wei_block
);
__syncthreads
();
__syncthreads
();
// compute on current data
// compute on current data
// a series of GEMM
// a series of GEMM
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
#if 1
{
blockwise_gemm
.
Run
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
#
el
if
1
#if
0
blockwise_gemm
.
Run
_asm
blockwise_gemm.Run
#elif
1
#elif
1
blockwise_gemm
.
Run_RegisterDoubleBuffer
blockwise_gemm
.
Run_RegisterDoubleBuffer
#elif 0
blockwise_gemm
.
Run_asm
#endif
#endif
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block
+
y
*
Wi
+
x
,
p_in_block
+
y
*
Wi
+
x
,
p_out_thread
,
p_out_thread
,
f_accum
);
f_accum
);
}
}
}
}
}
}
// output: register to global mem,
// output: register to global mem,
const
auto
c_thread_mtx_begin
=
const
auto
c_thread_mtx_begin
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_begin
=
k_block_data_begin
+
c_thread_mtx_begin
.
row
;
const
index_t
k_thread_data_begin
=
k_block_data_begin
+
c_thread_mtx_begin
.
row
;
const
index_t
b_thread_data_begin
=
b_block_data_begin
+
c_thread_mtx_begin
.
col
;
const
index_t
b_thread_data_begin
=
b_block_data_begin
+
c_thread_mtx_begin
.
col
;
for
(
index_t
k
=
0
;
k
<
out_kb_thread_desc
.
GetLength
(
I0
);
++
k
)
for
(
index_t
k
=
0
;
k
<
out_kb_thread_desc
.
GetLength
(
I0
);
++
k
)
{
for
(
index_t
b
=
0
;
b
<
out_kb_thread_desc
.
GetLength
(
I1
);
++
b
)
{
{
const
auto
c_thread_mtx_distance
=
for
(
index_t
b
=
0
;
b
<
out_kb_thread_desc
.
GetLength
(
I1
);
++
b
)
blockwise_gemm
.
GetDistanceFromBeginOfThreadMatrixC
(
k
,
b
);
index_t
k_data
=
k_thread_data_begin
+
c_thread_mtx_distance
.
row
;
index_t
b_data
=
b_thread_data_begin
+
c_thread_mtx_distance
.
col
;
index_t
h_data
=
b_data
/
(
Wi
*
N
);
index_t
itmp
=
b_data
-
h_data
*
(
Wi
*
N
);
index_t
w_data
=
itmp
/
N
;
index_t
n_data
=
itmp
-
w_data
*
N
;
if
(
n_data
<
N
&&
h_data
<
Ho
&&
w_data
<
Wo
)
{
{
p_out_global
[
out_khwn_global_desc
.
Get1dIndex
(
k_data
,
h_data
,
w_data
,
n_data
)]
=
const
auto
c_thread_mtx_distance
=
p_out_thread
[
out_kb_thread_desc
.
Get1dIndex
(
k
,
b
)];
blockwise_gemm
.
GetDistanceFromBeginOfThreadMatrixC
(
k
,
b
);
index_t
k_data
=
k_thread_data_begin
+
c_thread_mtx_distance
.
row
;
index_t
b_data
=
b_thread_data_begin
+
c_thread_mtx_distance
.
col
;
index_t
h_data
=
b_data
/
(
Wi
*
N
);
index_t
itmp
=
b_data
-
h_data
*
(
Wi
*
N
);
index_t
w_data
=
itmp
/
N
;
index_t
n_data
=
itmp
-
w_data
*
N
;
if
(
n_data
<
N
&&
h_data
<
Ho
&&
w_data
<
Wo
)
{
p_out_global
[
out_khwn_global_desc
.
Get1dIndex
(
k_data
,
h_data
,
w_data
,
n_data
)]
=
p_out_thread
[
out_kb_thread_desc
.
Get1dIndex
(
k
,
b
)];
}
}
}
}
}
}
}
}
}
;
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
View file @
e43d7bc6
...
@@ -34,67 +34,65 @@ template <index_t GridSize,
...
@@ -34,67 +34,65 @@ template <index_t GridSize,
index_t
WeiBlockCopyThreadPerDim1
,
index_t
WeiBlockCopyThreadPerDim1
,
index_t
InBlockCopyDataPerRead
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
>
index_t
WeiBlockCopyDataPerRead
>
__global__
void
class
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
#if 0
__launch_bounds__(256,2)
#endif
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
public:
constexpr
auto
I1
=
Number
<
1
>
{};
__global__
static
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
constexpr
auto
I2
=
Number
<
2
>
{};
const
Float
*
const
__restrict__
p_wei_global
,
constexpr
auto
I3
=
Number
<
3
>
{};
Float
*
const
__restrict__
p_out_global
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_chwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_chwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Hi
=
in_chwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_chwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wi
=
in_chwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_chwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
in_chwn_global_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
in_chwn_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_khwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_khwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_khwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_khwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wo
=
out_khwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_khwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Y
=
wei_cyxk_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_cyxk_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_cyxk_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_cyxk_global_desc
.
GetLength
(
I2
);
constexpr
index_t
B
=
N
*
Hi
*
Wi
;
constexpr
index_t
B
=
N
*
Hi
*
Wi
;
constexpr
index_t
BGhostRead
=
(
Y
-
1
)
*
Wi
+
(
X
-
1
);
constexpr
index_t
BGhostRead
=
(
Y
-
1
)
*
Wi
+
(
X
-
1
);
// divide block work by 2d: [K, B]
// divide block work by 2d: [K, B]
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
(
B
+
BPerBlock
-
1
)
/
BPerBlock
;
constexpr
index_t
BBlockWork
=
(
B
+
BPerBlock
-
1
)
/
BPerBlock
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
BBlockWork
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
BBlockWork
;
const
index_t
b_block_work_id
=
get_block_1d_id
()
-
k_block_work_id
*
BBlockWork
;
const
index_t
b_block_work_id
=
get_block_1d_id
()
-
k_block_work_id
*
BBlockWork
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
b_block_data_begin
=
b_block_work_id
*
BPerBlock
;
const
index_t
b_block_data_begin
=
b_block_work_id
*
BPerBlock
;
// flattend (2d) tensor view of gridwise input
// flattend (2d) tensor view of gridwise input
constexpr
auto
in_cb_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
B
>
{});
constexpr
auto
in_cb_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
B
>
{});
constexpr
auto
wei_ek_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
*
Y
*
X
,
K
>
{});
constexpr
auto
wei_ek_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
*
Y
*
X
,
K
>
{});
// tensor view of blockwise input and weight
// tensor view of blockwise input and weight
// be careful of alignment
// be careful of alignment
constexpr
auto
in_cb_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
in_cb_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
BPerBlock
+
BGhostRead
>
{},
Number
<
InBlockCopyDataPerRead
>
{});
Sequence
<
CPerBlock
,
BPerBlock
+
BGhostRead
>
{},
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
auto
wei_ek_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_ek_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
*
Y
*
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
Sequence
<
CPerBlock
*
Y
*
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
// tensor view of threadwise output in register
// tensor view of threadwise output in register
constexpr
auto
out_kb_thread_desc
=
constexpr
auto
out_kb_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
BPerThread
>
{});
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
BPerThread
>
{});
#if 0
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
...
@@ -125,20 +123,22 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
...
@@ -125,20 +123,22 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
decltype(in_cb_block_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{};
decltype(in_cb_block_desc.GetLengths())>{};
#elif
0
#elif
0
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
const
auto
blockwise_in_copy
=
Float
,
Blockwise2dTensorCopy2
<
BlockSize
,
decltype
(
in_cb_global_desc
),
Float
,
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
()),
decltype
(
in_cb_block_desc
),
InBlockCopyThreadPerDim0
,
decltype
(
in_cb_block_desc
.
GetLengths
()),
InBlockCopyThreadPerDim1
>
{};
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
>
{};
#elif 1
#elif 1
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
const
auto
blockwise_in_copy
=
Float
,
Blockwise2dTensorCopy3
<
BlockSize
,
decltype
(
in_cb_global_desc
),
Float
,
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
()),
decltype
(
in_cb_block_desc
),
InBlockCopyDataPerRead
>
{};
decltype
(
in_cb_block_desc
.
GetLengths
()),
InBlockCopyDataPerRead
>
{};
#endif
#endif
// blockwise wei copy
// blockwise wei copy
...
@@ -151,36 +151,38 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
...
@@ -151,36 +151,38 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
decltype(wei_ek_block_desc.GetLengths())>{};
#elif
0
#elif
0
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
const
auto
blockwise_wei_copy
=
Float
,
Blockwise2dTensorCopy2
<
BlockSize
,
decltype
(
wei_ek_global_desc
),
Float
,
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
decltype
(
wei_ek_block_desc
),
WeiBlockCopyThreadPerDim0
,
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyThreadPerDim1
>
{};
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
{};
#elif 1
#elif 1
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
const
auto
blockwise_wei_copy
=
Float
,
Blockwise2dTensorCopy3
<
BlockSize
,
decltype
(
wei_ek_global_desc
),
Float
,
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
decltype
(
wei_ek_block_desc
),
WeiBlockCopyDataPerRead
>
{};
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
{};
#endif
#endif
// a series of blockwise GEMM
// a series of blockwise GEMM
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
// a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
// c_mtx[K,B] is out_block[K,B]
// c_mtx[K,B] is out_block[K,B]
constexpr
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
constexpr
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_cyxk_block_desc
.
GetStride
(
I0
)
>
{});
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_cyxk_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_cxb_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
constexpr
auto
b_cxb_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
BPerBlock
>
{},
Number
<
in_cb_block_desc
.
GetStride
(
I0
)
>
{});
Number
<
CPerBlock
>
{},
Number
<
BPerBlock
>
{},
Number
<
in_cb_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_kxb_thread_mtx_desc
=
constexpr
auto
c_kxb_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
BPerThread
>
{});
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
BPerThread
>
{});
#if 0
#if 0
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC<BlockSize,
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC<BlockSize,
...
@@ -195,144 +197,149 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
...
@@ -195,144 +197,149 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
GemmThreadPerRowPerCluster,
GemmThreadPerRowPerCluster,
true>{};
true>{};
#else
#else
const
auto
blockwise_gemm
=
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxb_block_mtx_desc
),
decltype
(
b_cxb_block_mtx_desc
),
decltype
(
c_kxb_thread_mtx_desc
),
decltype
(
c_kxb_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
>
{};
GemmKPerThreadLoop
>
{};
#endif
#endif
// LDS: be careful of alignment
// LDS: be careful of alignment
constexpr
index_t
in_block_size
=
constexpr
index_t
in_block_
element_
size
=
in_cb_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
in_cb_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
index_t
wei_block_size
=
constexpr
index_t
wei_block_
element_
size
=
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
?
InBlockCopyDataPerRead
?
InBlockCopyDataPerRead
:
WeiBlockCopyDataPerRead
;
:
WeiBlockCopyDataPerRead
;
// LDS double buffer
// LDS double buffer
__shared__
Float
p_in_block_0
[
max_align
*
((
in_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
__shared__
Float
p_wei_block_0
[
max_align
*
((
wei_block_size
+
max_align
-
1
)
/
max_align
)];
p_in_block_0
[
max_align
*
((
in_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block_0
[
max_align
*
((
wei_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_in_block_1
[
max_align
*
((
in_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
__shared__
Float
p_wei_block_1
[
max_align
*
((
wei_block_size
+
max_align
-
1
)
/
max_align
)];
p_in_block_1
[
max_align
*
((
in_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block_1
[
max_align
*
((
wei_block_element_size
+
max_align
-
1
)
/
max_align
)];
const
Float
*
p_in_global_block_offset
=
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_cb_global_desc
.
Get1dIndex
(
0
,
b_block_data_begin
);
p_in_global
+
in_cb_global_desc
.
Get1dIndex
(
0
,
b_block_data_begin
);
const
Float
*
p_wei_global_block_offset
=
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_cyxk_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
p_wei_global
+
wei_cyxk_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
// preload data into LDS
// preload data into LDS
blockwise_in_copy
.
Run
(
p_in_global_block_offset
,
p_in_block_0
);
blockwise_in_copy
.
Run
(
p_in_global_block_offset
,
p_in_block_0
);
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
,
p_wei_block_0
);
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
,
p_wei_block_0
);
p_in_global_block_offset
+=
CPerBlock
*
in_cb_global_desc
.
GetStride
(
I0
);
p_in_global_block_offset
+=
CPerBlock
*
in_cb_global_desc
.
GetStride
(
I0
);
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
);
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
);
// register
// register
Float
p_out_thread
[
out_kb_thread_desc
.
GetElementSpace
()];
Float
p_out_thread
[
out_kb_thread_desc
.
GetElementSpace
()];
// set threadwise output tensor to 0
// set threadwise output tensor to 0
threadwise_2d_tensor_set_zero
(
out_kb_thread_desc
,
p_out_thread
);
threadwise_2d_tensor_set_zero
(
out_kb_thread_desc
,
p_out_thread
);
bool
even_loop
=
true
;
bool
even_loop
=
true
;
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
+
CPerBlock
<
C
;
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
+
CPerBlock
<
C
;
c_block_data_begin
+=
CPerBlock
,
c_block_data_begin
+=
CPerBlock
,
p_in_global_block_offset
+=
CPerBlock
*
in_cb_global_desc
.
GetStride
(
I0
),
p_in_global_block_offset
+=
CPerBlock
*
in_cb_global_desc
.
GetStride
(
I0
),
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
),
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
),
even_loop
=
!
even_loop
)
even_loop
=
!
even_loop
)
{
{
Float
*
p_in_block_now
=
even_loop
?
p_in_block_0
:
p_in_block_1
;
Float
*
p_in_block_now
=
even_loop
?
p_in_block_0
:
p_in_block_1
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_0
:
p_wei_block_1
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_0
:
p_wei_block_1
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_1
:
p_in_block_0
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_1
:
p_in_block_0
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_1
:
p_wei_block_0
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_1
:
p_wei_block_0
;
__syncthreads
();
__syncthreads
();
// load next data
// load next data
#if 0
#if 0
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next);
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next);
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next);
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next);
#elif
1
#elif
1
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
p_wei_register_clipboard
);
#endif
#endif
// compute on current data
// compute on current data
// a series of GEMM
// a series of GEMM
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
#if 1
#if 1
blockwise_gemm
.
Run
blockwise_gemm
.
Run
#else
#else
blockwise_gemm
.
Run_RegisterDoubleBuffer
blockwise_gemm
.
Run_RegisterDoubleBuffer
#endif
#endif
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_now
+
y
*
Wi
+
x
,
p_in_block_now
+
y
*
Wi
+
x
,
p_out_thread
,
p_out_thread
,
f_accum
);
f_accum
);
}
}
}
}
#if 1
#if 1
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_next
);
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
#endif
#endif
}
}
// last computation
// last computation
{
{
Float
*
p_in_block_now
=
even_loop
?
p_in_block_0
:
p_in_block_1
;
Float
*
p_in_block_now
=
even_loop
?
p_in_block_0
:
p_in_block_1
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_0
:
p_wei_block_1
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_0
:
p_wei_block_1
;
__syncthreads
();
__syncthreads
();
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
#if 1
#if 1
blockwise_gemm
.
Run
blockwise_gemm
.
Run
#else
#else
blockwise_gemm
.
Run_RegisterDoubleBuffer
blockwise_gemm
.
Run_RegisterDoubleBuffer
#endif
#endif
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_now
+
y
*
Wi
+
x
,
p_in_block_now
+
y
*
Wi
+
x
,
p_out_thread
,
p_out_thread
,
f_accum
);
f_accum
);
}
}
}
}
}
}
// output: register to global mem,
// output: register to global mem,
const
auto
c_thread_mtx_begin
=
const
auto
c_thread_mtx_begin
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_begin
=
k_block_data_begin
+
c_thread_mtx_begin
.
row
;
const
index_t
k_thread_data_begin
=
k_block_data_begin
+
c_thread_mtx_begin
.
row
;
const
index_t
b_thread_data_begin
=
b_block_data_begin
+
c_thread_mtx_begin
.
col
;
const
index_t
b_thread_data_begin
=
b_block_data_begin
+
c_thread_mtx_begin
.
col
;
#if 0
#if 0
if(get_block_1d_id() == 0)
if(get_block_1d_id() == 0)
...
@@ -348,26 +355,27 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
...
@@ -348,26 +355,27 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
}
}
#endif
#endif
for
(
index_t
k
=
0
;
k
<
out_kb_thread_desc
.
GetLength
(
I0
);
++
k
)
for
(
index_t
k
=
0
;
k
<
out_kb_thread_desc
.
GetLength
(
I0
);
++
k
)
{
for
(
index_t
b
=
0
;
b
<
out_kb_thread_desc
.
GetLength
(
I1
);
++
b
)
{
{
const
auto
c_thread_mtx_distance
=
for
(
index_t
b
=
0
;
b
<
out_kb_thread_desc
.
GetLength
(
I1
);
++
b
)
blockwise_gemm
.
GetDistanceFromBeginOfThreadMatrixC
(
k
,
b
);
index_t
k_data
=
k_thread_data_begin
+
c_thread_mtx_distance
.
row
;
index_t
b_data
=
b_thread_data_begin
+
c_thread_mtx_distance
.
col
;
index_t
h_data
=
b_data
/
(
Wi
*
N
);
index_t
itmp
=
b_data
-
h_data
*
(
Wi
*
N
);
index_t
w_data
=
itmp
/
N
;
index_t
n_data
=
itmp
-
w_data
*
N
;
if
(
n_data
<
N
&&
h_data
<
Ho
&&
w_data
<
Wo
)
{
{
p_out_global
[
out_khwn_global_desc
.
Get1dIndex
(
k_data
,
h_data
,
w_data
,
n_data
)]
=
const
auto
c_thread_mtx_distance
=
p_out_thread
[
out_kb_thread_desc
.
Get1dIndex
(
k
,
b
)];
blockwise_gemm
.
GetDistanceFromBeginOfThreadMatrixC
(
k
,
b
);
index_t
k_data
=
k_thread_data_begin
+
c_thread_mtx_distance
.
row
;
index_t
b_data
=
b_thread_data_begin
+
c_thread_mtx_distance
.
col
;
index_t
h_data
=
b_data
/
(
Wi
*
N
);
index_t
itmp
=
b_data
-
h_data
*
(
Wi
*
N
);
index_t
w_data
=
itmp
/
N
;
index_t
n_data
=
itmp
-
w_data
*
N
;
if
(
n_data
<
N
&&
h_data
<
Ho
&&
w_data
<
Wo
)
{
p_out_global
[
out_khwn_global_desc
.
Get1dIndex
(
k_data
,
h_data
,
w_data
,
n_data
)]
=
p_out_thread
[
out_kb_thread_desc
.
Get1dIndex
(
k
,
b
)];
}
}
}
}
}
}
}
}
}
;
src/include/threadwise_gemm.hip.hpp
View file @
e43d7bc6
...
@@ -10,11 +10,9 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
...
@@ -10,11 +10,9 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
constexpr
auto
src_mtx
=
SrcMatrix
{};
constexpr
auto
src_mtx
=
SrcMatrix
{};
constexpr
auto
dst_mtx
=
DstMatrix
{};
constexpr
auto
dst_mtx
=
DstMatrix
{};
#if 1
#if 0
//NRow = 1
for(index_t i = 0; i < NRow; ++i)
for(index_t i = 0; i < NRow; ++i)
{
{
//NCol = 4
for(index_t j = 0; j < NCol; ++j)
for(index_t j = 0; j < NCol; ++j)
{
{
const index_t src_index = src_mtx.Get1dIndex(i, j);
const index_t src_index = src_mtx.Get1dIndex(i, j);
...
@@ -23,7 +21,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
...
@@ -23,7 +21,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
p_dst[dst_index] = p_src[src_index];
p_dst[dst_index] = p_src[src_index];
}
}
}
}
#elif
0
#elif
1
static_assert
(
NCol
==
4
,
"only for NCol == 4"
);
static_assert
(
NCol
==
4
,
"only for NCol == 4"
);
using
vector_t
=
typename
vector_type
<
Float
,
4
>::
MemoryType
;
using
vector_t
=
typename
vector_type
<
Float
,
4
>::
MemoryType
;
...
@@ -33,22 +31,8 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
...
@@ -33,22 +31,8 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
const
index_t
src_index
=
src_mtx
.
Get1dIndex
(
i
,
0
);
const
index_t
src_index
=
src_mtx
.
Get1dIndex
(
i
,
0
);
const
index_t
dst_index
=
dst_mtx
.
Get1dIndex
(
i
,
0
);
const
index_t
dst_index
=
dst_mtx
.
Get1dIndex
(
i
,
0
);
#if 0
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
]))
=
*(reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]));
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]));
#elif
0
asm
volatile
(
"
\n
\
ds_read2_b64 %0, %1 offset1:1
\n
\
s_waitcnt lgkmcnt(0)"
:
"=v"
(
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])))
:
"v"
(
__to_local
((
void
*
)(
&
p_src
[
src_index
]))));
#elif 1
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1
\n
\
s_waitcnt lgkmcnt(0)"
:
"=v"
(
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])))
:
"v"
(
__to_local
((
void
*
)(
&
p_src
[
src_index
]))));
#endif
}
}
#endif
#endif
}
}
...
@@ -84,13 +68,10 @@ __device__ void threadwise_gemm(MatrixA,
...
@@ -84,13 +68,10 @@ __device__ void threadwise_gemm(MatrixA,
constexpr
index_t
N
=
c_mtx
.
NCol
();
constexpr
index_t
N
=
c_mtx
.
NCol
();
constexpr
index_t
K
=
a_mtx
.
NRow
();
// A is transposed
constexpr
index_t
K
=
a_mtx
.
NRow
();
// A is transposed
// K = 1
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
{
// M = 8
for
(
index_t
i
=
0
;
i
<
M
;
++
i
)
for
(
index_t
i
=
0
;
i
<
M
;
++
i
)
{
{
// N = 8
for
(
index_t
j
=
0
;
j
<
N
;
++
j
)
for
(
index_t
j
=
0
;
j
<
N
;
++
j
)
{
{
const
index_t
aindex
=
a_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
index_t
aindex
=
a_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment