Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1b323316
Commit
1b323316
authored
Feb 06, 2019
by
Chao Liu
Browse files
add another blockwise gemm
parent
5e776504
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
840 additions
and
111 deletions
+840
-111
driver/conv.cu
driver/conv.cu
+7
-4
driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
+2
-25
driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh
...ice_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh
+210
-0
src/include/blockwise_gemm.cuh
src/include/blockwise_gemm.cuh
+244
-25
src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh
...e/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh
+5
-6
src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh
...ise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh
+5
-6
src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_pipeline.cuh
...gemm_convolution_1_chwn_csrk_khwn_padded_lds_pipeline.cuh
+5
-6
src/include/gridwise_implicit_gemm_convolution_1_nchw_kcsr.cuh
...nclude/gridwise_implicit_gemm_convolution_1_nchw_kcsr.cuh
+5
-5
src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
...e/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
+5
-6
src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
...e/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
+7
-7
src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh
...ise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh
+327
-0
src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline.cuh
...plicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline.cuh
+6
-7
src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
...e/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
+6
-7
src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline.cuh
...plicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline.cuh
+6
-7
No files found.
driver/conv.cu
View file @
1b323316
...
...
@@ -14,6 +14,7 @@
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh"
#include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh"
#include "device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh"
#include "device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh"
//#include "device_winograd_convolution.cuh"
struct
GeneratorTensor_1
...
...
@@ -391,7 +392,7 @@ int main()
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
#elif
1
#elif
0
// 3x3, 34x34
constexpr
unsigned
N
=
64
;
constexpr
unsigned
C
=
256
;
...
...
@@ -484,7 +485,7 @@ int main()
constexpr
unsigned
HPad
=
1
;
constexpr
unsigned
WPad
=
1
;
#elif
0
#elif
1
// 1x1 filter, 28x28 image
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
256
;
...
...
@@ -591,8 +592,10 @@ int main()
device_implicit_gemm_convolution_1_chwn_csrk_khwn
#elif 0
device_implicit_gemm_convolution_2_cnhw_srck_knhw
#elif
1
#elif
0
device_implicit_gemm_convolution_2_cnhw_csrk_knhw
#elif 1
device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2
#endif
(
in_nchw_desc
,
in_nchw
,
wei_kcsr_desc
,
wei_kcsr
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
...
...
@@ -608,7 +611,7 @@ int main()
nrepeat
);
#endif
#if
1
#if
0
if(S == 3 && R == 3)
{
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
...
...
driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
View file @
1b323316
...
...
@@ -67,7 +67,7 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
Tensor
<
T
>
out_knhw
(
make_TensorDescriptor
(
out_knhw_desc
));
#if
1
#if
0
// 3x3, 34x34
constexpr unsigned BPerBlock = 128;
constexpr unsigned KPerBlock = 64;
...
...
@@ -90,31 +90,8 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
#elif 0
// 1x1, 28x28
constexpr
unsigned
BPerBlock
=
64
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
8
;
constexpr
unsigned
BPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
GemmThreadPerColumnPerCluster
=
4
;
constexpr
unsigned
GemmThreadPerRowPerCluster
=
8
;
constexpr
unsigned
InBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
InBlockCopyThreadPerDim1
=
16
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
unsigned
InBlockCopyDataPerRead
=
4
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
BlockSize
=
64
;
#elif
1
// 1x1, 28x28
try
// 1x1, 28x28
constexpr
unsigned
BPerBlock
=
64
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
8
;
...
...
driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh
0 → 100644
View file @
1b323316
#pragma once
#include "gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh"
#include <unistd.h>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcsr
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
unsigned
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
InDesc
{};
constexpr
auto
wei_kcsr_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
unsigned
N
=
in_nchw_desc
.
GetLength
(
I0
);
constexpr
unsigned
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
unsigned
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
unsigned
K
=
wei_kcsr_desc
.
GetLength
(
I0
);
constexpr
unsigned
C
=
wei_kcsr_desc
.
GetLength
(
I1
);
constexpr
unsigned
S
=
wei_kcsr_desc
.
GetLength
(
I2
);
constexpr
unsigned
R
=
wei_kcsr_desc
.
GetLength
(
I3
);
constexpr
unsigned
BGhostRead
=
(
S
-
1
)
*
Wi
+
(
R
-
1
);
// convert in_nchw to in_cnhw
auto
in_cnhw_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
N
,
Hi
,
Wi
>
{});
ostream_ConstantTensorDescriptor
(
in_cnhw_desc
,
std
::
cout
<<
"in_cnhw_desc: "
);
Tensor
<
T
>
in_cnhw
(
make_TensorDescriptor
(
in_cnhw_desc
));
auto
f_reorder_nchw2cnhw
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
in_cnhw
(
c
,
n
,
hi
,
wi
)
=
in_nchw
(
n
,
c
,
hi
,
wi
);
};
make_ParallelTensorFunctor
(
f_reorder_nchw2cnhw
,
N
,
C
,
Hi
,
Wi
)(
std
::
thread
::
hardware_concurrency
());
// convert wei_kcsr to wei_csrk
auto
wei_csrk_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
S
,
R
,
K
>
{});
ostream_ConstantTensorDescriptor
(
wei_csrk_desc
,
std
::
cout
<<
"wei_csrk_desc: "
);
Tensor
<
T
>
wei_csrk
(
make_TensorDescriptor
(
wei_csrk_desc
));
auto
f_reorder_kcsr2csrk
=
[
&
](
auto
k
,
auto
c
,
auto
s
,
auto
r
)
{
wei_csrk
(
c
,
s
,
r
,
k
)
=
wei_kcsr
(
k
,
c
,
s
,
r
);
};
make_ParallelTensorFunctor
(
f_reorder_kcsr2csrk
,
K
,
C
,
S
,
R
)(
std
::
thread
::
hardware_concurrency
());
// conver out_nkhw to out_knhw
auto
out_knhw_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
N
,
Ho
,
Wo
>
{});
ostream_ConstantTensorDescriptor
(
out_knhw_desc
,
std
::
cout
<<
"out_knhw_desc: "
);
Tensor
<
T
>
out_knhw
(
make_TensorDescriptor
(
out_knhw_desc
));
#if 0
// 1x1, 28x28
constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 8;
constexpr unsigned BPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned GemmMPerThreadSubC = 16;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 8;
constexpr unsigned GemmMLevel1Cluster = 1;
constexpr unsigned GemmNLevel1Cluster = 2;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 4;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 64;
#elif
1
// 1x1, 28x28 try
constexpr
unsigned
BPerBlock
=
64
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
8
;
constexpr
unsigned
BPerThread
=
8
;
constexpr
unsigned
KPerThread
=
8
;
constexpr
unsigned
GemmMPerThreadSubC
=
4
;
constexpr
unsigned
GemmNPerThreadSubC
=
4
;
constexpr
unsigned
GemmMLevel0Cluster
=
8
;
constexpr
unsigned
GemmNLevel0Cluster
=
2
;
constexpr
unsigned
GemmMLevel1Cluster
=
1
;
constexpr
unsigned
GemmNLevel1Cluster
=
4
;
constexpr
unsigned
GemmKPerThreadLoop
=
1
;
constexpr
unsigned
GemmThreadPerColumnPerCluster
=
8
;
constexpr
unsigned
GemmThreadPerRowPerCluster
=
8
;
constexpr
unsigned
InBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
InBlockCopyThreadPerDim1
=
16
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
unsigned
InBlockCopyDataPerRead
=
4
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
BlockSize
=
64
;
#endif
constexpr
unsigned
GridSize
=
((
N
*
Hi
*
Wi
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
);
dim3
block_dim
(
BlockSize
);
dim3
grid_dim
(
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
// mem
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_cnhw_device_buf
(
data_sz
*
(
in_cnhw
.
mDesc
.
GetElementSpace
()
+
BGhostRead
+
BPerBlock
));
// reserve extra space for BGhostRead
DeviceMem
wei_csrk_device_buf
(
data_sz
*
wei_csrk
.
mDesc
.
GetElementSpace
());
DeviceMem
out_knhw_device_buf
(
data_sz
*
out_knhw
.
mDesc
.
GetElementSpace
());
in_cnhw_device_buf
.
ToDevice
(
in_cnhw
.
mData
.
data
());
wei_csrk_device_buf
.
ToDevice
(
wei_csrk
.
mData
.
data
());
out_knhw_device_buf
.
ToDevice
(
out_knhw
.
mData
.
data
());
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
{
cudaEvent_t
start
,
stop
;
float
elapsedTime
;
cudaEventCreate
(
&
start
);
cudaEventRecord
(
start
,
0
);
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_cnhw_desc
),
decltype
(
wei_csrk_desc
),
decltype
(
out_knhw_desc
),
BPerBlock
,
KPerBlock
,
CPerBlock
,
BPerThread
,
KPerThread
,
GemmThreadPerColumnPerCluster
,
GemmThreadPerRowPerCluster
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
>
<<<
grid_dim
,
block_dim
>>>
(
in_cnhw_desc
,
static_cast
<
T
*>
(
in_cnhw_device_buf
.
GetDeviceBuffer
()),
wei_csrk_desc
,
static_cast
<
T
*>
(
wei_csrk_device_buf
.
GetDeviceBuffer
()),
out_knhw_desc
,
static_cast
<
T
*>
(
out_knhw_device_buf
.
GetDeviceBuffer
()));
cudaEventCreate
(
&
stop
);
cudaEventRecord
(
stop
,
0
);
cudaEventSynchronize
(
stop
);
cudaEventElapsedTime
(
&
elapsedTime
,
start
,
stop
);
printf
(
"Elapsed time : %f ms
\n
"
,
elapsedTime
);
usleep
(
std
::
min
(
elapsedTime
*
1000
,
float
(
10000
)));
}
checkCudaErrors
(
cudaGetLastError
());
out_knhw_device_buf
.
FromDevice
(
out_knhw
.
mData
.
data
());
// convert out_knhw to out_nkhw
auto
f_reorder_knhw2nkhw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
out_nkhw
(
n
,
k
,
ho
,
wo
)
=
out_knhw
(
k
,
n
,
ho
,
wo
);
};
make_ParallelTensorFunctor
(
f_reorder_knhw2nkhw
,
N
,
K
,
Ho
,
Wo
)(
std
::
thread
::
hardware_concurrency
());
}
src/include/blockwise_gemm.cuh
View file @
1b323316
...
...
@@ -22,9 +22,9 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
struct
MatrixIndex
{
unsigned
batch
_begin
;
unsigned
row
_begin
;
unsigned
col
_begin
;
unsigned
batch
;
unsigned
row
;
unsigned
col
;
};
__device__
Blockwise1dStridedBatchedGemmBlockABlockBThreadC
()
...
...
@@ -32,15 +32,15 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
const
auto
a_block_mtx
=
BlockMatrixA
{};
// constexpr doesn't compile
const
auto
b_block_mtx
=
BlockMatrixB
{};
// constexpr doesn't compile
const
auto
c_thread_mtx_index
=
Calculate
ThreadMatrixC
Index
(
get_thread_local_1d_id
());
const
auto
c_thread_mtx_index
=
GetBeginOf
ThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
c_thread_mtx_index
.
batch
_begin
*
BlockMatrixStrideA
+
((
!
TransA
)
?
a_block_mtx
.
Get1dIndex
(
c_thread_mtx_index
.
row
_begin
,
0
)
:
a_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
row
_begin
));
mMyThreadOffsetA
=
c_thread_mtx_index
.
batch
*
BlockMatrixStrideA
+
((
!
TransA
)
?
a_block_mtx
.
Get1dIndex
(
c_thread_mtx_index
.
row
,
0
)
:
a_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
row
));
mMyThreadOffsetB
=
c_thread_mtx_index
.
batch
_begin
*
BlockMatrixStrideB
+
((
!
TransB
)
?
b_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
col
_begin
)
:
b_block_mtx
.
Get1dIndex
(
c_thread_mtx_index
.
col
_begin
,
0
));
mMyThreadOffsetB
=
c_thread_mtx_index
.
batch
*
BlockMatrixStrideB
+
((
!
TransB
)
?
b_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
col
)
:
b_block_mtx
.
Get1dIndex
(
c_thread_mtx_index
.
col
,
0
));
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
...
...
@@ -52,16 +52,16 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
printf("%u %u, %u %u %u, %u %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
c_thread_mtx_index.batch
_begin
,
c_thread_mtx_index.row
_begin
,
c_thread_mtx_index.col
_begin
,
c_thread_mtx_index.batch,
c_thread_mtx_index.row,
c_thread_mtx_index.col,
mMyThreadOffsetA,
mMyThreadOffsetB);
}
#endif
}
__device__
MatrixIndex
Calculate
ThreadMatrixC
Index
(
unsigned
thread_id
)
const
__device__
MatrixIndex
GetBeginOf
ThreadMatrixC
(
unsigned
thread_id
)
const
{
if
(
TransA
&&
(
!
TransB
)
&&
(
!
TransC
))
...
...
@@ -237,8 +237,8 @@ struct BlockwiseGemmBlockABlockBThreadC
struct
MatrixIndex
{
unsigned
row
_begin
;
unsigned
col
_begin
;
unsigned
row
;
unsigned
col
;
};
__device__
BlockwiseGemmBlockABlockBThreadC
()
...
...
@@ -246,13 +246,13 @@ struct BlockwiseGemmBlockABlockBThreadC
const
auto
a_block_mtx
=
BlockMatrixA
{};
// constexpr doesn't compile
const
auto
b_block_mtx
=
BlockMatrixB
{};
// constexpr doesn't compile
const
auto
c_thread_mtx_index
=
Calculate
ThreadMatrixC
Index
(
get_thread_local_1d_id
());
const
auto
c_thread_mtx_index
=
GetBeginOf
ThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
(
!
TransA
)
?
a_block_mtx
.
Get1dIndex
(
c_thread_mtx_index
.
row
_begin
,
0
)
:
a_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
row
_begin
);
mMyThreadOffsetA
=
(
!
TransA
)
?
a_block_mtx
.
Get1dIndex
(
c_thread_mtx_index
.
row
,
0
)
:
a_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
row
);
mMyThreadOffsetB
=
(
!
TransB
)
?
b_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
col
_begin
)
:
b_block_mtx
.
Get1dIndex
(
c_thread_mtx_index
.
col
_begin
,
0
);
mMyThreadOffsetB
=
(
!
TransB
)
?
b_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
col
)
:
b_block_mtx
.
Get1dIndex
(
c_thread_mtx_index
.
col
,
0
);
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
...
...
@@ -264,16 +264,16 @@ struct BlockwiseGemmBlockABlockBThreadC
printf("%u %u, %u %u %u, %u %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
c_thread_mtx_index.batch
_begin
,
c_thread_mtx_index.row
_begin
,
c_thread_mtx_index.col
_begin
,
c_thread_mtx_index.batch,
c_thread_mtx_index.row,
c_thread_mtx_index.col,
mMyThreadOffsetA,
mMyThreadOffsetB);
}
#endif
}
__device__
MatrixIndex
Calculate
ThreadMatrixC
Index
(
unsigned
thread_id
)
const
__device__
MatrixIndex
GetBeginOf
ThreadMatrixC
(
unsigned
thread_id
)
const
{
if
(
TransA
&&
(
!
TransB
)
&&
(
!
TransC
))
...
...
@@ -359,6 +359,13 @@ struct BlockwiseGemmBlockABlockBThreadC
}
}
// this should be optimized away if input is known
__device__
static
MatrixIndex
GetDistanceFromBeginOfThreadMatrixC
(
unsigned
m_in_c
,
unsigned
n_in_c
)
{
return
MatrixIndex
{
m_in_c
,
n_in_c
};
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run
(
FloatA
*
const
p_a_block
,
FloatB
*
const
p_b_block
,
...
...
@@ -420,3 +427,215 @@ struct BlockwiseGemmBlockABlockBThreadC
}
}
};
// if following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
template
<
unsigned
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
ThreadMatrixC
,
unsigned
MPerThreadSubC
,
unsigned
NPerThreadSubC
,
unsigned
MLevel0Cluster
,
unsigned
NLevel0Cluster
,
unsigned
MLevel1Cluster
,
unsigned
NLevel1Cluster
,
unsigned
KPerThreadLoop
>
struct
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
struct
MatrixIndex
{
unsigned
row
;
unsigned
col
;
};
unsigned
mMyThreadOffsetA
;
unsigned
mMyThreadOffsetB
;
__device__
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
()
{
constexpr
unsigned
ThreadPerLevel1Cluster
=
MLevel0Cluster
*
NLevel0Cluster
*
MLevel1Cluster
*
NLevel1Cluster
;
static_assert
(
BlockSize
==
ThreadPerLevel1Cluster
,
"wrong! wrong blocksize
\n
"
);
const
auto
a_block_mtx
=
BlockMatrixA
{};
// constexpr doesn't compile
const
auto
b_block_mtx
=
BlockMatrixB
{};
// constexpr doesn't compile
const
auto
c_thread_mtx
=
ThreadMatrixC
{};
// constexpr doesn't compile
static_assert
(
a_block_mtx
.
NRow
()
==
b_block_mtx
.
NRow
(),
"wrong! K dimension not consistent
\n
"
);
constexpr
unsigned
M
=
a_block_mtx
.
NCol
();
// A is transposed
constexpr
unsigned
N
=
b_block_mtx
.
NCol
();
constexpr
unsigned
K
=
a_block_mtx
.
NRow
();
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
static_assert
((
MPerThread
%
MPerThreadSubC
==
0
)
&&
(
NPerThread
%
NPerThreadSubC
==
0
),
"wrong! Cannot evenly divide thread work among repeat
\n
"
);
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_assert
((
M
%
MRepeat
==
0
)
&&
(
N
%
NRepeat
==
0
),
"wrong! Cannot evenly divide work among repeat
\n
"
);
constexpr
unsigned
MPerLevel1Cluster
=
M
/
MRepeat
;
constexpr
unsigned
NPerLevel1Cluster
=
N
/
NRepeat
;
static_assert
((
MPerLevel1Cluster
%
MLevel1Cluster
==
0
)
&&
(
NPerLevel1Cluster
%
NLevel1Cluster
==
0
),
"wrong! Cannot evenly divide work among Level1Cluster
\n
"
);
constexpr
unsigned
MPerLevel0Cluster
=
MPerLevel1Cluster
/
MLevel1Cluster
;
constexpr
unsigned
NPerLevel0Cluster
=
NPerLevel1Cluster
/
NLevel1Cluster
;
static_assert
((
MPerLevel0Cluster
%
MLevel0Cluster
==
0
)
&&
(
NPerLevel0Cluster
%
NLevel0Cluster
==
0
),
"wrong! Cannot evenly divide work among Level0Cluster
\n
"
);
static_assert
((
MPerThreadSubC
==
MPerLevel0Cluster
/
MLevel0Cluster
)
&&
(
NPerThreadSubC
==
NPerLevel0Cluster
/
NLevel0Cluster
),
"wrong! thread work size is wrong
\n
"
);
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
a_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
row
);
mMyThreadOffsetB
=
b_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
col
);
}
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
unsigned
thread_id
)
{
constexpr
unsigned
ThreadPerLevel0Cluster
=
MLevel0Cluster
*
NLevel0Cluster
;
unsigned
level1_id
=
thread_id
/
ThreadPerLevel0Cluster
;
unsigned
level1_m_id
=
level1_id
/
NLevel1Cluster
;
unsigned
level1_n_id
=
level1_id
%
NLevel1Cluster
;
unsigned
level0_id
=
thread_id
%
ThreadPerLevel0Cluster
;
unsigned
level0_m_id
=
level0_id
/
NLevel0Cluster
;
unsigned
level0_n_id
=
level0_id
%
NLevel0Cluster
;
constexpr
unsigned
MPerLevel0Cluster
=
MPerThreadSubC
*
MLevel0Cluster
;
constexpr
unsigned
NPerLevel0Cluster
=
NPerThreadSubC
*
NLevel0Cluster
;
return
MatrixIndex
{
level1_m_id
*
MPerLevel0Cluster
+
level0_m_id
*
MPerThreadSubC
,
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
}
// this should be optimized away if input is known
__device__
static
MatrixIndex
GetDistanceFromBeginOfThreadMatrixC
(
unsigned
m_in_c
,
unsigned
n_in_c
)
{
const
auto
c_thread_mtx
=
ThreadMatrixC
{};
// constexpr doesn't compile
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
unsigned
m_repeat
=
m_in_c
/
MPerThreadSubC
;
unsigned
n_repeat
=
n_in_c
/
NPerThreadSubC
;
unsigned
m_in_sub_c
=
m_in_c
%
MPerThreadSubC
;
unsigned
n_in_sub_c
=
n_in_c
%
NPerThreadSubC
;
return
MatrixIndex
{
m_repeat
*
MPerLevel1Cluster
+
m_in_sub_c
,
n_repeat
*
NPerLevel1Cluster
+
n_in_sub_c
};
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run
(
FloatA
*
const
p_a_block
,
FloatB
*
const
p_b_block
,
FloatC
*
p_c_thread
,
Accumulator
f_accum
)
const
{
constexpr
auto
True
=
Constant
<
bool
,
true
>
{};
constexpr
auto
False
=
Constant
<
bool
,
false
>
{};
const
auto
a_block_mtx
=
BlockMatrixA
{};
// constexpr doesn't compile
const
auto
b_block_mtx
=
BlockMatrixB
{};
// constexpr doesn't compile
const
auto
c_thread_mtx
=
ThreadMatrixC
{};
// constexpr doesn't compile
constexpr
unsigned
M
=
a_block_mtx
.
NCol
();
constexpr
unsigned
N
=
b_block_mtx
.
NCol
();
constexpr
unsigned
K
=
a_block_mtx
.
NRow
();
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
const
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
// constexpr doesn't compile
const
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// constexpr doesn't compile
// thread A-sub, B-sub for copy
const
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
// constexpr doesn't compile
const
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
// constexpr doesn't compile
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// loop over k
for
(
unsigned
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
// copy A-sub to form A
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
mMyThreadOffsetA
+
k_begin
*
a_block_mtx
.
RowStride
()
+
m_repeat
*
MPerLevel1Cluster
,
a_thread_sub_mtx
,
p_a_thread
+
m_repeat
*
MPerThreadSubC
,
a_thread_sub_mtx
.
GetLengths
());
}
// copy B-sub to form B
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
mMyThreadOffsetB
+
k_begin
*
b_block_mtx
.
RowStride
()
+
n_repeat
*
NPerLevel1Cluster
,
b_thread_sub_mtx
,
p_b_thread
+
n_repeat
*
NPerThreadSubC
,
b_thread_sub_mtx
.
GetLengths
());
}
// C = A * B
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
,
f_accum
);
}
}
};
src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh
View file @
1b323316
...
...
@@ -208,13 +208,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
}
const
auto
matrix_c_index
=
blockwise_batch_gemm
.
Calculate
ThreadMatrixC
Index
(
get_thread_local_1d_id
());
blockwise_batch_gemm
.
GetBeginOf
ThreadMatrixC
(
get_thread_local_1d_id
());
const
unsigned
ho_thread_data_begin
=
matrix_c_index
.
batch_begin
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row_begin
;
const
unsigned
wo_thread_data_begin
=
matrix_c_index
.
col_begin
/
NPerBlock
;
const
unsigned
n_thread_data_begin
=
matrix_c_index
.
col_begin
-
wo_thread_data_begin
*
NPerBlock
;
const
unsigned
ho_thread_data_begin
=
matrix_c_index
.
batch
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row
;
const
unsigned
wo_thread_data_begin
=
matrix_c_index
.
col
/
NPerBlock
;
const
unsigned
n_thread_data_begin
=
matrix_c_index
.
col
-
wo_thread_data_begin
*
NPerBlock
;
#if 0
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
...
...
src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh
View file @
1b323316
...
...
@@ -262,13 +262,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri
}
const
auto
matrix_c_index
=
blockwise_batch_gemm
.
Calculate
ThreadMatrixC
Index
(
get_thread_local_1d_id
());
blockwise_batch_gemm
.
GetBeginOf
ThreadMatrixC
(
get_thread_local_1d_id
());
const
unsigned
ho_thread_data_begin
=
matrix_c_index
.
batch_begin
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row_begin
;
const
unsigned
wo_thread_data_begin
=
matrix_c_index
.
col_begin
/
NPerBlock
;
const
unsigned
n_thread_data_begin
=
matrix_c_index
.
col_begin
-
wo_thread_data_begin
*
NPerBlock
;
const
unsigned
ho_thread_data_begin
=
matrix_c_index
.
batch
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row
;
const
unsigned
wo_thread_data_begin
=
matrix_c_index
.
col
/
NPerBlock
;
const
unsigned
n_thread_data_begin
=
matrix_c_index
.
col
-
wo_thread_data_begin
*
NPerBlock
;
#if 0
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
...
...
src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_pipeline.cuh
View file @
1b323316
...
...
@@ -318,13 +318,12 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
}
const
auto
matrix_c_index
=
blockwise_batch_gemm
.
Calculate
ThreadMatrixC
Index
(
get_thread_local_1d_id
());
blockwise_batch_gemm
.
GetBeginOf
ThreadMatrixC
(
get_thread_local_1d_id
());
const
unsigned
ho_thread_data_begin
=
matrix_c_index
.
batch_begin
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row_begin
;
const
unsigned
wo_thread_data_begin
=
matrix_c_index
.
col_begin
/
NPerBlock
;
const
unsigned
n_thread_data_begin
=
matrix_c_index
.
col_begin
-
wo_thread_data_begin
*
NPerBlock
;
const
unsigned
ho_thread_data_begin
=
matrix_c_index
.
batch
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row
;
const
unsigned
wo_thread_data_begin
=
matrix_c_index
.
col
/
NPerBlock
;
const
unsigned
n_thread_data_begin
=
matrix_c_index
.
col
-
wo_thread_data_begin
*
NPerBlock
;
#if 0
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
...
...
src/include/gridwise_implicit_gemm_convolution_1_nchw_kcsr.cuh
View file @
1b323316
...
...
@@ -228,15 +228,15 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr(InGlobalDesc,
}
const
auto
matrix_c_index
=
blockwise_batch_gemm
.
Calculate
ThreadMatrixC
Index
(
get_thread_local_1d_id
());
blockwise_batch_gemm
.
GetBeginOf
ThreadMatrixC
(
get_thread_local_1d_id
());
#if 0
printf("%u %u, %u %u %u\n",get_block_1d_id(), get_thread_local_1d_id(), matrix_c_index.batch
_begin
, matrix_c_index.row
_begin
, matrix_c_index.col
_begin
);
printf("%u %u, %u %u %u\n",get_block_1d_id(), get_thread_local_1d_id(), matrix_c_index.batch, matrix_c_index.row, matrix_c_index.col);
#endif
const
unsigned
ho_thread_data_begin
=
matrix_c_index
.
batch
_begin
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row
_begin
;
const
unsigned
wo_thread_data_begin
=
matrix_c_index
.
col
_begin
/
NPerThread
;
const
unsigned
ho_thread_data_begin
=
matrix_c_index
.
batch
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row
;
const
unsigned
wo_thread_data_begin
=
matrix_c_index
.
col
/
NPerThread
;
#if 1
// output: register to global mem,
...
...
src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
View file @
1b323316
...
...
@@ -205,13 +205,12 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
}
const
auto
matrix_c_index
=
blockwise_batch_gemm
.
Calculate
ThreadMatrixC
Index
(
get_thread_local_1d_id
());
blockwise_batch_gemm
.
GetBeginOf
ThreadMatrixC
(
get_thread_local_1d_id
());
const
unsigned
ho_thread_data_begin
=
matrix_c_index
.
batch_begin
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row_begin
;
const
unsigned
wo_thread_data_begin
=
matrix_c_index
.
col_begin
/
NPerBlock
;
const
unsigned
n_thread_data_begin
=
matrix_c_index
.
col_begin
-
wo_thread_data_begin
*
NPerBlock
;
const
unsigned
ho_thread_data_begin
=
matrix_c_index
.
batch
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row
;
const
unsigned
wo_thread_data_begin
=
matrix_c_index
.
col
/
NPerBlock
;
const
unsigned
n_thread_data_begin
=
matrix_c_index
.
col
-
wo_thread_data_begin
*
NPerBlock
;
// output: register to global mem,
// convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo]
...
...
src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
View file @
1b323316
...
...
@@ -75,6 +75,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
constexpr
auto
wei_ek_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
*
S
*
R
,
K
>
{});
// tensor view of blockwise input and weight
// be careful of alignment
constexpr
auto
in_cb_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
BPerBlock
+
BGhostRead
>
{},
Number
<
InBlockCopyDataPerRead
>
{});
...
...
@@ -245,11 +246,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
}
// output: register to global mem,
const
auto
matrix_c_index
=
blockwise_gemm
.
CalculateThreadMatrixCIndex
(
get_thread_local_1d_id
());
const
auto
matrix_c_index
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row
_begin
;
const
unsigned
b_thread_data_begin
=
matrix_c_index
.
col
_begin
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row
;
const
unsigned
b_thread_data_begin
=
matrix_c_index
.
col
;
const
unsigned
k_data_begin
=
k_block_data_begin
+
k_thread_data_begin
;
const
unsigned
b_data_begin
=
b_block_data_begin
+
b_thread_data_begin
;
...
...
@@ -257,11 +257,11 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, row
_begin
%u col
_begin
%u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
matrix_c_index.row
_begin
,
matrix_c_index.col
_begin
,
matrix_c_index.row,
matrix_c_index.col,
k_data_begin,
b_data_begin,
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
...
...
src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh
0 → 100644
View file @
1b323316
#pragma once
#include "common.cuh"
#include "ConstantTensorDescriptor.cuh"
#include "ConstantMatrixDescriptor.cuh"
#include "blockwise_4d_tensor_op.cuh"
#include "blockwise_2d_tensor_op.cuh"
#include "threadwise_2d_tensor_op.cuh"
#include "blockwise_gemm.cuh"
// define B = flatten(N, Hi, Wi)
template
<
unsigned
GridSize
,
unsigned
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
unsigned
BPerBlock
,
unsigned
KPerBlock
,
unsigned
CPerBlock
,
unsigned
BPerThread
,
unsigned
KPerThread
,
unsigned
GemmThreadPerColumnPerCluster
,
unsigned
GemmThreadPerRowPerCluster
,
unsigned
GemmMPerThreadSubC
,
unsigned
GemmNPerThreadSubC
,
unsigned
GemmMLevel0Cluster
,
unsigned
GemmNLevel0Cluster
,
unsigned
GemmMLevel1Cluster
,
unsigned
GemmNLevel1Cluster
,
unsigned
GemmKPerThreadLoop
,
unsigned
InBlockCopyThreadPerDim0
,
unsigned
InBlockCopyThreadPerDim1
,
unsigned
WeiBlockCopyThreadPerDim0
,
unsigned
WeiBlockCopyThreadPerDim1
,
unsigned
InBlockCopyDataPerRead
,
unsigned
WeiBlockCopyDataPerRead
>
__global__
void
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2
(
InGlobalDesc
,
Float
*
const
__restrict__
p_in_global
,
WeiGlobalDesc
,
Float
*
const
__restrict__
p_wei_global
,
OutGlobalDesc
,
Float
*
__restrict__
p_out_global
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_cnhw_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_csrk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_knhw_global_desc
=
OutGlobalDesc
{};
constexpr
unsigned
C
=
in_cnhw_global_desc
.
GetLength
(
I0
);
constexpr
unsigned
N
=
in_cnhw_global_desc
.
GetLength
(
I1
);
constexpr
unsigned
Hi
=
in_cnhw_global_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wi
=
in_cnhw_global_desc
.
GetLength
(
I3
);
constexpr
unsigned
K
=
out_knhw_global_desc
.
GetLength
(
I0
);
constexpr
unsigned
Ho
=
out_knhw_global_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wo
=
out_knhw_global_desc
.
GetLength
(
I3
);
constexpr
unsigned
S
=
wei_csrk_global_desc
.
GetLength
(
I1
);
constexpr
unsigned
R
=
wei_csrk_global_desc
.
GetLength
(
I2
);
constexpr
unsigned
B
=
N
*
Hi
*
Wi
;
constexpr
unsigned
BGhostRead
=
(
S
-
1
)
*
Wi
+
(
R
-
1
);
// divide block work by 2d: [K, B]
constexpr
unsigned
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
unsigned
BBlockWork
=
(
B
+
BPerBlock
-
1
)
/
BPerBlock
;
const
unsigned
k_block_work_id
=
get_block_1d_id
()
/
BBlockWork
;
const
unsigned
b_block_work_id
=
get_block_1d_id
()
-
k_block_work_id
*
BBlockWork
;
const
unsigned
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
unsigned
b_block_data_begin
=
b_block_work_id
*
BPerBlock
;
// flattend (2d) tensor view of gridwise input
constexpr
auto
in_cb_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
B
>
{});
constexpr
auto
wei_ek_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
*
S
*
R
,
K
>
{});
// tensor view of blockwise input and weight
// be careful of alignment
constexpr
auto
in_cb_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
BPerBlock
+
BGhostRead
>
{},
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
auto
wei_ek_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
*
S
*
R
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
auto
wei_csrk_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
S
,
R
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_kb_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
BPerThread
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_cnhw_global_desc, "in_cnhw_global_desc");
print_ConstantTensorDescriptor(wei_csrk_global_desc, "wei_csrk_global_desc");
print_ConstantTensorDescriptor(out_knhw_global_desc, "out_knhw_global_desc");
print_ConstantTensorDescriptor(in_cb_global_desc, "in_cb_global_desc");
print_ConstantTensorDescriptor(wei_ek_global_desc, "wei_ek_global_desc");
print_ConstantTensorDescriptor(in_cb_block_desc, "in_cb_block_desc");
print_ConstantTensorDescriptor(wei_csrk_block_desc, "wei_csrk_block_desc");
print_ConstantTensorDescriptor(wei_ek_block_desc, "wei_ek_block_desc");
print_ConstantTensorDescriptor(out_kb_thread_desc, "out_kb_thread_desc");
printf("KPerBlock %u\n", KPerBlock);
}
#endif
// blockwise in copy
// formmat is [CPerBlock,BPerBlock + BGhostRead]
#if 0
const auto blockwise_in_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{};
#elif
0
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
()),
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
>
{};
#elif 1
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
()),
InBlockCopyDataPerRead
>
{};
#endif
// blockwise wei copy
// format is [CPerBlock*S*R,KPerBlock]
#if 0
const auto blockwise_wei_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
#elif
0
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
{};
#elif 1
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
{};
#endif
// a series of blockwise GEMM
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
// a_mtx[C,K] is a sub-matrix of wei_block[C,S,R,K]
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
// c_mtx[K,B] is out_block[K,B]
const
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_csrk_block_desc
.
GetStride
(
I0
)
>
{});
// constexpr doesn't compile
const
auto
b_cxb_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
BPerBlock
>
{},
Number
<
in_cb_block_desc
.
GetStride
(
I0
)
>
{});
// constexpr doesn't compile
const
auto
c_kxb_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
BPerThread
>
{});
// constexpr doesn't compile
#if 0
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC<BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxb_block_mtx_desc),
decltype(c_kxb_thread_mtx_desc),
true,
false,
false,
GemmKPerThreadLoop,
GemmThreadPerColumnPerCluster,
GemmThreadPerRowPerCluster,
true>{};
#else
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxb_block_mtx_desc
),
decltype
(
c_kxb_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
>
{};
#endif
// LDS: be careful of alignment
constexpr
unsigned
in_block_size
=
in_cb_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
unsigned
wei_block_size
=
wei_csrk_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
unsigned
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
?
InBlockCopyDataPerRead
:
WeiBlockCopyDataPerRead
;
__shared__
Float
p_in_block
[
max_align
*
((
in_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block
[
max_align
*
((
wei_block_size
+
max_align
-
1
)
/
max_align
)];
// register
Float
p_out_thread
[
out_kb_thread_desc
.
GetElementSpace
()];
// set threadwise output tensor to 0
threadwise_2d_tensor_set_zero
(
out_kb_thread_desc
,
p_out_thread
);
Float
*
p_in_global_block_offset
=
p_in_global
+
in_cb_global_desc
.
Get1dIndex
(
0
,
b_block_data_begin
);
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_csrk_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
for
(
unsigned
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_in_global_block_offset
+=
CPerBlock
*
in_cb_global_desc
.
GetStride
(
I0
),
p_wei_global_block_offset
+=
CPerBlock
*
wei_csrk_global_desc
.
GetStride
(
I0
),
__syncthreads
())
{
// input: global mem to LDS,
blockwise_in_copy
.
Run
(
p_in_global_block_offset
,
p_in_block
);
// weight: global mem to LDS,
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
,
p_wei_block
);
__syncthreads
();
// a series of GEMM
for
(
unsigned
s
=
0
;
s
<
S
;
++
s
)
{
for
(
unsigned
r
=
0
;
r
<
R
;
++
r
)
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
blockwise_gemm
.
Run
(
p_wei_block
+
wei_csrk_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_in_block
+
s
*
Wi
+
r
,
p_out_thread
,
f_accum
);
}
}
}
// output: register to global mem,
const
auto
c_thread_mtx_begin
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
unsigned
k_thread_data_begin
=
k_block_data_begin
+
c_thread_mtx_begin
.
row
;
const
unsigned
b_thread_data_begin
=
b_block_data_begin
+
c_thread_mtx_begin
.
col
;
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
matrix_c_index.row,
matrix_c_index.col,
k_data_begin,
b_data_begin,
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
}
#endif
for
(
unsigned
k
=
0
;
k
<
out_kb_thread_desc
.
GetLength
(
I0
);
++
k
)
{
for
(
unsigned
b
=
0
;
b
<
out_kb_thread_desc
.
GetLength
(
I1
);
++
b
)
{
const
auto
c_thread_mtx_distance
=
blockwise_gemm
.
GetDistanceFromBeginOfThreadMatrixC
(
k
,
b
);
unsigned
k_data
=
k_thread_data_begin
+
c_thread_mtx_distance
.
row
;
unsigned
b_data
=
b_thread_data_begin
+
c_thread_mtx_distance
.
col
;
unsigned
n_data
=
b_data
/
(
Hi
*
Wi
);
unsigned
itmp
=
b_data
-
n_data
*
(
Hi
*
Wi
);
unsigned
h_data
=
itmp
/
Wi
;
unsigned
w_data
=
itmp
-
h_data
*
Wi
;
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, k %u b %u, k_data %u n_data %u h_data %u w_data %u %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
k,
b,
k_data,
n_data,
h_data,
w_data,
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]);
}
#endif
if
(
n_data
<
N
&&
h_data
<
Ho
&&
w_data
<
Wo
)
{
p_out_global
[
out_knhw_global_desc
.
Get1dIndex
(
k_data
,
n_data
,
h_data
,
w_data
)]
=
p_out_thread
[
out_kb_thread_desc
.
Get1dIndex
(
k
,
b
)];
}
}
}
}
src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline.cuh
View file @
1b323316
...
...
@@ -290,11 +290,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
}
// output: register to global mem,
const
auto
matrix_c_index
=
blockwise_gemm
.
CalculateThreadMatrixCIndex
(
get_thread_local_1d_id
());
const
auto
matrix_c_index
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row
_begin
;
const
unsigned
b_thread_data_begin
=
matrix_c_index
.
col
_begin
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row
;
const
unsigned
b_thread_data_begin
=
matrix_c_index
.
col
;
const
unsigned
k_data_begin
=
k_block_data_begin
+
k_thread_data_begin
;
const
unsigned
b_data_begin
=
b_block_data_begin
+
b_thread_data_begin
;
...
...
@@ -302,11 +301,11 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, row
_begin
%u col
_begin
%u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
matrix_c_index.row
_begin
,
matrix_c_index.col
_begin
,
matrix_c_index.row,
matrix_c_index.col,
k_data_begin,
b_data_begin,
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
...
...
src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
View file @
1b323316
...
...
@@ -217,11 +217,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
}
// output: register to global mem,
const
auto
matrix_c_index
=
blockwise_gemm
.
CalculateThreadMatrixCIndex
(
get_thread_local_1d_id
());
const
auto
matrix_c_index
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row
_begin
;
const
unsigned
b_thread_data_begin
=
matrix_c_index
.
col
_begin
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row
;
const
unsigned
b_thread_data_begin
=
matrix_c_index
.
col
;
const
unsigned
k_data_begin
=
k_block_data_begin
+
k_thread_data_begin
;
const
unsigned
b_data_begin
=
b_block_data_begin
+
b_thread_data_begin
;
...
...
@@ -229,11 +228,11 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, row
_begin
%u col
_begin
%u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
matrix_c_index.row
_begin
,
matrix_c_index.col
_begin
,
matrix_c_index.row,
matrix_c_index.col,
k_data_begin,
b_data_begin,
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
...
...
src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline.cuh
View file @
1b323316
...
...
@@ -276,11 +276,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
}
// output: register to global mem,
const
auto
matrix_c_index
=
blockwise_gemm
.
CalculateThreadMatrixCIndex
(
get_thread_local_1d_id
());
const
auto
matrix_c_index
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row
_begin
;
const
unsigned
b_thread_data_begin
=
matrix_c_index
.
col
_begin
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row
;
const
unsigned
b_thread_data_begin
=
matrix_c_index
.
col
;
const
unsigned
k_data_begin
=
k_block_data_begin
+
k_thread_data_begin
;
const
unsigned
b_data_begin
=
b_block_data_begin
+
b_thread_data_begin
;
...
...
@@ -288,11 +287,11 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, row
_begin
%u col
_begin
%u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
matrix_c_index.row
_begin
,
matrix_c_index.col
_begin
,
matrix_c_index.row,
matrix_c_index.col,
k_data_begin,
b_data_begin,
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment