Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
b57d60c0
Commit
b57d60c0
authored
Apr 06, 2019
by
Chao Liu
Browse files
refactor
parent
bd0098af
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
386 additions
and
373 deletions
+386
-373
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
+52
-46
driver/driver.hip.cpp
driver/driver.hip.cpp
+2
-2
src/include/blockwise_batched_gemm.hip.hpp
src/include/blockwise_batched_gemm.hip.hpp
+4
-7
src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp
...dwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp
+328
-0
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp
...idwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp
+0
-318
No files found.
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
View file @
b57d60c0
#pragma once
#pragma once
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
(
InDesc
,
void
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
(
InDesc
,
...
@@ -260,14 +261,6 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
...
@@ -260,14 +261,6 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
...
@@ -276,6 +269,13 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
...
@@ -276,6 +269,13 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
...
@@ -289,8 +289,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
...
@@ -289,8 +289,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
float
time
=
launch_kernel
(
constexpr
auto
gridwise_conv
=
g
ridwise
_i
mplicit
_g
emm_
convolution_
1_chwn_cyxk_khwn
<
GridSize
,
G
ridwise
ConvolutionI
mplicit
G
emm_
v
1_chwn_cyxk_khwn
<
GridSize
,
BlockSize
,
BlockSize
,
T
,
T
,
decltype
(
in_chwn_desc
),
decltype
(
in_chwn_desc
),
...
@@ -305,12 +305,6 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
...
@@ -305,12 +305,6 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
KPerThread
,
KPerThread
,
HoPerThread
,
HoPerThread
,
WoPerThread
,
WoPerThread
,
Sequence
<
InBlockCopy_ThreadPerDimC
,
InBlockCopy_ThreadPerDimH
,
InBlockCopy_ThreadPerDimW
,
InBlockCopy_ThreadPerDimN
>
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
,
GemmMPerThreadSubC
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmMLevel0Cluster
,
...
@@ -318,14 +312,26 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
...
@@ -318,14 +312,26 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
GemmMLevel1Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmKPerThreadLoop
,
OutThreadCopyDataPerWrite
>
,
Sequence
<
InBlockCopy_ThreadPerDimC
,
InBlockCopy_ThreadPerDimH
,
InBlockCopy_ThreadPerDimW
,
InBlockCopy_ThreadPerDimN
>
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
,
OutThreadCopyDataPerWrite
>
{};
float
time
=
launch_kernel
(
run_gridwise_convolution
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_cyxk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_cyxk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1024
)
*
1024
*
1024
*
1024
)
/
(
time
/
1000
));
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
}
...
...
driver/driver.hip.cpp
View file @
b57d60c0
...
@@ -661,9 +661,9 @@ int main(int argc, char* argv[])
...
@@ -661,9 +661,9 @@ int main(int argc, char* argv[])
device_direct_convolution_2_nchw_kcyx_nkhw
device_direct_convolution_2_nchw_kcyx_nkhw
#elif 0
#elif 0
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 0
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
#elif 1
#elif 1
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
#elif 0
device_implicit_gemm_convolution_2_chwn_cyxk_khwn
device_implicit_gemm_convolution_2_chwn_cyxk_khwn
#endif
#endif
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
...
...
src/include/blockwise_batched_gemm.hip.hpp
View file @
b57d60c0
...
@@ -164,11 +164,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -164,11 +164,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
n_repeat
*
NPerLevel1Cluster
+
n_in_sub_c
};
n_repeat
*
NPerLevel1Cluster
+
n_in_sub_c
};
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
*
__restrict__
p_a_block
,
__device__
void
Run
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
FloatC
*
__restrict__
p_c_thread
)
const
Accumulator
f_accum
)
const
{
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
...
@@ -250,8 +249,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -250,8 +249,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
p_b_thread
,
p_b_thread
,
c_thread_mtx
,
c_thread_mtx
,
False
,
False
,
p_c_thread
+
ib
*
ThreadMatrixStrideC
,
p_c_thread
+
ib
*
ThreadMatrixStrideC
);
f_accum
);
// read next batch of a, b
// read next batch of a, b
if
(
BlockMatrixStrideA
!=
0
)
if
(
BlockMatrixStrideA
!=
0
)
...
@@ -296,8 +294,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -296,8 +294,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
p_b_thread
,
p_b_thread
,
c_thread_mtx
,
c_thread_mtx
,
False
,
False
,
p_c_thread
+
(
BatchPerThread
-
1
)
*
ThreadMatrixStrideC
,
p_c_thread
+
(
BatchPerThread
-
1
)
*
ThreadMatrixStrideC
);
f_accum
);
}
}
}
}
...
...
src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp
0 → 100644
View file @
b57d60c0
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_nd_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_batched_gemm.hip.hpp"
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
class
InBlockCopyThreadPerDims
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
,
index_t
OutThreadCopyDataPerWrite
>
struct
GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
static_assert
(
NPerBlock
%
NPerThread
==
0
,
"wrong! NPerBlock % NPerThread !=0"
);
static_assert
((
NPerThread
<
NPerBlock
&&
WoPerThread
==
1
)
||
NPerThread
==
NPerBlock
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_chwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_khwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_khwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wo
=
out_khwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
out_khwn_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_cyxk_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_cyxk_global_desc
.
GetLength
(
I2
);
constexpr
index_t
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
// divide block work: [K, Ho, Wo, N]
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
WBlockWork
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
;
constexpr
index_t
NBlockWork
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
index_t
itmp
=
get_block_1d_id
()
-
k_block_work_id
*
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
const
index_t
h_block_work_id
=
itmp
/
(
WBlockWork
*
NBlockWork
);
itmp
-=
h_block_work_id
*
(
WBlockWork
*
NBlockWork
);
const
index_t
w_block_work_id
=
itmp
/
NBlockWork
;
const
index_t
n_block_work_id
=
itmp
-
w_block_work_id
*
NBlockWork
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
index_t
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// flattend (2d) tensor view of gridwise weight
constexpr
auto
wei_ek_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
*
Y
*
X
,
K
>
{});
// tensor view of blockwise input and weight in LDS
// be careful of alignment
constexpr
auto
in_chwn_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
auto
wei_ek_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
*
Y
*
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_khwn_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
// blockwise copy
// input: format is [C, Hi, Wi, N]
const
auto
blockwise_in_copy
=
Blockwise4dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
in_chwn_global_desc
),
decltype
(
in_chwn_block_desc
),
decltype
(
in_chwn_block_desc
.
GetLengths
()),
InBlockCopyThreadPerDims
,
InBlockCopyDataPerRead
>
{};
// blockwise wei copy
// format is [CPerBlock*Y*X,KPerBlock]
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_cyxk_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_cxwn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{},
Number
<
in_chwn_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_kxwn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{},
Number
<
out_khwn_thread_desc
.
GetStride
(
I1
)
>
{});
const
auto
blockwise_batch_gemm
=
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
0
,
in_chwn_block_desc
.
GetStride
(
I1
),
out_khwn_thread_desc
.
GetStride
(
I1
),
HoPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
HoPerThread
>
{};
// LDS: be careful of alignment
constexpr
index_t
in_block_element_size
=
in_chwn_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
index_t
wei_block_element_size
=
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
?
InBlockCopyDataPerRead
:
WeiBlockCopyDataPerRead
;
__shared__
Float
p_in_block
[
max_align
*
((
in_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block
[
max_align
*
((
wei_block_element_size
+
max_align
-
1
)
/
max_align
)];
// register
Float
p_out_thread
[
out_khwn_thread_desc
.
GetElementSpace
()];
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_khwn_thread_desc
,
p_out_thread
);
const
Float
*
p_in_global_block_begin
=
p_in_global
+
in_chwn_global_desc
.
Get1dIndex
(
0
,
hi_block_data_begin
,
wi_block_data_begin
,
n_block_data_begin
);
const
Float
*
p_wei_global_block_begin
=
p_wei_global
+
wei_cyxk_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_in_global_block_begin
+=
CPerBlock
*
in_chwn_global_desc
.
GetStride
(
I0
),
p_wei_global_block_begin
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
),
__syncthreads
())
{
// input: global mem to LDS
blockwise_in_copy
.
Run
(
p_in_global_block_begin
,
p_in_block
);
// weight: global mem to LDS
blockwise_wei_copy
.
Run
(
p_wei_global_block_begin
,
p_wei_block
);
__syncthreads
();
// a series of batched GEMM
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
blockwise_batch_gemm
.
Run
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_out_thread
);
}
}
}
// output: register to global mem,
#if 0
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
{
for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
{
for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
{
for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
{
const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
const auto c_thread_mtx_distance =
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
const index_t ho_thread =
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
const index_t wo_thread = b_thread / NPerBlock;
const index_t n_thread = b_thread % NPerBlock;
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
ho_block_data_begin + ho_thread,
wo_block_data_begin + wo_thread,
n_block_data_begin + n_thread)] =
p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
}
}
}
}
#elif
1
const
auto
c_thread_mtx_begin
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_begin
=
c_thread_mtx_begin
.
row
;
const
index_t
ho_thread_data_begin
=
c_thread_mtx_begin
.
batch
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
-
NPerBlock
*
wo_thread_data_begin
;
// this is for v2 GEMM
// output is a 8d tensor
if
(
NPerThread
<
NPerBlock
&&
WoPerThread
==
1
)
{
constexpr
index_t
N1_
=
GemmNPerThreadSubC
;
constexpr
index_t
W1_
=
WoPerBlock
/
((
WoPerThread
*
NPerThread
)
/
GemmNPerThreadSubC
);
constexpr
index_t
K2_
=
GemmMPerThreadSubC
;
constexpr
index_t
K1_
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_8d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1_
*
K2_
),
K1_
,
K2_
,
Ho
,
Wo
/
W1_
,
W1_
,
N
/
N1_
,
N1_
>
{});
constexpr
auto
out_8d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerBlock
/
(
K1_
*
K2_
),
1
,
K2_
,
HoPerThread
,
WoPerBlock
/
W1_
,
1
,
1
,
N1_
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
print_ConstantTensorDescriptor(out_8d_thread_desc, "out_8d_thread_desc");
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
print_ConstantTensorDescriptor(out_8d_global_desc, "out_8d_global_desc");
}
#endif
threadwise_8d_tensor_copy
(
out_8d_thread_desc
,
p_out_thread
,
out_8d_global_desc
,
p_out_global
+
out_khwn_global_desc
.
Get1dIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_8d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite
>
{});
}
else
if
(
NPerThread
==
NPerBlock
)
{
// not implemented yet
assert
(
false
);
}
else
{
assert
(
false
);
}
#endif
}
};
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp
deleted
100644 → 0
View file @
bd0098af
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_nd_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_batched_gemm.hip.hpp"
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
class
InBlockCopyThreadPerDims
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
OutThreadCopyDataPerWrite
>
__global__
void
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
{
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
static_assert
(
NPerBlock
%
NPerThread
==
0
,
"wrong! NPerBlock % NPerThread !=0"
);
static_assert
((
NPerThread
<
NPerBlock
&&
WoPerThread
==
1
)
||
NPerThread
==
NPerBlock
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_chwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_khwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_khwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wo
=
out_khwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
out_khwn_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_cyxk_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_cyxk_global_desc
.
GetLength
(
I2
);
constexpr
index_t
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
// divide block work: [K, Ho, Wo, N]
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
WBlockWork
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
;
constexpr
index_t
NBlockWork
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
index_t
itmp
=
get_block_1d_id
()
-
k_block_work_id
*
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
const
index_t
h_block_work_id
=
itmp
/
(
WBlockWork
*
NBlockWork
);
itmp
-=
h_block_work_id
*
(
WBlockWork
*
NBlockWork
);
const
index_t
w_block_work_id
=
itmp
/
NBlockWork
;
const
index_t
n_block_work_id
=
itmp
-
w_block_work_id
*
NBlockWork
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
index_t
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// flattend (2d) tensor view of gridwise weight
constexpr
auto
wei_ek_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
*
Y
*
X
,
K
>
{});
// tensor view of blockwise input and weight in LDS
// be careful of alignment
constexpr
auto
in_chwn_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
auto
wei_ek_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
*
Y
*
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_khwn_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
// blockwise copy
// input: format is [C, Hi, Wi, N]
const
auto
blockwise_in_copy
=
Blockwise4dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
in_chwn_global_desc
),
decltype
(
in_chwn_block_desc
),
decltype
(
in_chwn_block_desc
.
GetLengths
()),
InBlockCopyThreadPerDims
,
InBlockCopyDataPerRead
>
{};
// blockwise wei copy
// format is [CPerBlock*Y*X,KPerBlock]
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_cyxk_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_cxwn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{},
Number
<
in_chwn_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_kxwn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{},
Number
<
out_khwn_thread_desc
.
GetStride
(
I1
)
>
{});
const
auto
blockwise_batch_gemm
=
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
0
,
in_chwn_block_desc
.
GetStride
(
I1
),
out_khwn_thread_desc
.
GetStride
(
I1
),
HoPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
HoPerThread
>
{};
// LDS: be careful of alignment
constexpr
index_t
in_block_element_size
=
in_chwn_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
index_t
wei_block_element_size
=
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
?
InBlockCopyDataPerRead
:
WeiBlockCopyDataPerRead
;
__shared__
Float
p_in_block
[
max_align
*
((
in_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block
[
max_align
*
((
wei_block_element_size
+
max_align
-
1
)
/
max_align
)];
// register
Float
p_out_thread
[
out_khwn_thread_desc
.
GetElementSpace
()];
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_khwn_thread_desc
,
p_out_thread
);
const
Float
*
p_in_global_block_begin
=
p_in_global
+
in_chwn_global_desc
.
Get1dIndex
(
0
,
hi_block_data_begin
,
wi_block_data_begin
,
n_block_data_begin
);
const
Float
*
p_wei_global_block_begin
=
p_wei_global
+
wei_cyxk_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_in_global_block_begin
+=
CPerBlock
*
in_chwn_global_desc
.
GetStride
(
I0
),
p_wei_global_block_begin
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
),
__syncthreads
())
{
// input: global mem to LDS
blockwise_in_copy
.
Run
(
p_in_global_block_begin
,
p_in_block
);
// weight: global mem to LDS
blockwise_wei_copy
.
Run
(
p_wei_global_block_begin
,
p_wei_block
);
__syncthreads
();
// a series of batched GEMM
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
#if 0
blockwise_batch_gemm.Run
#elif
1
blockwise_batch_gemm
.
Run_v3
#endif
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_out_thread
,
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
});
}
}
}
// output: register to global mem,
#if 0
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
{
for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
{
for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
{
for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
{
const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
const auto c_thread_mtx_distance =
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
const index_t ho_thread =
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
const index_t wo_thread = b_thread / NPerBlock;
const index_t n_thread = b_thread % NPerBlock;
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
ho_block_data_begin + ho_thread,
wo_block_data_begin + wo_thread,
n_block_data_begin + n_thread)] =
p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
}
}
}
}
#elif
1
const
auto
c_thread_mtx_begin
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_begin
=
c_thread_mtx_begin
.
row
;
const
index_t
ho_thread_data_begin
=
c_thread_mtx_begin
.
batch
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
-
NPerBlock
*
wo_thread_data_begin
;
// this is for v2 GEMM
// output is a 8d tensor
if
(
NPerThread
<
NPerBlock
&&
WoPerThread
==
1
)
{
constexpr
index_t
N1_
=
GemmNPerThreadSubC
;
constexpr
index_t
W1_
=
WoPerBlock
/
((
WoPerThread
*
NPerThread
)
/
GemmNPerThreadSubC
);
constexpr
index_t
K2_
=
GemmMPerThreadSubC
;
constexpr
index_t
K1_
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_8d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1_
*
K2_
),
K1_
,
K2_
,
Ho
,
Wo
/
W1_
,
W1_
,
N
/
N1_
,
N1_
>
{});
constexpr
auto
out_8d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerBlock
/
(
K1_
*
K2_
),
1
,
K2_
,
HoPerThread
,
WoPerBlock
/
W1_
,
1
,
1
,
N1_
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
print_ConstantTensorDescriptor(out_8d_thread_desc, "out_8d_thread_desc");
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
print_ConstantTensorDescriptor(out_8d_global_desc, "out_8d_global_desc");
}
#endif
threadwise_8d_tensor_copy
(
out_8d_thread_desc
,
p_out_thread
,
out_8d_global_desc
,
p_out_global
+
out_khwn_global_desc
.
Get1dIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_8d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite
>
{});
}
else
if
(
NPerThread
==
NPerBlock
)
{
// not implemented yet
assert
(
false
);
}
else
{
assert
(
false
);
}
#endif
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment