Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_onnxruntime
Commits
e43d7bc6
You need to sign in or sign up before continuing.
Commit
e43d7bc6
authored
Apr 01, 2019
by
Chao Liu
Browse files
refactor
parent
d058d164
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
862 additions
and
917 deletions
+862
-917
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
+9
-7
driver/driver.hip.cpp
driver/driver.hip.cpp
+1
-1
src/include/ConstantTensorDescriptor.hip.hpp
src/include/ConstantTensorDescriptor.hip.hpp
+4
-1
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+363
-485
src/include/common.hip.hpp
src/include/common.hip.hpp
+34
-12
src/include/gridwise_direct_convolution_1.hip.hpp
src/include/gridwise_direct_convolution_1.hip.hpp
+5
-5
src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp
...lude/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp
+5
-4
src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
...se_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
+4
-4
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp
...idwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp
+5
-4
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp
...implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp
+4
-4
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
...idwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
+221
-172
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
...mm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
+204
-196
src/include/threadwise_gemm.hip.hpp
src/include/threadwise_gemm.hip.hpp
+3
-22
No files found.
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
View file @
e43d7bc6
...
@@ -270,7 +270,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -270,7 +270,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
float
time
=
launch_kernel
(
constexpr
auto
gridwise_conv
=
#if 1
#if 1
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
#else
#else
...
@@ -301,7 +301,9 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -301,7 +301,9 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
,
WeiBlockCopyThreadPerDim1
,
InBlockCopyDataPerRead
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
>
,
WeiBlockCopyDataPerRead
>
();
float
time
=
launch_kernel
(
gridwise_conv
.
Run
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
...
...
driver/driver.hip.cpp
View file @
e43d7bc6
...
@@ -580,7 +580,7 @@ int main(int argc, char* argv[])
...
@@ -580,7 +580,7 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif
0
#elif
1
// 1x1 filter, 14x14 image, C = 2048
// 1x1 filter, 14x14 image, C = 2048
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
2048
;
constexpr
index_t
C
=
2048
;
...
...
src/include/ConstantTensorDescriptor.hip.hpp
View file @
e43d7bc6
...
@@ -137,7 +137,10 @@ struct ConstantTensorDescriptor
...
@@ -137,7 +137,10 @@ struct ConstantTensorDescriptor
}
}
};
};
return
static_const_reduce_n
<
nDim
>
{}(
GetElementSpace_f
{},
add
{})
+
align
.
Get
();
index_t
element_space_unaligned
=
static_const_reduce_n
<
nDim
>
{}(
GetElementSpace_f
{},
add
{})
+
1
;
return
align
.
Get
()
*
((
element_space_unaligned
+
align
.
Get
()
-
1
)
/
align
.
Get
());
}
}
template
<
class
...
Is
>
template
<
class
...
Is
>
...
...
src/include/blockwise_gemm.hip.hpp
View file @
e43d7bc6
#pragma once
#pragma once
#include "threadwise_gemm.hip.hpp"
#include "threadwise_gemm.hip.hpp"
extern
"C"
__attribute__
((
address_space
(
3
)))
void
*
__to_local
(
void
*
p
)
[[
hc
]];
extern
"C"
__attribute__
((
address_space
(
3
)))
void
*
__to_local
(
void
*
p
)[[
hc
]];
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixA
,
...
@@ -335,7 +335,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -335,7 +335,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run
(
const
FloatA
*
__restrict__
p_a_block
,
__device__
void
Run
_asm
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
FloatC
*
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
Accumulator
f_accum
)
const
...
@@ -370,8 +370,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -370,8 +370,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
float
p_thread
[
a_thread_mtx
.
GetElementSpace
()
+
b_thread_mtx
.
GetElementSpace
()];
float
p_thread
[
a_thread_mtx
.
GetElementSpace
()
+
b_thread_mtx
.
GetElementSpace
()];
FloatA
*
p_a_thread
=
p_thread
;
FloatA
*
p_a_thread
=
p_thread
;
FloatB
*
p_b_thread
=
p_thread
+
a_thread_mtx
.
GetElementSpace
();
FloatB
*
p_b_thread
=
p_thread
+
a_thread_mtx
.
GetElementSpace
();
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
...
@@ -387,9 +387,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -387,9 +387,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
auto
a_src_index
=
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
;
auto
a_src_index
=
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
;
auto
b_src_index
=
b_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetB
;
auto
b_src_index
=
b_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetB
;
const
float4
*
a_loc
=
(
const
float4
*
)(
p_a_block
+
a_src_index
);
const
float4
*
a_loc
=
(
const
float4
*
)(
p_a_block
+
a_src_index
);
const
float4
*
b_loc
=
(
const
float4
*
)(
p_b_block
+
b_src_index
);
const
float4
*
b_loc
=
(
const
float4
*
)(
p_b_block
+
b_src_index
);
float4
*
reg
=
(
float4
*
)(
p_thread
);
float4
*
reg
=
(
float4
*
)(
p_thread
);
reg
[
0
]
=
a_loc
[
0
];
reg
[
0
]
=
a_loc
[
0
];
reg
[
1
]
=
a_loc
[
16
];
reg
[
1
]
=
a_loc
[
16
];
...
@@ -476,7 +476,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -476,7 +476,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
//"v"(__to_local((void *)(&p_b_block[32])))
//"v"(__to_local((void *)(&p_b_block[32])))
//);
//);
// C = A * B
// C = A * B
asm
volatile
(
"
\n
\
asm
volatile
(
"
\n
\
v_mac_f32 %0, %64, %72
\n
\
v_mac_f32 %0, %64, %72
\n
\
...
@@ -544,8 +543,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -544,8 +543,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
v_mac_f32 %62, %71, %78
\n
\
v_mac_f32 %62, %71, %78
\n
\
v_mac_f32 %63, %71, %79
\n
\
v_mac_f32 %63, %71, %79
\n
\
"
"
:
:
"=v"
(
p_c_thread
[
0
]),
"=v"
(
p_c_thread
[
0
]),
"=v"
(
p_c_thread
[
1
]),
"=v"
(
p_c_thread
[
1
]),
"=v"
(
p_c_thread
[
2
]),
"=v"
(
p_c_thread
[
2
]),
"=v"
(
p_c_thread
[
3
]),
"=v"
(
p_c_thread
[
3
]),
...
@@ -609,8 +607,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -609,8 +607,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
"=v"
(
p_c_thread
[
61
]),
"=v"
(
p_c_thread
[
61
]),
"=v"
(
p_c_thread
[
62
]),
"=v"
(
p_c_thread
[
62
]),
"=v"
(
p_c_thread
[
63
])
"=v"
(
p_c_thread
[
63
])
:
:
"v"
(
p_a_thread
[
0
]),
"v"
(
p_a_thread
[
0
]),
"v"
(
p_a_thread
[
1
]),
"v"
(
p_a_thread
[
1
]),
"v"
(
p_a_thread
[
2
]),
"v"
(
p_a_thread
[
2
]),
"v"
(
p_a_thread
[
3
]),
"v"
(
p_a_thread
[
3
]),
...
@@ -689,18 +686,16 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -689,18 +686,16 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
"60"
(
p_c_thread
[
60
]),
"60"
(
p_c_thread
[
60
]),
"61"
(
p_c_thread
[
61
]),
"61"
(
p_c_thread
[
61
]),
"62"
(
p_c_thread
[
62
]),
"62"
(
p_c_thread
[
62
]),
"63"
(
p_c_thread
[
63
])
"63"
(
p_c_thread
[
63
]));
);
#else
#else
auto
a_src_index
=
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
;
auto
a_src_index
=
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
;
auto
b_src_index
=
b_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetB
;
auto
b_src_index
=
b_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetB
;
auto
dst_index
=
a_thread_sub_mtx
.
Get1dIndex
(
0
,
0
);
auto
dst_index
=
a_thread_sub_mtx
.
Get1dIndex
(
0
,
0
);
const
float4
*
a_loc
=
(
const
float4
*
)(
p_a_block
+
a_src_index
);
const
float4
*
a_loc
=
(
const
float4
*
)(
p_a_block
+
a_src_index
);
const
float4
*
b_loc
=
(
const
float4
*
)(
p_b_block
+
b_src_index
);
const
float4
*
b_loc
=
(
const
float4
*
)(
p_b_block
+
b_src_index
);
float4
*
reg
=
(
float4
*
)(
p_a_thread
+
dst_index
);
float4
*
reg
=
(
float4
*
)(
p_a_thread
+
dst_index
);
asm
volatile
(
"
\n
\
asm
volatile
(
"
\n
\
ds_read2_b64 %0, %84 offset1:1
\n
\
ds_read2_b64 %0, %84 offset1:1
\n
\
...
@@ -773,8 +768,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -773,8 +768,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
v_mac_f32 %66, %75, %82
\n
\
v_mac_f32 %66, %75, %82
\n
\
v_mac_f32 %67, %75, %83
\n
\
v_mac_f32 %67, %75, %83
\n
\
"
"
:
:
"=v"
(
reg
[
0
]),
"=v"
(
reg
[
0
]),
"=v"
(
reg
[
1
]),
"=v"
(
reg
[
1
]),
"=v"
(
reg
[
2
]),
"=v"
(
reg
[
2
]),
"=v"
(
reg
[
3
]),
"=v"
(
reg
[
3
]),
...
@@ -842,8 +836,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -842,8 +836,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
"=v"
(
p_c_thread
[
61
]),
"=v"
(
p_c_thread
[
61
]),
"=v"
(
p_c_thread
[
62
]),
"=v"
(
p_c_thread
[
62
]),
"=v"
(
p_c_thread
[
63
])
"=v"
(
p_c_thread
[
63
])
:
:
"v"
(
p_a_thread
[
0
]),
"v"
(
p_a_thread
[
0
]),
"v"
(
p_a_thread
[
1
]),
"v"
(
p_a_thread
[
1
]),
"v"
(
p_a_thread
[
2
]),
"v"
(
p_a_thread
[
2
]),
"v"
(
p_a_thread
[
3
]),
"v"
(
p_a_thread
[
3
]),
...
@@ -859,8 +852,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -859,8 +852,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
"v"
(
p_b_thread
[
5
]),
"v"
(
p_b_thread
[
5
]),
"v"
(
p_b_thread
[
6
]),
"v"
(
p_b_thread
[
6
]),
"v"
(
p_b_thread
[
7
]),
"v"
(
p_b_thread
[
7
]),
"v"
(
__to_local
((
void
*
)(
a_loc
))),
"v"
(
__to_local
((
void
*
)(
a_loc
))),
"v"
(
__to_local
((
void
*
)(
b_loc
))),
"v"
(
__to_local
((
void
*
)(
b_loc
))),
"4"
(
p_c_thread
[
0
]),
"4"
(
p_c_thread
[
0
]),
"5"
(
p_c_thread
[
1
]),
"5"
(
p_c_thread
[
1
]),
"6"
(
p_c_thread
[
2
]),
"6"
(
p_c_thread
[
2
]),
...
@@ -924,14 +917,13 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -924,14 +917,13 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
"64"
(
p_c_thread
[
60
]),
"64"
(
p_c_thread
[
60
]),
"65"
(
p_c_thread
[
61
]),
"65"
(
p_c_thread
[
61
]),
"66"
(
p_c_thread
[
62
]),
"66"
(
p_c_thread
[
62
]),
"67"
(
p_c_thread
[
63
])
"67"
(
p_c_thread
[
63
]));
);
#endif
#endif
}
}
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run
_asm
(
const
FloatA
*
const
__restrict__
p_a_block
,
__device__
void
Run
(
const
FloatA
*
const
__restrict__
p_a_block
,
const
FloatB
*
const
__restrict__
p_b_block
,
const
FloatB
*
const
__restrict__
p_b_block
,
FloatC
*
const
__restrict__
p_c_thread
,
FloatC
*
const
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
Accumulator
f_accum
)
const
...
@@ -973,17 +965,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -973,17 +965,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_assert
(
MPerThreadSubC
==
4
&&
NPerThreadSubC
==
4
&&
MRepeat
==
2
&&
NRepeat
==
2
&&
KPerThreadLoop
==
1
&&
K
==
1
,
"asm is not for this mtx shape"
);
const
FloatA
*
const
p_a_block_thread_offset
=
p_a_block
+
mMyThreadOffsetA
;
const
FloatA
*
const
p_a_block_thread_offset
=
p_a_block
+
mMyThreadOffsetA
;
#pragma unroll
#pragma unroll
// loop over k
// loop over k
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
{
#if 0
#pragma unroll
#pragma unroll
// copy A-sub to form A
// copy A-sub to form A
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
...
@@ -993,67 +980,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -993,67 +980,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
mMyThreadOffsetA
,
mMyThreadOffsetA
,
a_thread_mtx
,
a_thread_mtx
,
a_thread_sub_mtx.NCol(
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
a_thread_sub_mtx
.
GetLengths
());
a_thread_sub_mtx
.
GetLengths
());
}
}
#elif
1
// this produce right result
using
vectorA_t
=
typename
vector_type
<
FloatA
,
4
>::
MemoryType
;
// this is float4*
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1
\n
\
s_waitcnt lgkmcnt(0)"
:
"=v"
(
*
(
reinterpret_cast
<
vectorA_t
*>
(
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
0
))))
:
"v"
(
__to_local
(
(
void
*
)(
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
))));
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1
\n
\
s_waitcnt lgkmcnt(0)"
:
"=v"
(
*
(
reinterpret_cast
<
vectorA_t
*>
(
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
MPerThreadSubC
))))
:
"v"
(
__to_local
((
void
*
)(
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
MPerLevel1Cluster
)
+
mMyThreadOffsetA
))));
#elif 0
// this produce wrong result
using
vectorA_t
=
typename
vector_type
<
FloatA
,
4
>::
MemoryType
;
// this is float4*
asm
volatile
(
"
\n
\
ds_read_b128 %0, %2
\n
\
ds_read_b128 %1, %3
\n
\
s_waitcnt lgkmcnt(0)"
:
"=v"
(
*
(
reinterpret_cast
<
vectorA_t
*>
(
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
0
)))),
"=v"
(
*
(
reinterpret_cast
<
vectorA_t
*>
(
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
MPerThreadSubC
))))
:
"v"
(
__to_local
(
(
void
*
)(
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
))),
"v"
(
__to_local
((
void
*
)(
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
MPerLevel1Cluster
)
+
mMyThreadOffsetA
))));
#elif 1
// this produce wrong result
using
vectorA_t
=
typename
vector_type
<
FloatA
,
4
>::
MemoryType
;
// this is float4*
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1
\n
\
s_waitcnt lgkmcnt(0)"
:
"=v"
(
*
(
reinterpret_cast
<
vectorA_t
*>
(
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
0
))))
:
"v"
(
__to_local
((
void
*
)(
p_a_block_thread_offset
))));
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:16
\n
\
s_waitcnt lgkmcnt(0)"
:
"=v"
(
*
(
reinterpret_cast
<
vectorA_t
*>
(
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
MPerThreadSubC
))))
:
"v"
(
__to_local
((
void
*
)(
p_a_block_thread_offset
))));
#endif
//
#pragma unroll
#pragma unroll
// copy B-sub to form B
// copy B-sub to form B
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
{
...
@@ -1066,8 +997,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -1066,8 +997,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
b_thread_sub_mtx
.
GetLengths
());
b_thread_sub_mtx
.
GetLengths
());
}
}
// C = A * B
// C = A * B
#if 1
threadwise_gemm
(
a_thread_mtx
,
threadwise_gemm
(
a_thread_mtx
,
True
,
True
,
p_a_thread
,
p_a_thread
,
...
@@ -1078,58 +1008,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -1078,58 +1008,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
False
,
False
,
p_c_thread
,
p_c_thread
,
f_accum
);
f_accum
);
#elif 0
// inline asm
static_assert
(
c_thread_mtx
.
NRow
()
==
8
&&
c_thread_mtx
.
NCol
()
==
8
,
"asm is only for 8x8"
);
for
(
index_t
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
// A is transposed
{
const
index_t
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
0
);
for
(
index_t
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
{
const
index_t
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
index_t
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
0
);
asm
volatile
(
"
\n
\
v_mac_f32 %0, %8, %9
\n
\
v_mac_f32 %1, %8, %10
\n
\
v_mac_f32 %2, %8, %11
\n
\
v_mac_f32 %3, %8, %12
\n
\
v_mac_f32 %4, %8, %13
\n
\
v_mac_f32 %5, %8, %14
\n
\
v_mac_f32 %6, %8, %15
\n
\
v_mac_f32 %7, %8, %16
\n
\
"
:
"=v"
(
p_c_thread
[
cindex
+
0
]),
"=v"
(
p_c_thread
[
cindex
+
1
]),
"=v"
(
p_c_thread
[
cindex
+
2
]),
"=v"
(
p_c_thread
[
cindex
+
3
]),
"=v"
(
p_c_thread
[
cindex
+
4
]),
"=v"
(
p_c_thread
[
cindex
+
5
]),
"=v"
(
p_c_thread
[
cindex
+
6
]),
"=v"
(
p_c_thread
[
cindex
+
7
])
:
"v"
(
p_a_thread
[
aindex
]),
"v"
(
p_b_thread
[
bindex
+
0
]),
"v"
(
p_b_thread
[
bindex
+
1
]),
"v"
(
p_b_thread
[
bindex
+
2
]),
"v"
(
p_b_thread
[
bindex
+
3
]),
"v"
(
p_b_thread
[
bindex
+
4
]),
"v"
(
p_b_thread
[
bindex
+
5
]),
"v"
(
p_b_thread
[
bindex
+
6
]),
"v"
(
p_b_thread
[
bindex
+
7
]),
"0"
(
p_c_thread
[
cindex
+
0
]),
"1"
(
p_c_thread
[
cindex
+
1
]),
"2"
(
p_c_thread
[
cindex
+
2
]),
"3"
(
p_c_thread
[
cindex
+
3
]),
"4"
(
p_c_thread
[
cindex
+
4
]),
"5"
(
p_c_thread
[
cindex
+
5
]),
"6"
(
p_c_thread
[
cindex
+
6
]),
"7"
(
p_c_thread
[
cindex
+
7
]));
}
}
#endif
}
}
}
}
...
...
src/include/common.hip.hpp
View file @
e43d7bc6
...
@@ -5,8 +5,6 @@
...
@@ -5,8 +5,6 @@
#include "Array.hip.hpp"
#include "Array.hip.hpp"
#include "functional.hip.hpp"
#include "functional.hip.hpp"
extern
"C"
__attribute__
((
address_space
(
3
)))
void
*
__to_local
(
void
*
p
)[[
hc
]];
__device__
index_t
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
index_t
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
__device__
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
...
@@ -23,21 +21,45 @@ struct is_same<T, T>
...
@@ -23,21 +21,45 @@ struct is_same<T, T>
static
const
bool
value
=
true
;
static
const
bool
value
=
true
;
};
};
#if DEVICE_BACKEND_CUDA
__host__
__device__
constexpr
index_t
integer_divide_ceil
(
index_t
a
,
index_t
b
)
template
<
typename
T
>
__host__
__device__
constexpr
T
max
(
T
a
,
T
b
)
{
{
return
a
>
b
?
a
:
b
;
return
(
a
+
b
-
1
)
/
b
;
}
}
template
<
typename
T
>
namespace
mod_conv
{
__host__
__device__
constexpr
T
min
(
T
a
,
T
b
)
template
<
class
T
>
__host__
__device__
constexpr
T
max
(
T
x
,
T
y
)
{
{
return
a
<
b
?
a
:
b
;
return
x
>
y
?
x
:
y
;
}
}
#endif
__host__
__device__
constexpr
index_t
integer_divide_ceil
(
index_t
a
,
index_t
b
)
template
<
class
T
,
class
...
Ts
>
__host__
__device__
constexpr
T
max
(
T
x
,
Ts
...
xs
)
{
{
return
(
a
+
b
-
1
)
/
b
;
static_assert
(
sizeof
...(
xs
)
>
0
,
"not enough argument"
);
auto
y
=
max
(
xs
...);
static_assert
(
is_same
<
decltype
(
y
),
T
>::
value
,
"not the same type"
);
return
x
>
y
?
x
:
y
;
}
template
<
class
T
>
__host__
__device__
constexpr
T
min
(
T
x
,
T
y
)
{
return
x
<
y
?
x
:
y
;
}
template
<
class
T
,
class
...
Ts
>
__host__
__device__
constexpr
T
min
(
T
x
,
Ts
...
xs
)
{
static_assert
(
sizeof
...(
xs
)
>
0
,
"not enough argument"
);
auto
y
=
min
(
xs
...);
static_assert
(
is_same
<
decltype
(
y
),
T
>::
value
,
"not the same type"
);
return
x
<
y
?
x
:
y
;
}
}
}
src/include/gridwise_direct_convolution_1.hip.hpp
View file @
e43d7bc6
...
@@ -59,12 +59,12 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
...
@@ -59,12 +59,12 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
constexpr
auto
out_block_desc
=
constexpr
auto
out_block_desc
=
make_ConstantTensorDescriptor
(
out_block_global_desc
.
GetLengths
());
make_ConstantTensorDescriptor
(
out_block_global_desc
.
GetLengths
());
constexpr
index_t
in_block_size
=
in_block_desc
.
GetElementSpace
();
constexpr
index_t
in_block_
element_
size
=
in_block_desc
.
GetElementSpace
();
constexpr
index_t
wei_block_size
=
wei_block_desc
.
GetElementSpace
();
constexpr
index_t
wei_block_
element_
size
=
wei_block_desc
.
GetElementSpace
();
constexpr
index_t
out_block_size
=
out_block_desc
.
GetElementSpace
();
constexpr
index_t
out_block_size
=
out_block_desc
.
GetElementSpace
();
__shared__
Float
p_in_block
[
in_block_size
];
__shared__
Float
p_in_block
[
in_block_
element_
size
];
__shared__
Float
p_wei_block
[
wei_block_size
];
__shared__
Float
p_wei_block
[
wei_block_
element_
size
];
__shared__
Float
p_out_block
[
out_block_size
];
__shared__
Float
p_out_block
[
out_block_size
];
const
index_t
block_id
=
blockIdx
.
x
;
const
index_t
block_id
=
blockIdx
.
x
;
...
...
src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp
View file @
e43d7bc6
...
@@ -63,17 +63,18 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
...
@@ -63,17 +63,18 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
Sequence
<
wei_ke_block_desc
.
GetStride
(
I0
),
Y
*
X
,
X
,
1
>
{});
Sequence
<
wei_ke_block_desc
.
GetStride
(
I0
),
Y
*
X
,
X
,
1
>
{});
// shared mem
// shared mem
constexpr
index_t
in_block_size
=
constexpr
index_t
in_block_
element_
size
=
in_nchw_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
in_nchw_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
index_t
wei_block_size
=
constexpr
index_t
wei_block_
element_
size
=
wei_kcyx_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
wei_kcyx_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
?
InBlockCopyDataPerRead
?
InBlockCopyDataPerRead
:
WeiBlockCopyDataPerRead
;
:
WeiBlockCopyDataPerRead
;
__shared__
Float
p_in_block
[
max_align
*
((
in_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_in_block
[
max_align
*
((
in_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block
[
max_align
*
((
wei_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block
[
max_align
*
((
wei_block_element_size
+
max_align
-
1
)
/
max_align
)];
// threadwise tensors
// threadwise tensors
constexpr
index_t
HiPerThread
=
HoPerThread
+
Y
-
1
;
constexpr
index_t
HiPerThread
=
HoPerThread
+
Y
-
1
;
...
...
src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
View file @
e43d7bc6
...
@@ -73,10 +73,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
...
@@ -73,10 +73,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
Sequence
<
wei_ke_vec_block_desc
.
GetStride
(
I0
),
Y
*
X
,
X
,
1
>
{});
Sequence
<
wei_ke_vec_block_desc
.
GetStride
(
I0
),
Y
*
X
,
X
,
1
>
{});
// shared mem
// shared mem
constexpr
index_t
in_block_size
=
constexpr
index_t
in_block_
element_
size
=
in_nchw_vec_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
in_nchw_vec_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
index_t
wei_block_size
=
constexpr
index_t
wei_block_
element_
size
=
wei_kcyx_vec_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
wei_kcyx_vec_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
...
@@ -84,9 +84,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
...
@@ -84,9 +84,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
:
WeiBlockCopyDataPerRead
;
:
WeiBlockCopyDataPerRead
;
__shared__
in_vector_mem_t
__shared__
in_vector_mem_t
p_in_vec_block
[
max_align
*
((
in_block_size
+
max_align
-
1
)
/
max_align
)];
p_in_vec_block
[
max_align
*
((
in_block_
element_
size
+
max_align
-
1
)
/
max_align
)];
__shared__
in_vector_mem_t
__shared__
in_vector_mem_t
p_wei_vec_block
[
max_align
*
((
wei_block_size
+
max_align
-
1
)
/
max_align
)];
p_wei_vec_block
[
max_align
*
((
wei_block_
element_
size
+
max_align
-
1
)
/
max_align
)];
// threadwise tensors
// threadwise tensors
constexpr
index_t
HiPerThread
=
HoPerThread
+
Y
-
1
;
constexpr
index_t
HiPerThread
=
HoPerThread
+
Y
-
1
;
...
...
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp
View file @
e43d7bc6
...
@@ -164,18 +164,19 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
...
@@ -164,18 +164,19 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
HoPerThread
>
{};
HoPerThread
>
{};
// LDS: be careful of alignment
// LDS: be careful of alignment
constexpr
index_t
in_block_size
=
constexpr
index_t
in_block_
element_
size
=
in_chwn_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
in_chwn_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
index_t
wei_block_size
=
constexpr
index_t
wei_block_
element_
size
=
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
?
InBlockCopyDataPerRead
?
InBlockCopyDataPerRead
:
WeiBlockCopyDataPerRead
;
:
WeiBlockCopyDataPerRead
;
__shared__
Float
p_in_block
[
max_align
*
((
in_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_in_block
[
max_align
*
((
in_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block
[
max_align
*
((
wei_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block
[
max_align
*
((
wei_block_element_size
+
max_align
-
1
)
/
max_align
)];
// register
// register
Float
p_out_thread
[
out_khwn_thread_desc
.
GetElementSpace
()];
Float
p_out_thread
[
out_khwn_thread_desc
.
GetElementSpace
()];
...
...
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp
View file @
e43d7bc6
...
@@ -204,11 +204,11 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
...
@@ -204,11 +204,11 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
true
>
{};
true
>
{};
// LDS
// LDS
constexpr
index_t
in_block_size
=
in_chwn_block_desc
.
GetElementSpace
();
constexpr
index_t
in_block_
element_
size
=
in_chwn_block_desc
.
GetElementSpace
();
constexpr
index_t
wei_block_size
=
wei_cyxk_block_desc
.
GetElementSpace
();
constexpr
index_t
wei_block_
element_
size
=
wei_cyxk_block_desc
.
GetElementSpace
();
__shared__
Float
p_in_block
[
in_block_size
];
__shared__
Float
p_in_block
[
in_block_
element_
size
];
__shared__
Float
p_wei_block
[
wei_block_size
];
__shared__
Float
p_wei_block
[
wei_block_
element_
size
];
// register
// register
Float
p_out_thread
[
out_hkwn_thread_desc
.
GetElementSpace
()];
Float
p_out_thread
[
out_hkwn_thread_desc
.
GetElementSpace
()];
...
...
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
View file @
e43d7bc6
...
@@ -34,11 +34,57 @@ template <index_t GridSize,
...
@@ -34,11 +34,57 @@ template <index_t GridSize,
index_t
WeiBlockCopyThreadPerDim1
,
index_t
WeiBlockCopyThreadPerDim1
,
index_t
InBlockCopyDataPerRead
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
>
index_t
WeiBlockCopyDataPerRead
>
__global__
void
class
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
(
const
Float
*
const
__restrict__
p_in_global
,
{
public:
__host__
__device__
static
index_t
GetSharedMemorySize
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
Hi
=
in_chwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wi
=
in_chwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Y
=
wei_cyxk_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_cyxk_global_desc
.
GetLength
(
I2
);
constexpr
index_t
BGhostRead
=
(
Y
-
1
)
*
Wi
+
(
X
-
1
);
// tensor view of blockwise input and weight
// be careful of alignment
constexpr
auto
in_cb_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
BPerBlock
+
BGhostRead
>
{},
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_kb_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
BPerThread
>
{});
constexpr
index_t
max_align
=
mod_conv
::
max
(
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
// LDS: be careful of alignment
constexpr
index_t
in_block_element_space
=
in_cb_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_element_space
=
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
return
(
in_block_element_space
+
wei_block_element_space
)
*
sizeof
(
Float
);
}
__global__
static
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
Float
*
const
__restrict__
p_out_global
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -121,7 +167,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
...
@@ -121,7 +167,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
decltype(in_cb_block_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{};
decltype(in_cb_block_desc.GetLengths())>{};
#elif
0
#elif
0
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
Float
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
),
...
@@ -129,7 +176,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
...
@@ -129,7 +176,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
>
{};
InBlockCopyThreadPerDim1
>
{};
#elif 1
#elif 1
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
),
...
@@ -147,7 +195,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
...
@@ -147,7 +195,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
decltype(wei_ek_block_desc.GetLengths())>{};
#elif
0
#elif
0
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
Float
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
),
...
@@ -155,7 +204,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
...
@@ -155,7 +204,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
{};
WeiBlockCopyThreadPerDim1
>
{};
#elif 1
#elif 1
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
),
...
@@ -192,19 +242,17 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
...
@@ -192,19 +242,17 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
GemmKPerThreadLoop
>
{};
GemmKPerThreadLoop
>
{};
// LDS: be careful of alignment
// LDS: be careful of alignment
constexpr
index_t
in_block_size
=
constexpr
index_t
max_align
=
in_cb_block_desc
.
GetElementSpace
(
Number
<
In
BlockCopyDataPerRead
>
{}
);
mod_conv
::
max
(
InBlockCopyDataPerRead
,
Wei
BlockCopyDataPerRead
);
constexpr
index_t
we
i_block_
siz
e
=
constexpr
index_t
i
n
_block_
element_spac
e
=
wei_cyxk
_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
in_cb
_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
constexpr
index_t
wei_block_element_space
=
?
InBlockCopyDataPerRead
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
:
WeiBlockCopyDataPerRead
;
// LDS
__shared__
Float
p_in_block
[
in_block_element_space
];
__shared__
Float
p_in_block
[
max_align
*
((
in_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block
[
wei_block_element_space
];
__shared__
Float
p_wei_block
[
max_align
*
((
wei_block_size
+
max_align
-
1
)
/
max_align
)];
const
Float
*
p_in_global_block_offset
=
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_cb_global_desc
.
Get1dIndex
(
0
,
b_block_data_begin
);
p_in_global
+
in_cb_global_desc
.
Get1dIndex
(
0
,
b_block_data_begin
);
...
@@ -236,12 +284,12 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
...
@@ -236,12 +284,12 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
#if
1
#if
0
blockwise_gemm.Run
blockwise_gemm.Run
#elif 1
blockwise_gemm
.
Run_asm
#elif
1
#elif
1
blockwise_gemm
.
Run_RegisterDoubleBuffer
blockwise_gemm
.
Run_RegisterDoubleBuffer
#elif 0
blockwise_gemm
.
Run_asm
#endif
#endif
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block
+
y
*
Wi
+
x
,
p_in_block
+
y
*
Wi
+
x
,
...
@@ -280,4 +328,5 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
...
@@ -280,4 +328,5 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
}
}
}
}
}
}
}
}
};
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
View file @
e43d7bc6
...
@@ -34,15 +34,13 @@ template <index_t GridSize,
...
@@ -34,15 +34,13 @@ template <index_t GridSize,
index_t
WeiBlockCopyThreadPerDim1
,
index_t
WeiBlockCopyThreadPerDim1
,
index_t
InBlockCopyDataPerRead
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
>
index_t
WeiBlockCopyDataPerRead
>
__global__
void
class
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
#if 0
{
__launch_bounds__(256,2)
public:
#endif
__global__
static
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
Float
*
const
__restrict__
p_out_global
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -125,7 +123,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
...
@@ -125,7 +123,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
decltype(in_cb_block_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{};
decltype(in_cb_block_desc.GetLengths())>{};
#elif
0
#elif
0
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
Float
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
),
...
@@ -133,7 +132,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
...
@@ -133,7 +132,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
>
{};
InBlockCopyThreadPerDim1
>
{};
#elif 1
#elif 1
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
),
...
@@ -151,7 +151,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
...
@@ -151,7 +151,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
decltype(wei_ek_block_desc.GetLengths())>{};
#elif
0
#elif
0
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
Float
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
),
...
@@ -159,7 +160,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
...
@@ -159,7 +160,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
{};
WeiBlockCopyThreadPerDim1
>
{};
#elif 1
#elif 1
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
),
...
@@ -210,10 +212,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
...
@@ -210,10 +212,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
#endif
#endif
// LDS: be careful of alignment
// LDS: be careful of alignment
constexpr
index_t
in_block_size
=
constexpr
index_t
in_block_
element_
size
=
in_cb_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
in_cb_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
index_t
wei_block_size
=
constexpr
index_t
wei_block_
element_
size
=
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
...
@@ -221,11 +223,15 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
...
@@ -221,11 +223,15 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
:
WeiBlockCopyDataPerRead
;
:
WeiBlockCopyDataPerRead
;
// LDS double buffer
// LDS double buffer
__shared__
Float
p_in_block_0
[
max_align
*
((
in_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
__shared__
Float
p_wei_block_0
[
max_align
*
((
wei_block_size
+
max_align
-
1
)
/
max_align
)];
p_in_block_0
[
max_align
*
((
in_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block_0
[
max_align
*
((
wei_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_in_block_1
[
max_align
*
((
in_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
__shared__
Float
p_wei_block_1
[
max_align
*
((
wei_block_size
+
max_align
-
1
)
/
max_align
)];
p_in_block_1
[
max_align
*
((
in_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block_1
[
max_align
*
((
wei_block_element_size
+
max_align
-
1
)
/
max_align
)];
const
Float
*
p_in_global_block_offset
=
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_cb_global_desc
.
Get1dIndex
(
0
,
b_block_data_begin
);
p_in_global
+
in_cb_global_desc
.
Get1dIndex
(
0
,
b_block_data_begin
);
...
@@ -298,7 +304,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
...
@@ -298,7 +304,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
#if 1
#if 1
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_next
);
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
#endif
#endif
}
}
...
@@ -370,4 +377,5 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
...
@@ -370,4 +377,5 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
}
}
}
}
}
}
}
}
};
src/include/threadwise_gemm.hip.hpp
View file @
e43d7bc6
...
@@ -10,11 +10,9 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
...
@@ -10,11 +10,9 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
constexpr
auto
src_mtx
=
SrcMatrix
{};
constexpr
auto
src_mtx
=
SrcMatrix
{};
constexpr
auto
dst_mtx
=
DstMatrix
{};
constexpr
auto
dst_mtx
=
DstMatrix
{};
#if 1
#if 0
//NRow = 1
for(index_t i = 0; i < NRow; ++i)
for(index_t i = 0; i < NRow; ++i)
{
{
//NCol = 4
for(index_t j = 0; j < NCol; ++j)
for(index_t j = 0; j < NCol; ++j)
{
{
const index_t src_index = src_mtx.Get1dIndex(i, j);
const index_t src_index = src_mtx.Get1dIndex(i, j);
...
@@ -23,7 +21,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
...
@@ -23,7 +21,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
p_dst[dst_index] = p_src[src_index];
p_dst[dst_index] = p_src[src_index];
}
}
}
}
#elif
0
#elif
1
static_assert
(
NCol
==
4
,
"only for NCol == 4"
);
static_assert
(
NCol
==
4
,
"only for NCol == 4"
);
using
vector_t
=
typename
vector_type
<
Float
,
4
>::
MemoryType
;
using
vector_t
=
typename
vector_type
<
Float
,
4
>::
MemoryType
;
...
@@ -33,22 +31,8 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
...
@@ -33,22 +31,8 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
const
index_t
src_index
=
src_mtx
.
Get1dIndex
(
i
,
0
);
const
index_t
src_index
=
src_mtx
.
Get1dIndex
(
i
,
0
);
const
index_t
dst_index
=
dst_mtx
.
Get1dIndex
(
i
,
0
);
const
index_t
dst_index
=
dst_mtx
.
Get1dIndex
(
i
,
0
);
#if 0
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
]))
=
*(reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]));
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]));
#elif
0
asm
volatile
(
"
\n
\
ds_read2_b64 %0, %1 offset1:1
\n
\
s_waitcnt lgkmcnt(0)"
:
"=v"
(
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])))
:
"v"
(
__to_local
((
void
*
)(
&
p_src
[
src_index
]))));
#elif 1
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1
\n
\
s_waitcnt lgkmcnt(0)"
:
"=v"
(
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])))
:
"v"
(
__to_local
((
void
*
)(
&
p_src
[
src_index
]))));
#endif
}
}
#endif
#endif
}
}
...
@@ -84,13 +68,10 @@ __device__ void threadwise_gemm(MatrixA,
...
@@ -84,13 +68,10 @@ __device__ void threadwise_gemm(MatrixA,
constexpr
index_t
N
=
c_mtx
.
NCol
();
constexpr
index_t
N
=
c_mtx
.
NCol
();
constexpr
index_t
K
=
a_mtx
.
NRow
();
// A is transposed
constexpr
index_t
K
=
a_mtx
.
NRow
();
// A is transposed
// K = 1
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
{
// M = 8
for
(
index_t
i
=
0
;
i
<
M
;
++
i
)
for
(
index_t
i
=
0
;
i
<
M
;
++
i
)
{
{
// N = 8
for
(
index_t
j
=
0
;
j
<
N
;
++
j
)
for
(
index_t
j
=
0
;
j
<
N
;
++
j
)
{
{
const
index_t
aindex
=
a_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
index_t
aindex
=
a_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment