Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
6614729a
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "8e26f3c5934d01b1491d1e48a41e431fc6568f90"
Commit
6614729a
authored
Feb 05, 2019
by
Chao Liu
Browse files
add another version of blockwise 2d copy, refactor
parent
4b616aad
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
462 additions
and
453 deletions
+462
-453
driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
+6
-1
src/include/blockwise_2d_tensor_op.cuh
src/include/blockwise_2d_tensor_op.cuh
+97
-77
src/include/blockwise_4d_tensor_op.cuh
src/include/blockwise_4d_tensor_op.cuh
+5
-34
src/include/blockwise_gemm.cuh
src/include/blockwise_gemm.cuh
+6
-6
src/include/gridwise_direct_convolution_1.cuh
src/include/gridwise_direct_convolution_1.cuh
+18
-18
src/include/gridwise_direct_convolution_2.cuh
src/include/gridwise_direct_convolution_2.cuh
+12
-12
src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh
...e/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh
+27
-28
src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh
...ise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh
+40
-42
src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_pipeline.cuh
...gemm_convolution_1_chwn_csrk_khwn_padded_lds_pipeline.cuh
+39
-41
src/include/gridwise_implicit_gemm_convolution_1_nchw_kcsr.cuh
...nclude/gridwise_implicit_gemm_convolution_1_nchw_kcsr.cuh
+25
-26
src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
...e/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
+21
-22
src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
...e/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
+51
-40
src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline.cuh
...plicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline.cuh
+50
-37
src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
...e/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
+31
-33
src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline.cuh
...plicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline.cuh
+34
-36
No files found.
driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
View file @
6614729a
...
@@ -86,6 +86,9 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
...
@@ -86,6 +86,9 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
unsigned
InBlockCopyDataPerRead
=
4
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
BlockSize
=
64
;
constexpr
unsigned
BlockSize
=
64
;
#endif
#endif
...
@@ -137,7 +140,9 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
...
@@ -137,7 +140,9 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
,
InBlockCopyThreadPerDim1
,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
WeiBlockCopyThreadPerDim1
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
>
<<<
grid_dim
,
block_dim
>>>
(
in_cnhw_desc
,
<<<
grid_dim
,
block_dim
>>>
(
in_cnhw_desc
,
static_cast
<
T
*>
(
in_cnhw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_cnhw_device_buf
.
GetDeviceBuffer
()),
wei_csrk_desc
,
wei_csrk_desc
,
...
...
src/include/blockwise_2d_tensor_op.cuh
View file @
6614729a
...
@@ -162,9 +162,9 @@ blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
...
@@ -162,9 +162,9 @@ blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
}
}
template
<
unsigned
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
template
<
unsigned
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
struct
b
lockwise
_
2d
_t
ensor
_c
opy
_
1
struct
B
lockwise2d
T
ensor
C
opy1
{
{
__device__
void
r
un
(
Float
*
const
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
__device__
void
R
un
(
Float
*
const
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
{
constexpr
auto
dst_from_src_reorder
=
Sequence
<
0
,
1
>
{};
constexpr
auto
dst_from_src_reorder
=
Sequence
<
0
,
1
>
{};
...
@@ -173,6 +173,8 @@ struct blockwise_2d_tensor_copy_1
...
@@ -173,6 +173,8 @@ struct blockwise_2d_tensor_copy_1
}
}
};
};
// need to be aligned to float4 and float2
// stride1 need to be 1 for both source and destination
template
<
unsigned
BlockSize
,
template
<
unsigned
BlockSize
,
class
Float
,
class
Float
,
class
SrcDesc
,
class
SrcDesc
,
...
@@ -180,21 +182,27 @@ template <unsigned BlockSize,
...
@@ -180,21 +182,27 @@ template <unsigned BlockSize,
class
SrcOpLengths
,
class
SrcOpLengths
,
unsigned
ThreadPerDim0
,
unsigned
ThreadPerDim0
,
unsigned
ThreadPerDim1
>
unsigned
ThreadPerDim1
>
struct
b
lockwise
_
2d
_t
ensor
_c
opy
_
2
struct
B
lockwise2d
T
ensor
C
opy2
{
{
unsigned
mThreadId0
;
unsigned
mThreadId0
;
unsigned
mThreadId1
;
unsigned
mThreadId1
;
__device__
b
lockwise
_
2d
_t
ensor
_c
opy
_
2
()
__device__
B
lockwise2d
T
ensor
C
opy2
()
{
{
static_assert
(
is_same
<
Float
,
float
>::
value
,
"wrong! type is not float!
\n
"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
static_assert
(
SrcDesc
{}.
GetStride
(
I1
)
==
1
&&
DstDesc
{}.
GetStride
(
I1
)
==
1
,
"wrong! stride is not 1!
\n
"
);
mThreadId0
=
get_thread_local_1d_id
()
/
ThreadPerDim1
;
mThreadId0
=
get_thread_local_1d_id
()
/
ThreadPerDim1
;
mThreadId1
=
get_thread_local_1d_id
()
-
mThreadId0
*
ThreadPerDim1
;
mThreadId1
=
get_thread_local_1d_id
()
-
mThreadId0
*
ThreadPerDim1
;
}
}
__device__
void
r
un
(
Float
*
const
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
__device__
void
R
un
(
Float
*
const
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
{
static_assert
(
is_same
<
Float
,
float
>::
value
,
"wrong! only support float!
\n
"
);
if
(
get_thread_local_1d_id
()
>=
ThreadPerDim0
*
ThreadPerDim1
)
if
(
get_thread_local_1d_id
()
>=
ThreadPerDim0
*
ThreadPerDim1
)
return
;
return
;
...
@@ -227,22 +235,12 @@ struct blockwise_2d_tensor_copy_2
...
@@ -227,22 +235,12 @@ struct blockwise_2d_tensor_copy_2
for
(
unsigned
d1v4loop
=
0
;
d1v4loop
<
Dim1V4Loop
;
++
d1v4loop
)
for
(
unsigned
d1v4loop
=
0
;
d1v4loop
<
Dim1V4Loop
;
++
d1v4loop
)
{
{
unsigned
did1
=
d1v4loop
*
4
*
ThreadPerDim1
+
4
*
mThreadId1
;
unsigned
did1
=
d1v4loop
*
4
*
ThreadPerDim1
+
4
*
mThreadId1
;
#if 1
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
*
(
reinterpret_cast
<
float4
*>
(
p_dst
+
dindex
))
=
*
(
reinterpret_cast
<
float4
*>
(
p_dst
+
dindex
))
=
*
(
reinterpret_cast
<
float4
*>
(
p_src
+
sindex
));
*
(
reinterpret_cast
<
float4
*>
(
p_src
+
sindex
));
#else
for
(
unsigned
i
=
0
;
i
<
4
;
++
i
)
{
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
+
i
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
+
i
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
#endif
}
}
// v2
// v2
...
@@ -251,22 +249,11 @@ struct blockwise_2d_tensor_copy_2
...
@@ -251,22 +249,11 @@ struct blockwise_2d_tensor_copy_2
unsigned
did1
=
unsigned
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
d1v2loop
*
2
*
ThreadPerDim1
+
2
*
mThreadId1
;
Dim1V4Loop
*
4
*
ThreadPerDim1
+
d1v2loop
*
2
*
ThreadPerDim1
+
2
*
mThreadId1
;
#if 1
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
*
(
reinterpret_cast
<
float2
*>
(
p_dst
+
dindex
))
=
*
(
reinterpret_cast
<
float2
*>
(
p_dst
+
dindex
))
=
*
(
reinterpret_cast
<
float2
*>
(
p_src
+
sindex
));
*
(
reinterpret_cast
<
float2
*>
(
p_src
+
sindex
));
#else
for
(
unsigned
i
=
0
;
i
<
2
;
++
i
)
{
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
+
i
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
+
i
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
#endif
}
}
// v1
// v1
...
@@ -310,22 +297,11 @@ struct blockwise_2d_tensor_copy_2
...
@@ -310,22 +297,11 @@ struct blockwise_2d_tensor_copy_2
{
{
unsigned
did1
=
d1v4loop
*
4
*
ThreadPerDim1
+
4
*
mThreadId1
;
unsigned
did1
=
d1v4loop
*
4
*
ThreadPerDim1
+
4
*
mThreadId1
;
#if 1
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
*
(
reinterpret_cast
<
float4
*>
(
p_dst
+
dindex
))
=
*
(
reinterpret_cast
<
float4
*>
(
p_dst
+
dindex
))
=
*
(
reinterpret_cast
<
float4
*>
(
p_src
+
sindex
));
*
(
reinterpret_cast
<
float4
*>
(
p_src
+
sindex
));
#else
for
(
unsigned
i
=
0
;
i
<
4
;
++
i
)
{
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
+
i
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
+
i
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
#endif
}
}
// v2
// v2
...
@@ -334,22 +310,11 @@ struct blockwise_2d_tensor_copy_2
...
@@ -334,22 +310,11 @@ struct blockwise_2d_tensor_copy_2
unsigned
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
d1v2loop
*
2
*
ThreadPerDim1
+
unsigned
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
d1v2loop
*
2
*
ThreadPerDim1
+
2
*
mThreadId1
;
2
*
mThreadId1
;
#if 1
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
*
(
reinterpret_cast
<
float2
*>
(
p_dst
+
dindex
))
=
*
(
reinterpret_cast
<
float2
*>
(
p_dst
+
dindex
))
=
*
(
reinterpret_cast
<
float2
*>
(
p_src
+
sindex
));
*
(
reinterpret_cast
<
float2
*>
(
p_src
+
sindex
));
#else
for
(
unsigned
i
=
0
;
i
<
2
;
++
i
)
{
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
+
i
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
+
i
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
#endif
}
}
// v1
// v1
...
@@ -385,49 +350,104 @@ struct blockwise_2d_tensor_copy_2
...
@@ -385,49 +350,104 @@ struct blockwise_2d_tensor_copy_2
}
}
};
};
template
<
unsigned
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
// starting point need to be aligned to float4 or float2 or float
struct
blockwise_2d_tensor_copy_dummy_1
// stride1 need to be 1 for both source and destination
template
<
unsigned
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
unsigned
DataPerRead
>
struct
Blockwise2dTensorCopy3
{
{
unsigned
mBegin
;
unsigned
mSrcMyThreadOffset
;
unsigned
mDstMyThreadOffset
;
__device__
b
lockwise
_
2d
_t
ensor
_c
opy
_dummy_1
()
__device__
B
lockwise2d
T
ensor
C
opy
3
()
{
{
constexpr
unsigned
n_total
=
constexpr
auto
I0
=
Number
<
0
>
{};
make_ConstantTensorDescriptor
(
SrcOpLengths
{}).
GetElementSpace
();
constexpr
auto
I1
=
Number
<
1
>
{};
static_assert
(
SrcDesc
{}.
GetStride
(
I1
)
==
1
&&
DstDesc
{}.
GetStride
(
I1
)
==
1
,
"wrong! only support stride1 == 1!
\n
"
);
static_assert
(
DataPerRead
==
1
||
DataPerRead
==
2
||
DataPerRead
==
4
,
"wrong! only support DataPerRead == 1, 2 or 4!
\n
"
);
constexpr
unsigned
L0
=
SrcOpLengths
{}.
Get
(
I0
);
constexpr
unsigned
L1
=
SrcOpLengths
{}.
Get
(
I1
);
static_assert
(
L1
%
DataPerRead
==
0
,
"wrong! only support mod(L1, DataPerRead) == 0
\n
"
);
constexpr
unsigned
thread_per_d1
=
L1
/
DataPerRead
;
constexpr
unsigned
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
unsigned
n_per_thread
=
n_total
/
BlockSize
;
static_assert
(
thread_per_d1
<=
BlockSize
,
"wrong! not enough threads to cover L1 dimension
\n
"
);
mBegin
=
n_per_thread
*
get_thread_local_1d_id
();
const
unsigned
thread_id_d0
=
get_thread_local_1d_id
()
/
thread_per_d1
;
const
unsigned
thread_id_d1
=
get_thread_local_1d_id
()
-
thread_id_d0
*
thread_per_d1
;
mSrcMyThreadOffset
=
SrcDesc
{}.
Get1dIndex
(
thread_id_d0
,
thread_id_d1
*
DataPerRead
);
mDstMyThreadOffset
=
DstDesc
{}.
Get1dIndex
(
thread_id_d0
,
thread_id_d1
*
DataPerRead
);
}
}
__device__
void
r
un
(
Float
*
const
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
__device__
void
R
un
(
Float
*
const
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
{
constexpr
unsigned
n_total
=
static_assert
(
is_same
<
Float
,
float
>::
value
,
"wrong! only support float!
\n
"
);
make_ConstantTensorDescriptor
(
SrcOpLengths
{}).
GetElementSpace
();
using
Float2
=
float2
;
using
Float4
=
float4
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
unsigned
L0
=
SrcOpLengths
{}.
Get
(
I0
);
constexpr
unsigned
L1
=
SrcOpLengths
{}.
Get
(
I1
);
constexpr
unsigned
n_per_thread
=
n_total
/
BlockSize
;
constexpr
unsigned
thread_per_d1
=
L1
/
DataPerRead
;
constexpr
unsigned
thread_per_d0
=
BlockSize
/
thread_per_d1
;
for
(
unsigned
i
=
0
;
i
<
n_per_thread
;
++
i
)
constexpr
unsigned
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
if
(
BlockSize
>
num_active_thread
)
{
{
p_dst
[
mBegin
+
i
]
=
p_src
[
mBegin
+
i
];
if
(
get_thread_local_1d_id
()
>
num_active_thread
)
{
return
;
}
}
}
}
};
template
<
unsigned
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
constexpr
unsigned
nloop_d0
=
L0
/
thread_per_d0
;
struct
blockwise_2d_tensor_copy_dummy_2
{
constexpr
bool
has_tail_d0
=
(
L0
>
nloop_d0
*
thread_per_d0
);
__device__
void
run
(
Float
*
const
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
unsigned
n_total
=
make_ConstantTensorDescriptor
(
SrcOpLengths
{}).
GetElementSpace
();
constexpr
unsigned
n_per_thread
=
n_total
/
BlockSize
;
constexpr
unsigned
src_loop_stride
=
SrcDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
unsigned
dst_loop_stride
=
DstDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
for
(
unsigned
i
=
0
;
i
<
n_per_thread
;
++
i
)
for
(
unsigned
i
loop
=
0
;
i
loop
<
nloop_d0
;
++
i
loop
)
{
{
unsigned
index
=
get_thread_local_1d_id
()
+
BlockSize
*
i
;
if
(
DataPerRead
==
1
)
p_dst
[
index
]
=
p_src
[
index
];
{
p_dst
[
mDstMyThreadOffset
+
iloop
*
dst_loop_stride
]
=
p_src
[
mSrcMyThreadOffset
+
iloop
*
src_loop_stride
];
}
else
if
(
DataPerRead
==
2
)
{
*
(
reinterpret_cast
<
Float2
*>
(
p_dst
+
mDstMyThreadOffset
+
iloop
*
dst_loop_stride
))
=
*
(
reinterpret_cast
<
Float2
*>
(
p_src
+
mSrcMyThreadOffset
+
iloop
*
src_loop_stride
));
}
else
if
(
DataPerRead
==
4
)
{
*
(
reinterpret_cast
<
Float4
*>
(
p_dst
+
mDstMyThreadOffset
+
iloop
*
dst_loop_stride
))
=
*
(
reinterpret_cast
<
Float4
*>
(
p_src
+
mSrcMyThreadOffset
+
iloop
*
src_loop_stride
));
}
else
{
assert
(
false
);
}
}
}
}
}
};
};
src/include/blockwise_4d_tensor_op.cuh
View file @
6614729a
...
@@ -200,9 +200,9 @@ blockwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
...
@@ -200,9 +200,9 @@ blockwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
}
}
template
<
unsigned
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
template
<
unsigned
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
struct
b
lockwise
_
4d
_t
ensor
_c
opy
_
1
struct
B
lockwise4d
T
ensor
C
opy1
{
{
__device__
void
r
un
(
Float
*
const
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
__device__
void
R
un
(
Float
*
const
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
{
constexpr
auto
dst_from_src_reorder
=
Sequence
<
0
,
1
,
2
,
3
>
{};
constexpr
auto
dst_from_src_reorder
=
Sequence
<
0
,
1
,
2
,
3
>
{};
...
@@ -217,9 +217,9 @@ template <unsigned BlockSize,
...
@@ -217,9 +217,9 @@ template <unsigned BlockSize,
class
DstDesc
,
class
DstDesc
,
class
DstOpLengths
,
class
DstOpLengths
,
class
GlobalLowerPads
>
class
GlobalLowerPads
>
struct
b
lockwise
_c
hwn
_t
ensor
_c
opy
_with_padding
struct
B
lockwise
C
hwn
T
ensor
C
opy
Padded
{
{
__device__
void
r
un
(
Float
*
const
__restrict__
p_src
,
__device__
void
R
un
(
Float
*
const
__restrict__
p_src
,
unsigned
c_block_data_begin
,
unsigned
c_block_data_begin
,
unsigned
ho_block_data_begin
,
unsigned
ho_block_data_begin
,
unsigned
wo_block_data_begin
,
unsigned
wo_block_data_begin
,
...
@@ -336,33 +336,4 @@ struct blockwise_chwn_tensor_copy_with_padding
...
@@ -336,33 +336,4 @@ struct blockwise_chwn_tensor_copy_with_padding
}
}
}
}
}
}
};
};
\ No newline at end of file
template
<
unsigned
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
struct
blockwise_4d_tensor_copy_dummy
{
unsigned
mBegin
;
__device__
blockwise_4d_tensor_copy_dummy
()
{
constexpr
unsigned
n_total
=
make_ConstantTensorDescriptor
(
SrcOpLengths
{}).
GetElementSpace
();
constexpr
unsigned
n_per_thread
=
n_total
/
BlockSize
;
mBegin
=
n_per_thread
*
get_thread_local_1d_id
();
}
__device__
void
run
(
Float
*
const
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
unsigned
n_total
=
make_ConstantTensorDescriptor
(
SrcOpLengths
{}).
GetElementSpace
();
constexpr
unsigned
n_per_thread
=
n_total
/
BlockSize
;
for
(
unsigned
i
=
0
;
i
<
n_per_thread
;
++
i
)
{
p_dst
[
mBegin
+
i
]
=
p_src
[
mBegin
+
i
];
}
}
};
src/include/blockwise_gemm.cuh
View file @
6614729a
...
@@ -15,7 +15,7 @@ template <unsigned BlockSize,
...
@@ -15,7 +15,7 @@ template <unsigned BlockSize,
unsigned
BatchPerThread
,
unsigned
BatchPerThread
,
unsigned
KPerThreadLoop
,
unsigned
KPerThreadLoop
,
bool
DistributeThreadAlongColumnFirst
>
bool
DistributeThreadAlongColumnFirst
>
struct
b
lockwise
_
1d
_s
trided
_b
atched
_g
emm
_b
lock
_a_block_b_t
hread
_c
struct
B
lockwise1d
S
trided
B
atched
G
emm
B
lock
ABlockBT
hread
C
{
{
unsigned
mMyThreadOffsetA
=
0
;
unsigned
mMyThreadOffsetA
=
0
;
unsigned
mMyThreadOffsetB
=
0
;
unsigned
mMyThreadOffsetB
=
0
;
...
@@ -27,7 +27,7 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
...
@@ -27,7 +27,7 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
unsigned
col_begin
;
unsigned
col_begin
;
};
};
__device__
b
lockwise
_
1d
_s
trided
_b
atched
_g
emm
_b
lock
_a_block_b_t
hread
_c
()
__device__
B
lockwise1d
S
trided
B
atched
G
emm
B
lock
ABlockBT
hread
C
()
{
{
const
auto
a_block_mtx
=
BlockMatrixA
{};
// constexpr doesn't compile
const
auto
a_block_mtx
=
BlockMatrixA
{};
// constexpr doesn't compile
const
auto
b_block_mtx
=
BlockMatrixB
{};
// constexpr doesn't compile
const
auto
b_block_mtx
=
BlockMatrixB
{};
// constexpr doesn't compile
...
@@ -117,7 +117,7 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
...
@@ -117,7 +117,7 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
r
un
(
FloatA
*
const
p_a_block
,
__device__
void
R
un
(
FloatA
*
const
p_a_block
,
FloatB
*
const
p_b_block
,
FloatB
*
const
p_b_block
,
FloatC
*
p_c_thread
,
FloatC
*
p_c_thread
,
Accumulator
f_accum
)
const
Accumulator
f_accum
)
const
...
@@ -230,7 +230,7 @@ template <unsigned BlockSize,
...
@@ -230,7 +230,7 @@ template <unsigned BlockSize,
unsigned
MThreadPerCluster
,
unsigned
MThreadPerCluster
,
unsigned
NThreadPerCluster
,
unsigned
NThreadPerCluster
,
bool
DistributeThreadAlongColumnFirst
>
bool
DistributeThreadAlongColumnFirst
>
struct
b
lockwise
_g
emm
_b
lock
_a_block_b_t
hread
_c
struct
B
lockwise
G
emm
B
lock
ABlockBT
hread
C
{
{
unsigned
mMyThreadOffsetA
=
0
;
unsigned
mMyThreadOffsetA
=
0
;
unsigned
mMyThreadOffsetB
=
0
;
unsigned
mMyThreadOffsetB
=
0
;
...
@@ -241,7 +241,7 @@ struct blockwise_gemm_block_a_block_b_thread_c
...
@@ -241,7 +241,7 @@ struct blockwise_gemm_block_a_block_b_thread_c
unsigned
col_begin
;
unsigned
col_begin
;
};
};
__device__
b
lockwise
_g
emm
_b
lock
_a_block_b_t
hread
_c
()
__device__
B
lockwise
G
emm
B
lock
ABlockBT
hread
C
()
{
{
const
auto
a_block_mtx
=
BlockMatrixA
{};
// constexpr doesn't compile
const
auto
a_block_mtx
=
BlockMatrixA
{};
// constexpr doesn't compile
const
auto
b_block_mtx
=
BlockMatrixB
{};
// constexpr doesn't compile
const
auto
b_block_mtx
=
BlockMatrixB
{};
// constexpr doesn't compile
...
@@ -360,7 +360,7 @@ struct blockwise_gemm_block_a_block_b_thread_c
...
@@ -360,7 +360,7 @@ struct blockwise_gemm_block_a_block_b_thread_c
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
r
un
(
FloatA
*
const
p_a_block
,
__device__
void
R
un
(
FloatA
*
const
p_a_block
,
FloatB
*
const
p_b_block
,
FloatB
*
const
p_b_block
,
FloatC
*
p_c_thread
,
FloatC
*
p_c_thread
,
Accumulator
f_accum
)
const
Accumulator
f_accum
)
const
...
...
src/include/gridwise_direct_convolution_1.cuh
View file @
6614729a
...
@@ -122,25 +122,25 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc,
...
@@ -122,25 +122,25 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc,
#endif
#endif
constexpr
auto
blockwise_in_copy
=
constexpr
auto
blockwise_in_copy
=
b
lockwise
_
4d
_t
ensor
_c
opy
_
1
<
BlockSize
,
B
lockwise4d
T
ensor
C
opy1
<
BlockSize
,
Float
,
Float
,
decltype
(
in_block_global_desc
),
decltype
(
in_block_global_desc
),
decltype
(
in_block_desc
),
decltype
(
in_block_desc
),
decltype
(
in_block_desc
.
GetLengths
())
>
{};
decltype
(
in_block_desc
.
GetLengths
())
>
{};
constexpr
auto
blockwise_wei_copy
=
constexpr
auto
blockwise_wei_copy
=
b
lockwise
_
4d
_t
ensor
_c
opy
_
1
<
BlockSize
,
B
lockwise4d
T
ensor
C
opy1
<
BlockSize
,
Float
,
Float
,
decltype
(
wei_block_global_desc
),
decltype
(
wei_block_global_desc
),
decltype
(
wei_block_desc
),
decltype
(
wei_block_desc
),
decltype
(
wei_block_desc
.
GetLengths
())
>
{};
decltype
(
wei_block_desc
.
GetLengths
())
>
{};
constexpr
auto
blockwise_out_copy
=
constexpr
auto
blockwise_out_copy
=
b
lockwise
_
4d
_t
ensor
_c
opy
_
1
<
BlockSize
,
B
lockwise4d
T
ensor
C
opy1
<
BlockSize
,
Float
,
Float
,
decltype
(
out_block_desc
),
decltype
(
out_block_desc
),
decltype
(
out_block_global_desc
),
decltype
(
out_block_global_desc
),
decltype
(
out_block_desc
.
GetLengths
())
>
{};
decltype
(
out_block_desc
.
GetLengths
())
>
{};
// set output tensor in LDS to 0
// set output tensor in LDS to 0
blockwise_4d_tensor_set_zero
<
BlockSize
>
(
out_block_desc
,
p_out_block
);
blockwise_4d_tensor_set_zero
<
BlockSize
>
(
out_block_desc
,
p_out_block
);
...
@@ -149,14 +149,14 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc,
...
@@ -149,14 +149,14 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc,
c_block_work_begin
+=
CPerBlock
)
c_block_work_begin
+=
CPerBlock
)
{
{
// copy input tensor to LDS
// copy input tensor to LDS
blockwise_in_copy
.
r
un
(
p_in_global
+
in_global_desc
.
Get1dIndex
(
n_block_work_begin
,
blockwise_in_copy
.
R
un
(
p_in_global
+
in_global_desc
.
Get1dIndex
(
n_block_work_begin
,
c_block_work_begin
,
c_block_work_begin
,
hi_block_work_begin
,
hi_block_work_begin
,
wi_block_work_begin
),
wi_block_work_begin
),
p_in_block
);
p_in_block
);
// copy weight tensor to LDS
// copy weight tensor to LDS
blockwise_wei_copy
.
r
un
(
blockwise_wei_copy
.
R
un
(
p_wei_global
+
wei_global_desc
.
Get1dIndex
(
k_block_work_begin
,
c_block_work_begin
,
0
,
0
),
p_wei_global
+
wei_global_desc
.
Get1dIndex
(
k_block_work_begin
,
c_block_work_begin
,
0
,
0
),
p_wei_block
);
p_wei_block
);
...
@@ -179,7 +179,7 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc,
...
@@ -179,7 +179,7 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc,
}
}
// copy output tensor from LDS to device mem
// copy output tensor from LDS to device mem
blockwise_out_copy
.
r
un
(
p_out_block
,
blockwise_out_copy
.
R
un
(
p_out_block
,
p_out_global
+
out_global_desc
.
Get1dIndex
(
n_block_work_begin
,
p_out_global
+
out_global_desc
.
Get1dIndex
(
n_block_work_begin
,
k_block_work_begin
,
k_block_work_begin
,
ho_block_work_begin
,
ho_block_work_begin
,
...
...
src/include/gridwise_direct_convolution_2.cuh
View file @
6614729a
...
@@ -145,18 +145,18 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
...
@@ -145,18 +145,18 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
#endif
#endif
constexpr
auto
blockwise_in_copy
=
constexpr
auto
blockwise_in_copy
=
b
lockwise
_
4d
_t
ensor
_c
opy
_
1
<
BlockSize
,
B
lockwise4d
T
ensor
C
opy1
<
BlockSize
,
Float
,
Float
,
decltype
(
in_global_desc
),
decltype
(
in_global_desc
),
decltype
(
in_block_desc
),
decltype
(
in_block_desc
),
decltype
(
in_block_desc
.
GetLengths
())
>
{};
decltype
(
in_block_desc
.
GetLengths
())
>
{};
constexpr
auto
blockwise_wei_copy
=
constexpr
auto
blockwise_wei_copy
=
b
lockwise
_
4d
_t
ensor
_c
opy
_
1
<
BlockSize
,
B
lockwise4d
T
ensor
C
opy1
<
BlockSize
,
Float
,
Float
,
decltype
(
wei_global_desc
),
decltype
(
wei_global_desc
),
decltype
(
wei_block_desc
),
decltype
(
wei_block_desc
),
decltype
(
wei_block_desc
.
GetLengths
())
>
{};
decltype
(
wei_block_desc
.
GetLengths
())
>
{};
// set threadwise output tensor to 0
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_thread_desc
,
p_out_thread
);
threadwise_4d_tensor_set_zero
(
out_thread_desc
,
p_out_thread
);
...
@@ -165,14 +165,14 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
...
@@ -165,14 +165,14 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
{
{
// copy input tensor to LDS
// copy input tensor to LDS
blockwise_in_copy
.
r
un
(
p_in_global
+
in_global_desc
.
Get1dIndex
(
n_block_data_begin
,
blockwise_in_copy
.
R
un
(
p_in_global
+
in_global_desc
.
Get1dIndex
(
n_block_data_begin
,
c_block_data_begin
,
c_block_data_begin
,
hi_block_data_begin
,
hi_block_data_begin
,
wi_block_data_begin
),
wi_block_data_begin
),
p_in_block
);
p_in_block
);
// copy weight tensor to LDS
// copy weight tensor to LDS
blockwise_wei_copy
.
r
un
(
blockwise_wei_copy
.
R
un
(
p_wei_global
+
wei_global_desc
.
Get1dIndex
(
k_block_data_begin
,
c_block_data_begin
,
0
,
0
),
p_wei_global
+
wei_global_desc
.
Get1dIndex
(
k_block_data_begin
,
c_block_data_begin
,
0
,
0
),
p_wei_block
);
p_wei_block
);
...
...
src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh
View file @
6614729a
...
@@ -106,19 +106,19 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
...
@@ -106,19 +106,19 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
// blockwise copy
// blockwise copy
// input: format is [C, Hi, Wi, N]
// input: format is [C, Hi, Wi, N]
constexpr
auto
blockwise_in_copy
=
constexpr
auto
blockwise_in_copy
=
b
lockwise
_
4d
_t
ensor
_c
opy
_
1
<
BlockSize
,
B
lockwise4d
T
ensor
C
opy1
<
BlockSize
,
Float
,
Float
,
decltype
(
in_chwn_global_desc
),
decltype
(
in_chwn_global_desc
),
decltype
(
in_chwn_block_desc
),
decltype
(
in_chwn_block_desc
),
decltype
(
in_chwn_block_desc
.
GetLengths
())
>
{};
decltype
(
in_chwn_block_desc
.
GetLengths
())
>
{};
// weight: format is [S,R,C,K]
// weight: format is [S,R,C,K]
constexpr
auto
blockwise_wei_copy
=
constexpr
auto
blockwise_wei_copy
=
b
lockwise
_
4d
_t
ensor
_c
opy
_
1
<
BlockSize
,
B
lockwise4d
T
ensor
C
opy1
<
BlockSize
,
Float
,
Float
,
decltype
(
wei_csrk_global_desc
),
decltype
(
wei_csrk_global_desc
),
decltype
(
wei_csrk_block_desc
),
decltype
(
wei_csrk_block_desc
),
decltype
(
wei_csrk_block_desc
.
GetLengths
())
>
{};
decltype
(
wei_csrk_block_desc
.
GetLengths
())
>
{};
// a series of blockwise batched GEMM
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// C_matrix += transpose(A_matrix) * B_matrix
...
@@ -140,21 +140,20 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
...
@@ -140,21 +140,20 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{});
// constexpr doesn't compile
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{});
// constexpr doesn't compile
const
auto
blockwise_batch_gemm
=
const
auto
blockwise_batch_gemm
=
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
<
BlockSize
,
Blockwise1dStridedBatchedGemmBlockABlockBThreadC
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
true
,
true
,
false
,
false
,
false
,
false
,
0
,
0
,
in_chwn_block_desc
.
GetStride
(
I1
),
in_chwn_block_desc
.
GetStride
(
I1
),
out_hkwn_thread_desc
.
GetStride
(
out_hkwn_thread_desc
.
GetStride
(
I0
),
I0
),
HoPerBlock
,
HoPerBlock
,
HoPerThread
,
HoPerThread
,
CPerThread
,
CPerThread
,
true
>
{};
true
>
{};
// LDS
// LDS
constexpr
unsigned
in_block_size
=
in_chwn_block_desc
.
GetElementSpace
();
constexpr
unsigned
in_block_size
=
in_chwn_block_desc
.
GetElementSpace
();
...
@@ -183,12 +182,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
...
@@ -183,12 +182,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
{
{
#if 1
#if 1
// input: global mem to LDS,
// input: global mem to LDS,
blockwise_in_copy
.
r
un
(
p_in_global_block_begin
,
p_in_block
);
blockwise_in_copy
.
R
un
(
p_in_global_block_begin
,
p_in_block
);
#endif
#endif
#if 1
#if 1
// weight: global mem to LDS,
// weight: global mem to LDS,
blockwise_wei_copy
.
r
un
(
p_wei_global_block_begin
,
p_wei_block
);
blockwise_wei_copy
.
R
un
(
p_wei_global_block_begin
,
p_wei_block
);
#endif
#endif
__syncthreads
();
__syncthreads
();
...
@@ -200,7 +199,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
...
@@ -200,7 +199,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
{
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
blockwise_batch_gemm
.
r
un
(
p_wei_block
+
wei_csrk_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
blockwise_batch_gemm
.
R
un
(
p_wei_block
+
wei_csrk_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_out_thread
,
p_out_thread
,
f_accum
);
f_accum
);
...
...
src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh
View file @
6614729a
...
@@ -136,39 +136,38 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri
...
@@ -136,39 +136,38 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri
#endif
#endif
constexpr
auto
blockwise_in_copy
=
constexpr
auto
blockwise_in_copy
=
b
lockwise
_c
hwn
_t
ensor
_c
opy
_with_padding
<
BlockSize
,
B
lockwise
C
hwn
T
ensor
C
opy
Padded
<
BlockSize
,
Float
,
Float
,
decltype
(
in_chwn_global_desc
),
decltype
(
in_chwn_global_desc
),
decltype
(
in_chwn_block_desc
),
decltype
(
in_chwn_block_desc
),
decltype
(
in_chwn_block_desc
.
GetLengths
()),
decltype
(
in_chwn_block_desc
.
GetLengths
()),
LowerPads
>
{};
LowerPads
>
{};
#if 1
#if 1
// weight: format is [C,S,R,K]
// weight: format is [C,S,R,K]
constexpr
auto
blockwise_wei_copy
=
constexpr
auto
blockwise_wei_copy
=
b
lockwise
_
4d
_t
ensor
_c
opy
_
1
<
BlockSize
,
B
lockwise4d
T
ensor
C
opy1
<
BlockSize
,
Float
,
Float
,
decltype
(
wei_csrk_global_desc
),
decltype
(
wei_csrk_global_desc
),
decltype
(
wei_csrk_block_desc
),
decltype
(
wei_csrk_block_desc
),
decltype
(
wei_csrk_block_desc
.
GetLengths
())
>
{};
decltype
(
wei_csrk_block_desc
.
GetLengths
())
>
{};
#elif 1
#elif 1
// weight: format is [C*S*R,K]
// weight: format is [C*S*R,K]
constexpr
auto
blockwise_wei_copy
=
constexpr
auto
blockwise_wei_copy
=
b
lockwise
_
2d
_t
ensor
_c
opy
_
1
<
BlockSize
,
B
lockwise2d
T
ensor
C
opy1
<
BlockSize
,
Float
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
())
>
{};
decltype
(
wei_ek_block_desc
.
GetLengths
())
>
{};
#elif 1
#elif 1
// weight: format is [C*S*R,K]
// weight: format is [C*S*R,K]
const
auto
blockwise_wei_copy
=
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
blockwise_2d_tensor_copy_2
<
BlockSize
,
Float
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
{};
WeiBlockCopyThreadPerDim1
>
{};
#endif
#endif
// a series of blockwise batched GEMM
// a series of blockwise batched GEMM
...
@@ -191,21 +190,20 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri
...
@@ -191,21 +190,20 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{});
// constexpr doesn't compile
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{});
// constexpr doesn't compile
const
auto
blockwise_batch_gemm
=
const
auto
blockwise_batch_gemm
=
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
<
BlockSize
,
Blockwise1dStridedBatchedGemmBlockABlockBThreadC
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
true
,
true
,
false
,
false
,
false
,
false
,
0
,
0
,
in_chwn_block_desc
.
GetStride
(
I1
),
in_chwn_block_desc
.
GetStride
(
I1
),
out_hkwn_thread_desc
.
GetStride
(
out_hkwn_thread_desc
.
GetStride
(
I0
),
I0
),
HoPerBlock
,
HoPerBlock
,
HoPerThread
,
HoPerThread
,
CPerThread
,
CPerThread
,
true
>
{};
true
>
{};
// LDS
// LDS
constexpr
unsigned
in_block_size
=
in_chwn_block_desc
.
GetElementSpace
();
constexpr
unsigned
in_block_size
=
in_chwn_block_desc
.
GetElementSpace
();
...
@@ -229,7 +227,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri
...
@@ -229,7 +227,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri
{
{
#if 1
#if 1
// input: global mem to LDS,
// input: global mem to LDS,
blockwise_in_copy
.
r
un
(
p_in_global
,
blockwise_in_copy
.
R
un
(
p_in_global
,
c_block_data_begin
,
c_block_data_begin
,
ho_block_data_begin
,
ho_block_data_begin
,
wo_block_data_begin
,
wo_block_data_begin
,
...
@@ -243,7 +241,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri
...
@@ -243,7 +241,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri
#if 1
#if 1
// weight: global mem to LDS,
// weight: global mem to LDS,
blockwise_wei_copy
.
r
un
(
p_wei_global_block_begin
,
p_wei_block
);
blockwise_wei_copy
.
R
un
(
p_wei_global_block_begin
,
p_wei_block
);
#endif
#endif
__syncthreads
();
__syncthreads
();
...
@@ -255,7 +253,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri
...
@@ -255,7 +253,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri
{
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
blockwise_batch_gemm
.
r
un
(
p_wei_block
+
wei_csrk_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
blockwise_batch_gemm
.
R
un
(
p_wei_block
+
wei_csrk_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_out_thread
,
p_out_thread
,
f_accum
);
f_accum
);
...
...
src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_pipeline.cuh
View file @
6614729a
...
@@ -136,17 +136,17 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
...
@@ -136,17 +136,17 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
#endif
#endif
constexpr
auto
blockwise_in_copy
=
constexpr
auto
blockwise_in_copy
=
b
lockwise
_c
hwn
_t
ensor
_c
opy
_with_padding
<
BlockSize
,
B
lockwise
C
hwn
T
ensor
C
opy
Padded
<
BlockSize
,
Float
,
Float
,
decltype
(
in_chwn_global_desc
),
decltype
(
in_chwn_global_desc
),
decltype
(
in_chwn_block_desc
),
decltype
(
in_chwn_block_desc
),
decltype
(
in_chwn_block_desc
.
GetLengths
()),
decltype
(
in_chwn_block_desc
.
GetLengths
()),
LowerPads
>
{};
LowerPads
>
{};
#if 0
#if 0
// weight: format is [C,S,R,K]
// weight: format is [C,S,R,K]
constexpr auto blockwise_wei_copy =
constexpr auto blockwise_wei_copy =
b
lockwise
_
4d
_t
ensor
_c
opy
_
1<BlockSize,
B
lockwise4d
T
ensor
C
opy1<BlockSize,
Float,
Float,
decltype(wei_csrk_global_desc),
decltype(wei_csrk_global_desc),
decltype(wei_csrk_block_desc),
decltype(wei_csrk_block_desc),
...
@@ -154,21 +154,20 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
...
@@ -154,21 +154,20 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
#elif
0
#elif
0
// weight: format is [C*S*R,K]
// weight: format is [C*S*R,K]
constexpr
auto
blockwise_wei_copy
=
constexpr
auto
blockwise_wei_copy
=
b
lockwise
_
2d
_t
ensor
_c
opy
_
1
<
BlockSize
,
B
lockwise2d
T
ensor
C
opy1
<
BlockSize
,
Float
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
())
>
{};
decltype
(
wei_ek_block_desc
.
GetLengths
())
>
{};
#elif 1
#elif 1
// weight: format is [C*S*R,K]
// weight: format is [C*S*R,K]
const
auto
blockwise_wei_copy
=
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
blockwise_2d_tensor_copy_2
<
BlockSize
,
Float
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
{};
WeiBlockCopyThreadPerDim1
>
{};
#endif
#endif
// a series of blockwise batched GEMM
// a series of blockwise batched GEMM
...
@@ -191,21 +190,20 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
...
@@ -191,21 +190,20 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{});
// constexpr doesn't compile
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{});
// constexpr doesn't compile
const
auto
blockwise_batch_gemm
=
const
auto
blockwise_batch_gemm
=
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
<
BlockSize
,
Blockwise1dStridedBatchedGemmBlockABlockBThreadC
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
true
,
true
,
false
,
false
,
false
,
false
,
0
,
0
,
in_chwn_block_desc
.
GetStride
(
I1
),
in_chwn_block_desc
.
GetStride
(
I1
),
out_hkwn_thread_desc
.
GetStride
(
out_hkwn_thread_desc
.
GetStride
(
I0
),
I0
),
HoPerBlock
,
HoPerBlock
,
HoPerThread
,
HoPerThread
,
CPerThread
,
CPerThread
,
true
>
{};
true
>
{};
// LDS
// LDS
constexpr
unsigned
in_block_size
=
in_chwn_block_desc
.
GetElementSpace
();
constexpr
unsigned
in_block_size
=
in_chwn_block_desc
.
GetElementSpace
();
...
@@ -229,7 +227,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
...
@@ -229,7 +227,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
// prelog: load data
// prelog: load data
// input: global mem to LDS,
// input: global mem to LDS,
blockwise_in_copy
.
r
un
(
p_in_global
,
blockwise_in_copy
.
R
un
(
p_in_global
,
0
,
0
,
ho_block_data_begin
,
ho_block_data_begin
,
wo_block_data_begin
,
wo_block_data_begin
,
...
@@ -241,7 +239,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
...
@@ -241,7 +239,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
w_block_pad_up
);
w_block_pad_up
);
// weight: global mem to LDS,
// weight: global mem to LDS,
blockwise_wei_copy
.
r
un
(
p_wei_global_block_begin
,
p_wei_block_0
);
blockwise_wei_copy
.
R
un
(
p_wei_global_block_begin
,
p_wei_block_0
);
p_wei_global_block_begin
+=
CPerBlock
*
wei_ek_global_desc
.
GetStride
(
I0
);
p_wei_global_block_begin
+=
CPerBlock
*
wei_ek_global_desc
.
GetStride
(
I0
);
...
@@ -263,7 +261,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
...
@@ -263,7 +261,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
// preload next data
// preload next data
#if 1
#if 1
// input: global mem to LDS,
// input: global mem to LDS,
blockwise_in_copy
.
r
un
(
p_in_global
,
blockwise_in_copy
.
R
un
(
p_in_global
,
c_block_data_begin
,
c_block_data_begin
,
ho_block_data_begin
,
ho_block_data_begin
,
wo_block_data_begin
,
wo_block_data_begin
,
...
@@ -277,7 +275,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
...
@@ -277,7 +275,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
#if 1
#if 1
// weight: global mem to LDS,
// weight: global mem to LDS,
blockwise_wei_copy
.
r
un
(
p_wei_global_block_begin
,
p_wei_block_next
);
blockwise_wei_copy
.
R
un
(
p_wei_global_block_begin
,
p_wei_block_next
);
#endif
#endif
// a series of batched GEMM
// a series of batched GEMM
...
@@ -287,7 +285,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
...
@@ -287,7 +285,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
{
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
blockwise_batch_gemm
.
r
un
(
p_wei_block_now
+
blockwise_batch_gemm
.
R
un
(
p_wei_block_now
+
wei_csrk_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
wei_csrk_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_in_block_now
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_in_block_now
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_out_thread
,
p_out_thread
,
...
@@ -310,7 +308,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
...
@@ -310,7 +308,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
{
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
blockwise_batch_gemm
.
r
un
(
p_wei_block_now
+
blockwise_batch_gemm
.
R
un
(
p_wei_block_now
+
wei_csrk_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
wei_csrk_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_in_block_now
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_in_block_now
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_out_thread
,
p_out_thread
,
...
...
src/include/gridwise_implicit_gemm_convolution_1_nchw_kcsr.cuh
View file @
6614729a
...
@@ -127,21 +127,20 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr(InGlobalDesc,
...
@@ -127,21 +127,20 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr(InGlobalDesc,
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{});
// constexpr doesn't compile
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{});
// constexpr doesn't compile
const
auto
blockwise_batch_gemm
=
const
auto
blockwise_batch_gemm
=
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
<
BlockSize
,
Blockwise1dStridedBatchedGemmBlockABlockBThreadC
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
true
,
true
,
false
,
false
,
false
,
false
,
0
,
0
,
in_chwn_block_desc
.
GetStride
(
I1
),
in_chwn_block_desc
.
GetStride
(
I1
),
out_hkwn_thread_desc
.
GetStride
(
out_hkwn_thread_desc
.
GetStride
(
I0
),
I0
),
HoPerBlock
,
HoPerBlock
,
HoPerThread
,
HoPerThread
,
CPerThread
,
CPerThread
,
true
>
{};
true
>
{};
// LDS
// LDS
constexpr
unsigned
in_block_size
=
in_chwn_block_desc
.
GetElementSpace
();
constexpr
unsigned
in_block_size
=
in_chwn_block_desc
.
GetElementSpace
();
...
@@ -175,15 +174,15 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr(InGlobalDesc,
...
@@ -175,15 +174,15 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr(InGlobalDesc,
#else
#else
// input: global mem to LDS,
// input: global mem to LDS,
// no format conversion, this is wrong, for performance study only!
// no format conversion, this is wrong, for performance study only!
b
lockwise
_
4d
_t
ensor
_c
opy
<
BlockSize
>
(
in_nchw_global_desc
,
B
lockwise4d
T
ensor
C
opy
<
BlockSize
>
(
in_nchw_global_desc
,
p_in_global
+
p_in_global
+
in_nchw_global_desc
.
Get1dIndex
(
n_block_data_begin
,
in_nchw_global_desc
.
Get1dIndex
(
n_block_data_begin
,
c_block_data_begin
,
c_block_data_begin
,
hi_block_data_begin
,
hi_block_data_begin
,
wi_block_data_begin
),
wi_block_data_begin
),
in_nchw_block_desc
,
in_nchw_block_desc
,
p_in_block
,
p_in_block
,
in_nchw_block_desc
.
GetLengths
());
in_nchw_block_desc
.
GetLengths
());
#endif
#endif
#if 1
#if 1
...
@@ -200,7 +199,7 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr(InGlobalDesc,
...
@@ -200,7 +199,7 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr(InGlobalDesc,
#else
#else
// weight: global mem to LDS,
// weight: global mem to LDS,
// no format conversion, this is wrong, for performance study only!
// no format conversion, this is wrong, for performance study only!
b
lockwise
_
4d
_t
ensor
_c
opy
<
BlockSize
>
(
B
lockwise4d
T
ensor
C
opy
<
BlockSize
>
(
wei_kcsr_global_desc
,
wei_kcsr_global_desc
,
p_wei_global
+
p_wei_global
+
wei_kcsr_global_desc
.
Get1dIndex
(
k_block_data_begin
,
c_block_data_begin
,
0
,
0
),
wei_kcsr_global_desc
.
Get1dIndex
(
k_block_data_begin
,
c_block_data_begin
,
0
,
0
),
...
@@ -219,7 +218,7 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr(InGlobalDesc,
...
@@ -219,7 +218,7 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr(InGlobalDesc,
{
{
auto
f_accum
=
[](
auto
&
c
,
const
auto
&&
ab
)
{
c
+=
ab
;
};
auto
f_accum
=
[](
auto
&
c
,
const
auto
&&
ab
)
{
c
+=
ab
;
};
blockwise_batch_gemm
.
r
un
(
p_wei_block
+
wei_srck_block_desc
.
Get1dIndex
(
s
,
r
,
0
,
0
),
blockwise_batch_gemm
.
R
un
(
p_wei_block
+
wei_srck_block_desc
.
Get1dIndex
(
s
,
r
,
0
,
0
),
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_out_thread
,
p_out_thread
,
f_accum
);
f_accum
);
...
...
src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
View file @
6614729a
...
@@ -109,11 +109,11 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
...
@@ -109,11 +109,11 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
// blockwise copy
// blockwise copy
// wei: format is [S,R,C,K], no conversion needed
// wei: format is [S,R,C,K], no conversion needed
constexpr
auto
blockwise_wei_copy
=
constexpr
auto
blockwise_wei_copy
=
b
lockwise
_
4d
_t
ensor
_c
opy
_
1
<
BlockSize
,
B
lockwise4d
T
ensor
C
opy1
<
BlockSize
,
Float
,
Float
,
decltype
(
wei_srck_global_desc
),
decltype
(
wei_srck_global_desc
),
decltype
(
wei_srck_block_desc
),
decltype
(
wei_srck_block_desc
),
decltype
(
wei_srck_block_desc
.
GetLengths
())
>
{};
decltype
(
wei_srck_block_desc
.
GetLengths
())
>
{};
// a series of blockwise batched GEMM
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// C_matrix += transpose(A_matrix) * B_matrix
...
@@ -133,21 +133,20 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
...
@@ -133,21 +133,20 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{});
// constexpr doesn't compile
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{});
// constexpr doesn't compile
const
auto
blockwise_batch_gemm
=
const
auto
blockwise_batch_gemm
=
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
<
BlockSize
,
Blockwise1dStridedBatchedGemmBlockABlockBThreadC
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
true
,
true
,
false
,
false
,
false
,
false
,
0
,
0
,
in_chwn_block_desc
.
GetStride
(
I1
),
in_chwn_block_desc
.
GetStride
(
I1
),
out_hkwn_thread_desc
.
GetStride
(
out_hkwn_thread_desc
.
GetStride
(
I0
),
I0
),
HoPerBlock
,
HoPerBlock
,
HoPerThread
,
HoPerThread
,
CPerThread
,
CPerThread
,
true
>
{};
true
>
{};
// LDS
// LDS
constexpr
unsigned
in_block_size
=
in_chwn_block_desc
.
GetElementSpace
();
constexpr
unsigned
in_block_size
=
in_chwn_block_desc
.
GetElementSpace
();
...
@@ -183,7 +182,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
...
@@ -183,7 +182,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
#if 1
#if 1
// weight: global mem to LDS,
// weight: global mem to LDS,
// format is [S,R,C,K], no conversion needed
// format is [S,R,C,K], no conversion needed
blockwise_wei_copy
.
r
un
(
p_wei_global
+
wei_srck_global_desc
.
Get1dIndex
(
blockwise_wei_copy
.
R
un
(
p_wei_global
+
wei_srck_global_desc
.
Get1dIndex
(
0
,
0
,
c_block_data_begin
,
k_block_data_begin
),
0
,
0
,
c_block_data_begin
,
k_block_data_begin
),
p_wei_block
);
p_wei_block
);
#endif
#endif
...
@@ -197,7 +196,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
...
@@ -197,7 +196,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
{
{
auto
f_accum
=
[](
auto
&
c
,
const
auto
&&
ab
)
{
c
+=
ab
;
};
auto
f_accum
=
[](
auto
&
c
,
const
auto
&&
ab
)
{
c
+=
ab
;
};
blockwise_batch_gemm
.
r
un
(
p_wei_block
+
wei_srck_block_desc
.
Get1dIndex
(
s
,
r
,
0
,
0
),
blockwise_batch_gemm
.
R
un
(
p_wei_block
+
wei_srck_block_desc
.
Get1dIndex
(
s
,
r
,
0
,
0
),
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_out_thread
,
p_out_thread
,
f_accum
);
f_accum
);
...
...
src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
View file @
6614729a
...
@@ -25,7 +25,9 @@ template <unsigned GridSize,
...
@@ -25,7 +25,9 @@ template <unsigned GridSize,
unsigned
InBlockCopyThreadPerDim0
,
unsigned
InBlockCopyThreadPerDim0
,
unsigned
InBlockCopyThreadPerDim1
,
unsigned
InBlockCopyThreadPerDim1
,
unsigned
WeiBlockCopyThreadPerDim0
,
unsigned
WeiBlockCopyThreadPerDim0
,
unsigned
WeiBlockCopyThreadPerDim1
>
unsigned
WeiBlockCopyThreadPerDim1
,
unsigned
InBlockCopyDataPerRead
,
unsigned
WeiBlockCopyDataPerRead
>
__global__
void
__global__
void
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw
(
InGlobalDesc
,
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw
(
InGlobalDesc
,
Float
*
const
__restrict__
p_in_global
,
Float
*
const
__restrict__
p_in_global
,
...
@@ -117,40 +119,52 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
...
@@ -117,40 +119,52 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
// formmat is [CPerBlock,BPerBlock + BGhostRead]
// formmat is [CPerBlock,BPerBlock + BGhostRead]
#if 0
#if 0
const auto blockwise_in_copy =
const auto blockwise_in_copy =
b
lockwise
_
2d
_t
ensor
_c
opy
_
1<BlockSize,
B
lockwise2d
T
ensor
C
opy1<BlockSize,
Float,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{};
decltype(in_cb_block_desc.GetLengths())>{};
#elif
0
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
()),
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
>
{};
#elif 1
#elif 1
const
auto
blockwise_in_copy
=
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
blockwise_2d_tensor_copy_2
<
BlockSize
,
Float
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
()),
decltype
(
in_cb_block_desc
.
GetLengths
()),
InBlockCopyDataPerRead
>
{};
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
>
{};
#endif
#endif
// blockwise wei copy
// blockwise wei copy
// format is [CPerBlock*S*R,KPerBlock]
// format is [CPerBlock*S*R,KPerBlock]
#if 0
#if 0
const auto blockwise_wei_copy =
const auto blockwise_wei_copy =
blockwise_2d_tensor_copy_1<BlockSize,
Blockwise2dTensorCopy1<BlockSize,
Float,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
decltype(wei_ek_block_desc.GetLengths())>{};
#elif
0
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
{};
#elif 1
#elif 1
const
auto
blockwise_wei_copy
=
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
blockwise_2d_tensor_copy_2
<
BlockSize
,
Float
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
{};
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
{};
#endif
#endif
// a series of blockwise GEMM
// a series of blockwise GEMM
...
@@ -170,18 +184,17 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
...
@@ -170,18 +184,17 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
const
auto
c_kxb_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
const
auto
c_kxb_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
BPerThread
>
{});
// constexpr doesn't compile
Number
<
KPerThread
>
{},
Number
<
BPerThread
>
{});
// constexpr doesn't compile
const
auto
blockwise_gemm
=
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadC
<
BlockSize
,
blockwise_gemm_block_a_block_b_thread_c
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxb_block_mtx_desc
),
decltype
(
b_cxb_block_mtx_desc
),
decltype
(
c_kxb_thread_mtx_desc
),
decltype
(
c_kxb_thread_mtx_desc
),
true
,
true
,
false
,
false
,
false
,
false
,
CPerThread
,
CPerThread
,
GemmThreadPerClusterRow
,
GemmThreadPerClusterRow
,
GemmThreadPerClusterColumn
,
GemmThreadPerClusterColumn
,
true
>
{};
true
>
{};
// LDS
// LDS
constexpr
unsigned
in_block_size
=
in_cb_block_desc
.
GetElementSpace
();
constexpr
unsigned
in_block_size
=
in_cb_block_desc
.
GetElementSpace
();
...
@@ -208,10 +221,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
...
@@ -208,10 +221,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
__syncthreads
())
__syncthreads
())
{
{
// input: global mem to LDS,
// input: global mem to LDS,
blockwise_in_copy
.
r
un
(
p_in_global_block_offset
,
p_in_block
);
blockwise_in_copy
.
R
un
(
p_in_global_block_offset
,
p_in_block
);
// weight: global mem to LDS,
// weight: global mem to LDS,
blockwise_wei_copy
.
r
un
(
p_wei_global_block_offset
,
p_wei_block
);
blockwise_wei_copy
.
R
un
(
p_wei_global_block_offset
,
p_wei_block
);
__syncthreads
();
__syncthreads
();
...
@@ -222,7 +235,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
...
@@ -222,7 +235,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
{
{
auto
f_accum
=
[](
auto
&
c
,
const
auto
&&
ab
)
{
c
+=
ab
;
};
auto
f_accum
=
[](
auto
&
c
,
const
auto
&&
ab
)
{
c
+=
ab
;
};
blockwise_gemm
.
r
un
(
p_wei_block
+
wei_csrk_block_desc
.
Get1dIndex
(
s
,
r
,
0
,
0
),
blockwise_gemm
.
R
un
(
p_wei_block
+
wei_csrk_block_desc
.
Get1dIndex
(
s
,
r
,
0
,
0
),
p_in_block
+
s
*
Wi
+
r
,
p_in_block
+
s
*
Wi
+
r
,
p_out_thread
,
p_out_thread
,
f_accum
);
f_accum
);
...
@@ -283,10 +296,8 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
...
@@ -283,10 +296,8 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
#endif
#endif
if
(
n_data
<
N
&&
h_data
<
Ho
&&
w_data
<
Wo
)
if
(
n_data
<
N
&&
h_data
<
Ho
&&
w_data
<
Wo
)
{
{
#if 1
p_out_global
[
out_knhw_global_desc
.
Get1dIndex
(
k_data
,
n_data
,
h_data
,
w_data
)]
=
p_out_global
[
out_knhw_global_desc
.
Get1dIndex
(
k_data
,
n_data
,
h_data
,
w_data
)]
=
p_out_thread
[
out_kb_thread_desc
.
Get1dIndex
(
k
,
b
)];
p_out_thread
[
out_kb_thread_desc
.
Get1dIndex
(
k
,
b
)];
#endif
}
}
}
}
}
}
...
...
src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline.cuh
View file @
6614729a
...
@@ -25,7 +25,9 @@ template <unsigned GridSize,
...
@@ -25,7 +25,9 @@ template <unsigned GridSize,
unsigned
InBlockCopyThreadPerDim0
,
unsigned
InBlockCopyThreadPerDim0
,
unsigned
InBlockCopyThreadPerDim1
,
unsigned
InBlockCopyThreadPerDim1
,
unsigned
WeiBlockCopyThreadPerDim0
,
unsigned
WeiBlockCopyThreadPerDim0
,
unsigned
WeiBlockCopyThreadPerDim1
>
unsigned
WeiBlockCopyThreadPerDim1
,
unsigned
InBlockCopyDataPerRead
,
unsigned
WeiBlockCopyDataPerRead
>
__global__
void
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
(
__global__
void
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
(
InGlobalDesc
,
InGlobalDesc
,
Float
*
const
__restrict__
p_in_global
,
Float
*
const
__restrict__
p_in_global
,
...
@@ -117,40 +119,52 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
...
@@ -117,40 +119,52 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
// formmat is [CPerBlock,BPerBlock + BGhostRead]
// formmat is [CPerBlock,BPerBlock + BGhostRead]
#if 0
#if 0
const auto blockwise_in_copy =
const auto blockwise_in_copy =
b
lockwise
_
2d
_t
ensor
_c
opy
_
1<BlockSize,
B
lockwise2d
T
ensor
C
opy1<BlockSize,
Float,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{};
decltype(in_cb_block_desc.GetLengths())>{};
#elif
0
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
()),
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
>
{};
#elif 1
#elif 1
const
auto
blockwise_in_copy
=
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
blockwise_2d_tensor_copy_2
<
BlockSize
,
Float
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
()),
decltype
(
in_cb_block_desc
.
GetLengths
()),
InBlockCopyDataPerRead
>
{};
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
>
{};
#endif
#endif
// blockwise wei copy
// blockwise wei copy
// format is [CPerBlock*S*R,KPerBlock]
// format is [CPerBlock*S*R,KPerBlock]
#if 0
#if 0
const auto blockwise_wei_copy =
const auto blockwise_wei_copy =
b
lockwise
_
2d
_t
ensor
_c
opy
_
1<BlockSize,
B
lockwise2d
T
ensor
C
opy1<BlockSize,
Float,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
decltype(wei_ek_block_desc.GetLengths())>{};
#elif
0
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
{};
#elif 1
#elif 1
const
auto
blockwise_wei_copy
=
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
blockwise_2d_tensor_copy_2
<
BlockSize
,
Float
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
{};
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
{};
#endif
#endif
// a series of blockwise GEMM
// a series of blockwise GEMM
...
@@ -170,18 +184,17 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
...
@@ -170,18 +184,17 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
const
auto
c_kxb_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
const
auto
c_kxb_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
BPerThread
>
{});
// constexpr doesn't compile
Number
<
KPerThread
>
{},
Number
<
BPerThread
>
{});
// constexpr doesn't compile
const
auto
blockwise_gemm
=
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadC
<
BlockSize
,
blockwise_gemm_block_a_block_b_thread_c
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxb_block_mtx_desc
),
decltype
(
b_cxb_block_mtx_desc
),
decltype
(
c_kxb_thread_mtx_desc
),
decltype
(
c_kxb_thread_mtx_desc
),
true
,
true
,
false
,
false
,
false
,
false
,
CPerThread
,
CPerThread
,
GemmThreadPerClusterRow
,
GemmThreadPerClusterRow
,
GemmThreadPerClusterColumn
,
GemmThreadPerClusterColumn
,
true
>
{};
true
>
{};
// LDS
// LDS
constexpr
unsigned
in_block_size
=
in_cb_block_desc
.
GetElementSpace
();
constexpr
unsigned
in_block_size
=
in_cb_block_desc
.
GetElementSpace
();
...
@@ -205,10 +218,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
...
@@ -205,10 +218,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
// prelog : preload data
// prelog : preload data
// input: global mem to LDS,
// input: global mem to LDS,
blockwise_in_copy
.
r
un
(
p_in_global_block_offset
,
p_in_block_0
);
blockwise_in_copy
.
R
un
(
p_in_global_block_offset
,
p_in_block_0
);
// weight: global mem to LDS,
// weight: global mem to LDS,
blockwise_wei_copy
.
r
un
(
p_wei_global_block_offset
,
p_wei_block_0
);
blockwise_wei_copy
.
R
un
(
p_wei_global_block_offset
,
p_wei_block_0
);
p_in_global_block_offset
+=
CPerBlock
*
in_cb_global_desc
.
GetStride
(
I0
);
p_in_global_block_offset
+=
CPerBlock
*
in_cb_global_desc
.
GetStride
(
I0
);
...
@@ -234,10 +247,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
...
@@ -234,10 +247,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_1
:
p_wei_block_0
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_1
:
p_wei_block_0
;
// input: global mem to LDS,
// input: global mem to LDS,
blockwise_in_copy
.
r
un
(
p_in_global_block_offset
,
p_in_block_next
);
blockwise_in_copy
.
R
un
(
p_in_global_block_offset
,
p_in_block_next
);
// weight: global mem to LDS,
// weight: global mem to LDS,
blockwise_wei_copy
.
r
un
(
p_wei_global_block_offset
,
p_wei_block_next
);
blockwise_wei_copy
.
R
un
(
p_wei_global_block_offset
,
p_wei_block_next
);
// a series of GEMM
// a series of GEMM
for
(
unsigned
s
=
0
;
s
<
S
;
++
s
)
for
(
unsigned
s
=
0
;
s
<
S
;
++
s
)
...
@@ -246,7 +259,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
...
@@ -246,7 +259,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
{
{
auto
f_accum
=
[](
auto
&
c
,
const
auto
&&
ab
)
{
c
+=
ab
;
};
auto
f_accum
=
[](
auto
&
c
,
const
auto
&&
ab
)
{
c
+=
ab
;
};
blockwise_gemm
.
r
un
(
p_wei_block_now
+
wei_csrk_block_desc
.
Get1dIndex
(
s
,
r
,
0
,
0
),
blockwise_gemm
.
R
un
(
p_wei_block_now
+
wei_csrk_block_desc
.
Get1dIndex
(
s
,
r
,
0
,
0
),
p_in_block_now
+
s
*
Wi
+
r
,
p_in_block_now
+
s
*
Wi
+
r
,
p_out_thread
,
p_out_thread
,
f_accum
);
f_accum
);
...
@@ -268,7 +281,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
...
@@ -268,7 +281,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
{
{
auto
f_accum
=
[](
auto
&
c
,
const
auto
&&
ab
)
{
c
+=
ab
;
};
auto
f_accum
=
[](
auto
&
c
,
const
auto
&&
ab
)
{
c
+=
ab
;
};
blockwise_gemm
.
r
un
(
p_wei_block_now
+
wei_csrk_block_desc
.
Get1dIndex
(
s
,
r
,
0
,
0
),
blockwise_gemm
.
R
un
(
p_wei_block_now
+
wei_csrk_block_desc
.
Get1dIndex
(
s
,
r
,
0
,
0
),
p_in_block_now
+
s
*
Wi
+
r
,
p_in_block_now
+
s
*
Wi
+
r
,
p_out_thread
,
p_out_thread
,
f_accum
);
f_accum
);
...
...
src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
View file @
6614729a
...
@@ -110,30 +110,29 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
...
@@ -110,30 +110,29 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
// formmat is [CPerBlock,BPerBlock + BGhostRead]
// formmat is [CPerBlock,BPerBlock + BGhostRead]
#if 1
#if 1
const
auto
blockwise_in_copy
=
const
auto
blockwise_in_copy
=
b
lockwise
_
2d
_t
ensor
_c
opy
_
1
<
BlockSize
,
B
lockwise2d
T
ensor
C
opy1
<
BlockSize
,
Float
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
())
>
{};
decltype
(
in_cb_block_desc
.
GetLengths
())
>
{};
#elif 1
#elif 1
const
auto
blockwise_in_copy
=
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
blockwise_2d_tensor_copy_2
<
BlockSize
,
Float
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
()),
decltype
(
in_cb_block_desc
.
GetLengths
()),
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
>
{};
InBlockCopyThreadPerDim1
>
{};
#endif
#endif
// blockwise wei copy
// blockwise wei copy
// format is [S,R,CPerBlock,KPerBlock]
// format is [S,R,CPerBlock,KPerBlock]
const
auto
blockwise_wei_copy
=
const
auto
blockwise_wei_copy
=
b
lockwise
_
4d
_t
ensor
_c
opy
_
1
<
BlockSize
,
B
lockwise4d
T
ensor
C
opy1
<
BlockSize
,
Float
,
Float
,
decltype
(
wei_srck_global_desc
),
decltype
(
wei_srck_global_desc
),
decltype
(
wei_srck_block_desc
),
decltype
(
wei_srck_block_desc
),
decltype
(
wei_srck_block_desc
.
GetLengths
())
>
{};
decltype
(
wei_srck_block_desc
.
GetLengths
())
>
{};
// a series of blockwise GEMM
// a series of blockwise GEMM
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
...
@@ -152,18 +151,17 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
...
@@ -152,18 +151,17 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
const
auto
c_kxb_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
const
auto
c_kxb_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
BPerThread
>
{});
// constexpr doesn't compile
Number
<
KPerThread
>
{},
Number
<
BPerThread
>
{});
// constexpr doesn't compile
const
auto
blockwise_gemm
=
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadC
<
BlockSize
,
blockwise_gemm_block_a_block_b_thread_c
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxb_block_mtx_desc
),
decltype
(
b_cxb_block_mtx_desc
),
decltype
(
c_kxb_thread_mtx_desc
),
decltype
(
c_kxb_thread_mtx_desc
),
true
,
true
,
false
,
false
,
false
,
false
,
CPerThread
,
CPerThread
,
GemmThreadPerClusterRow
,
GemmThreadPerClusterRow
,
GemmThreadPerClusterColumn
,
GemmThreadPerClusterColumn
,
true
>
{};
true
>
{};
// LDS
// LDS
constexpr
unsigned
in_block_size
=
in_cb_block_desc
.
GetElementSpace
();
constexpr
unsigned
in_block_size
=
in_cb_block_desc
.
GetElementSpace
();
...
@@ -191,12 +189,12 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
...
@@ -191,12 +189,12 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
{
{
#if 1
#if 1
// input: global mem to LDS,
// input: global mem to LDS,
blockwise_in_copy
.
r
un
(
p_in_global_block_offset
,
p_in_block
);
blockwise_in_copy
.
R
un
(
p_in_global_block_offset
,
p_in_block
);
#endif
#endif
#if 1
#if 1
// weight: global mem to LDS,
// weight: global mem to LDS,
blockwise_wei_copy
.
r
un
(
p_wei_global_block_offset
,
p_wei_block
);
blockwise_wei_copy
.
R
un
(
p_wei_global_block_offset
,
p_wei_block
);
#endif
#endif
__syncthreads
();
__syncthreads
();
...
@@ -209,7 +207,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
...
@@ -209,7 +207,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
{
{
auto
f_accum
=
[](
auto
&
c
,
const
auto
&&
ab
)
{
c
+=
ab
;
};
auto
f_accum
=
[](
auto
&
c
,
const
auto
&&
ab
)
{
c
+=
ab
;
};
blockwise_gemm
.
r
un
(
p_wei_block
+
wei_srck_block_desc
.
Get1dIndex
(
s
,
r
,
0
,
0
),
blockwise_gemm
.
R
un
(
p_wei_block
+
wei_srck_block_desc
.
Get1dIndex
(
s
,
r
,
0
,
0
),
p_in_block
+
s
*
Wi
+
r
,
p_in_block
+
s
*
Wi
+
r
,
p_out_thread
,
p_out_thread
,
f_accum
);
f_accum
);
...
...
src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline.cuh
View file @
6614729a
...
@@ -110,20 +110,19 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
...
@@ -110,20 +110,19 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
// formmat is [CPerBlock,BPerBlock + BGhostRead]
// formmat is [CPerBlock,BPerBlock + BGhostRead]
#if 1
#if 1
const
auto
blockwise_in_copy
=
const
auto
blockwise_in_copy
=
b
lockwise
_
2d
_t
ensor
_c
opy
_
1
<
BlockSize
,
B
lockwise2d
T
ensor
C
opy1
<
BlockSize
,
Float
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
())
>
{};
decltype
(
in_cb_block_desc
.
GetLengths
())
>
{};
#elif 1
#elif 1
const
auto
blockwise_in_copy
=
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
blockwise_2d_tensor_copy_2
<
BlockSize
,
Float
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
()),
decltype
(
in_cb_block_desc
.
GetLengths
()),
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
>
{};
InBlockCopyThreadPerDim1
>
{};
#elif 0
#elif 0
const
auto
blockwise_in_copy
=
const
auto
blockwise_in_copy
=
blockwise_2d_tensor_copy_dummy_2
<
BlockSize
,
blockwise_2d_tensor_copy_dummy_2
<
BlockSize
,
...
@@ -137,11 +136,11 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
...
@@ -137,11 +136,11 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
// format is [S,R,CPerBlock,KPerBlock]
// format is [S,R,CPerBlock,KPerBlock]
#if 1
#if 1
const
auto
blockwise_wei_copy
=
const
auto
blockwise_wei_copy
=
b
lockwise
_
4d
_t
ensor
_c
opy
_
1
<
BlockSize
,
B
lockwise4d
T
ensor
C
opy1
<
BlockSize
,
Float
,
Float
,
decltype
(
wei_srck_global_desc
),
decltype
(
wei_srck_global_desc
),
decltype
(
wei_srck_block_desc
),
decltype
(
wei_srck_block_desc
),
decltype
(
wei_srck_block_desc
.
GetLengths
())
>
{};
decltype
(
wei_srck_block_desc
.
GetLengths
())
>
{};
#else
#else
const
auto
blockwise_wei_copy
=
const
auto
blockwise_wei_copy
=
blockwise_4d_tensor_copy_dummy
<
BlockSize
,
blockwise_4d_tensor_copy_dummy
<
BlockSize
,
...
@@ -168,18 +167,17 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
...
@@ -168,18 +167,17 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
const
auto
c_kxb_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
const
auto
c_kxb_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
BPerThread
>
{});
// constexpr doesn't compile
Number
<
KPerThread
>
{},
Number
<
BPerThread
>
{});
// constexpr doesn't compile
const
auto
blockwise_gemm
=
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadC
<
BlockSize
,
blockwise_gemm_block_a_block_b_thread_c
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxb_block_mtx_desc
),
decltype
(
b_cxb_block_mtx_desc
),
decltype
(
c_kxb_thread_mtx_desc
),
decltype
(
c_kxb_thread_mtx_desc
),
true
,
true
,
false
,
false
,
false
,
false
,
CPerThread
,
CPerThread
,
GemmRowThreadPerCluster
,
GemmRowThreadPerCluster
,
GemmColumnThreadPerCluster
,
GemmColumnThreadPerCluster
,
true
>
{};
true
>
{};
// LDS
// LDS
constexpr
unsigned
in_block_size
=
in_cb_block_desc
.
GetElementSpace
();
constexpr
unsigned
in_block_size
=
in_cb_block_desc
.
GetElementSpace
();
...
@@ -201,13 +199,13 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
...
@@ -201,13 +199,13 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
// prelog: load data
// prelog: load data
#if 1
#if 1
// input: global mem to LDS,
// input: global mem to LDS,
blockwise_in_copy
.
r
un
(
p_in_global
+
in_cb_global_desc
.
Get1dIndex
(
0
,
b_block_data_begin
),
blockwise_in_copy
.
R
un
(
p_in_global
+
in_cb_global_desc
.
Get1dIndex
(
0
,
b_block_data_begin
),
p_in_block_0
);
p_in_block_0
);
#endif
#endif
#if 1
#if 1
// weight: global mem to LDS,
// weight: global mem to LDS,
blockwise_wei_copy
.
r
un
(
blockwise_wei_copy
.
R
un
(
p_wei_global
+
wei_srck_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
),
p_wei_block_0
);
p_wei_global
+
wei_srck_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
),
p_wei_block_0
);
#endif
#endif
...
@@ -227,14 +225,14 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
...
@@ -227,14 +225,14 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
#if 1
#if 1
// preload next data
// preload next data
// input: global mem to LDS,
// input: global mem to LDS,
blockwise_in_copy
.
r
un
(
p_in_global
+
in_cb_global_desc
.
Get1dIndex
(
blockwise_in_copy
.
R
un
(
p_in_global
+
in_cb_global_desc
.
Get1dIndex
(
c_block_data_begin
+
CPerBlock
,
b_block_data_begin
),
c_block_data_begin
+
CPerBlock
,
b_block_data_begin
),
p_in_block_next
);
p_in_block_next
);
#endif
#endif
#if 1
#if 1
// weight: global mem to LDS,
// weight: global mem to LDS,
blockwise_wei_copy
.
r
un
(
p_wei_global
+
blockwise_wei_copy
.
R
un
(
p_wei_global
+
wei_srck_global_desc
.
Get1dIndex
(
wei_srck_global_desc
.
Get1dIndex
(
0
,
0
,
c_block_data_begin
+
CPerBlock
,
k_block_data_begin
),
0
,
0
,
c_block_data_begin
+
CPerBlock
,
k_block_data_begin
),
p_wei_block_next
);
p_wei_block_next
);
...
@@ -247,7 +245,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
...
@@ -247,7 +245,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
{
{
auto
f_accum
=
[](
auto
&
c
,
const
auto
&&
ab
)
{
c
+=
ab
;
};
auto
f_accum
=
[](
auto
&
c
,
const
auto
&&
ab
)
{
c
+=
ab
;
};
blockwise_gemm
.
r
un
(
p_wei_block_now
+
wei_srck_block_desc
.
Get1dIndex
(
s
,
r
,
0
,
0
),
blockwise_gemm
.
R
un
(
p_wei_block_now
+
wei_srck_block_desc
.
Get1dIndex
(
s
,
r
,
0
,
0
),
p_in_block_now
+
s
*
Wi
+
r
,
p_in_block_now
+
s
*
Wi
+
r
,
p_out_thread
,
p_out_thread
,
f_accum
);
f_accum
);
...
@@ -269,7 +267,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
...
@@ -269,7 +267,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
{
{
auto
f_accum
=
[](
auto
&
c
,
const
auto
&&
ab
)
{
c
+=
ab
;
};
auto
f_accum
=
[](
auto
&
c
,
const
auto
&&
ab
)
{
c
+=
ab
;
};
blockwise_gemm
.
r
un
(
p_wei_block_now
+
wei_srck_block_desc
.
Get1dIndex
(
s
,
r
,
0
,
0
),
blockwise_gemm
.
R
un
(
p_wei_block_now
+
wei_srck_block_desc
.
Get1dIndex
(
s
,
r
,
0
,
0
),
p_in_block_now
+
s
*
Wi
+
r
,
p_in_block_now
+
s
*
Wi
+
r
,
p_out_thread
,
p_out_thread
,
f_accum
);
f_accum
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment