Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3439e4b5
"...composable_kernel-1.git" did not exist on "451f1e3d653ddfdfc09983e345a72791d5d935c3"
Commit
3439e4b5
authored
Jan 25, 2019
by
Chao Liu
Browse files
padding works (sort of), but code looks ugly. Tuned some resnet configs
parent
8bd6ea1a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
785 additions
and
20 deletions
+785
-20
driver/conv.cu
driver/conv.cu
+109
-18
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh
...plicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh
+245
-0
src/include/blockwise_4d_tensor_op.cuh
src/include/blockwise_4d_tensor_op.cuh
+127
-0
src/include/conv_common.cuh
src/include/conv_common.cuh
+39
-2
src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh
...plicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh
+265
-0
No files found.
driver/conv.cu
View file @
3439e4b5
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh"
#include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh"
#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh"
#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh"
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh"
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh"
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh"
#include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh"
#include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh"
//#include "device_winograd_convolution.cuh"
//#include "device_winograd_convolution.cuh"
...
@@ -107,20 +108,31 @@ auto make_TensorDescriptor(TConstTensorDesc)
...
@@ -107,20 +108,31 @@ auto make_TensorDescriptor(TConstTensorDesc)
return
TensorDescriptor
(
lengths
,
strides
);
return
TensorDescriptor
(
lengths
,
strides
);
}
}
template
<
class
T
>
template
<
class
T
,
class
LowerPads
,
class
UpperPads
>
void
host_direct_convolution
(
const
Tensor
<
T
>&
in_nchw
,
const
Tensor
<
T
>&
wei_kcsr
,
Tensor
<
T
>&
out
)
void
host_direct_convolution
(
const
Tensor
<
T
>&
in_nchw
,
const
Tensor
<
T
>&
wei_kcsr
,
Tensor
<
T
>&
out
,
LowerPads
,
UpperPads
)
{
{
unsigned
h_pad_low
=
LowerPads
{}.
Get
(
Number
<
0
>
{});
unsigned
w_pad_low
=
LowerPads
{}.
Get
(
Number
<
1
>
{});
unsigned
h_pad_up
=
UpperPads
{}.
Get
(
Number
<
0
>
{});
unsigned
w_pad_up
=
UpperPads
{}.
Get
(
Number
<
1
>
{});
auto
f
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
auto
f
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
double
v
=
0
;
double
v
=
0
;
for
(
int
c
=
0
;
c
<
wei_kcsr
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
int
c
=
0
;
c
<
wei_kcsr
.
mDesc
.
GetLengths
()[
1
];
++
c
)
{
{
for
(
int
y
=
0
;
y
<
wei_kcsr
.
mDesc
.
GetLengths
()[
2
];
++
y
)
for
(
int
y
=
0
;
y
<
wei_kcsr
.
mDesc
.
GetLengths
()[
2
];
++
y
)
{
{
int
hi
=
ho
+
y
;
int
hi
=
ho
+
y
-
h_pad_low
;
for
(
int
x
=
0
;
x
<
wei_kcsr
.
mDesc
.
GetLengths
()[
3
];
++
x
)
for
(
int
x
=
0
;
x
<
wei_kcsr
.
mDesc
.
GetLengths
()[
3
];
++
x
)
{
{
int
wi
=
wo
+
x
;
int
wi
=
wo
+
x
-
w_pad_low
;
v
+=
in_nchw
(
n
,
c
,
hi
,
wi
)
*
wei_kcsr
(
k
,
c
,
y
,
x
);
if
(
hi
>=
0
&&
hi
<
in_nchw
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
wi
<
in_nchw
.
mDesc
.
GetLengths
()[
3
])
{
v
+=
in_nchw
(
n
,
c
,
hi
,
wi
)
*
wei_kcsr
(
k
,
c
,
y
,
x
);
}
}
}
}
}
}
}
...
@@ -136,10 +148,9 @@ void host_direct_convolution(const Tensor<T>& in_nchw, const Tensor<T>& wei_kcsr
...
@@ -136,10 +148,9 @@ void host_direct_convolution(const Tensor<T>& in_nchw, const Tensor<T>& wei_kcsr
f_par
(
std
::
thread
::
hardware_concurrency
());
f_par
(
std
::
thread
::
hardware_concurrency
());
}
}
template
<
class
T
>
template
<
class
T
,
class
LowerPads
,
class
UpperPads
>
void
host_winograd_3x3_convolution
(
const
Tensor
<
T
>&
in_nchw
,
void
host_winograd_3x3_convolution
(
const
Tensor
<
T
>&
wei_kcsr
,
const
Tensor
<
T
>&
in_nchw
,
const
Tensor
<
T
>&
wei_kcsr
,
Tensor
<
T
>&
out
,
LowerPads
,
UpperPads
)
Tensor
<
T
>&
out
)
{
{
constexpr
std
::
size_t
OutTileSizeH
=
2
;
constexpr
std
::
size_t
OutTileSizeH
=
2
;
constexpr
std
::
size_t
OutTileSizeW
=
2
;
constexpr
std
::
size_t
OutTileSizeW
=
2
;
...
@@ -156,6 +167,12 @@ void host_winograd_3x3_convolution(const Tensor<T>& in_nchw,
...
@@ -156,6 +167,12 @@ void host_winograd_3x3_convolution(const Tensor<T>& in_nchw,
std
::
size_t
HO
=
out
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
HO
=
out
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
WO
=
out
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
WO
=
out
.
mDesc
.
GetLengths
()[
3
];
unsigned
h_pad_low
=
LowerPads
{}.
Get
(
Number
<
0
>
{});
unsigned
w_pad_low
=
LowerPads
{}.
Get
(
Number
<
1
>
{});
unsigned
h_pad_up
=
UpperPads
{}.
Get
(
Number
<
0
>
{});
unsigned
w_pad_up
=
UpperPads
{}.
Get
(
Number
<
1
>
{});
std
::
size_t
InTileSizeH
=
OutTileSizeH
+
S
-
1
;
std
::
size_t
InTileSizeH
=
OutTileSizeH
+
S
-
1
;
std
::
size_t
InTileSizeW
=
OutTileSizeW
+
R
-
1
;
std
::
size_t
InTileSizeW
=
OutTileSizeW
+
R
-
1
;
...
@@ -171,11 +188,20 @@ void host_winograd_3x3_convolution(const Tensor<T>& in_nchw,
...
@@ -171,11 +188,20 @@ void host_winograd_3x3_convolution(const Tensor<T>& in_nchw,
auto
f_in_hold
=
[
&
](
auto
n
,
auto
c
,
auto
y
,
auto
x
)
{
auto
f_in_hold
=
[
&
](
auto
n
,
auto
c
,
auto
y
,
auto
x
)
{
for
(
int
j
=
0
;
j
<
InTileSizeH
;
++
j
)
for
(
int
j
=
0
;
j
<
InTileSizeH
;
++
j
)
{
{
std
::
size_
t
hi
=
OutTileSizeH
*
y
+
j
;
in
t
hi
=
OutTileSizeH
*
y
+
j
-
h_pad_low
;
for
(
int
i
=
0
;
i
<
InTileSizeW
;
++
i
)
for
(
int
i
=
0
;
i
<
InTileSizeW
;
++
i
)
{
{
std
::
size_t
wi
=
OutTileSizeW
*
x
+
i
;
int
wi
=
OutTileSizeW
*
x
+
i
-
w_pad_low
;
in_hold
(
n
,
c
,
y
,
x
,
j
,
i
)
=
in_nchw
(
n
,
c
,
hi
,
wi
);
if
(
hi
>=
0
&&
hi
<
in_nchw
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
wi
<
in_nchw
.
mDesc
.
GetLengths
()[
3
])
{
in_hold
(
n
,
c
,
y
,
x
,
j
,
i
)
=
in_nchw
(
n
,
c
,
hi
,
wi
);
}
else
{
in_hold
(
n
,
c
,
y
,
x
,
j
,
i
)
=
T
(
0
);
}
}
}
}
}
};
};
...
@@ -406,7 +432,7 @@ int main()
...
@@ -406,7 +432,7 @@ int main()
constexpr
unsigned
K
=
64
;
constexpr
unsigned
K
=
64
;
constexpr
unsigned
S
=
7
;
constexpr
unsigned
S
=
7
;
constexpr
unsigned
R
=
7
;
constexpr
unsigned
R
=
7
;
#elif
1
#elif
0
// 3x3, 58x58
// 3x3, 58x58
constexpr
unsigned
N
=
16
;
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
128
;
constexpr
unsigned
C
=
128
;
...
@@ -415,12 +441,63 @@ int main()
...
@@ -415,12 +441,63 @@ int main()
constexpr
unsigned
K
=
256
;
constexpr
unsigned
K
=
256
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
constexpr
unsigned
R
=
3
;
#elif 0
// 3x3 filter, 58x58 image, 0x0 padding
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
128
;
constexpr
unsigned
HI
=
58
;
constexpr
unsigned
WI
=
58
;
constexpr
unsigned
K
=
256
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
constexpr
unsigned
HPad
=
0
;
constexpr
unsigned
WPad
=
0
;
#elif 1
// 3x3 filter, 56x56 image, 1x1 padding
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
128
;
constexpr
unsigned
HI
=
56
;
constexpr
unsigned
WI
=
56
;
constexpr
unsigned
K
=
256
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
constexpr
unsigned
HPad
=
1
;
constexpr
unsigned
WPad
=
1
;
#elif 0
// 3x3 filter, 28x28 image, 1x1 padding
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
256
;
constexpr
unsigned
HI
=
28
;
constexpr
unsigned
WI
=
28
;
constexpr
unsigned
K
=
512
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
constexpr
unsigned
HPad
=
1
;
constexpr
unsigned
WPad
=
1
;
#elif 0
// 3x3 filter, 20x84 image, 1x1 padding
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
256
;
constexpr
unsigned
HI
=
20
;
constexpr
unsigned
WI
=
84
;
constexpr
unsigned
K
=
256
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
constexpr
unsigned
HPad
=
1
;
constexpr
unsigned
WPad
=
1
;
#endif
#endif
auto
lower_pads
=
Sequence
<
HPad
,
WPad
>
{};
auto
upper_pads
=
Sequence
<
HPad
,
WPad
>
{};
auto
in_nchw_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
N
,
C
,
HI
,
WI
>
{});
auto
in_nchw_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
N
,
C
,
HI
,
WI
>
{});
auto
wei_kcsr_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
C
,
S
,
R
>
{});
auto
wei_kcsr_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
C
,
S
,
R
>
{});
auto
out_nkhw_desc
=
auto
out_nkhw_desc
=
get_convolution_with_padding_output_default_4d_tensor_descriptor
(
get_convolution_output_default_4d_tensor_descriptor
(
in_nchw_desc
,
wei_kcsr_desc
);
in_nchw_desc
,
wei_kcsr_desc
,
lower_pads
,
upper_pads
);
ostream_ConstantTensorDescriptor
(
in_nchw_desc
,
std
::
cout
<<
"in_nchw_desc: "
);
ostream_ConstantTensorDescriptor
(
in_nchw_desc
,
std
::
cout
<<
"in_nchw_desc: "
);
ostream_ConstantTensorDescriptor
(
wei_kcsr_desc
,
std
::
cout
<<
"wei_kcsr_desc: "
);
ostream_ConstantTensorDescriptor
(
wei_kcsr_desc
,
std
::
cout
<<
"wei_kcsr_desc: "
);
...
@@ -443,6 +520,7 @@ int main()
...
@@ -443,6 +520,7 @@ int main()
unsigned
nrepeat
=
50
;
unsigned
nrepeat
=
50
;
#if 0
#if 0
#if 0
device_direct_convolution_1
device_direct_convolution_1
#elif 0
#elif 0
...
@@ -451,7 +529,7 @@ int main()
...
@@ -451,7 +529,7 @@ int main()
device_implicit_gemm_convolution_1_nchw_kcsr
device_implicit_gemm_convolution_1_nchw_kcsr
#elif 0
#elif 0
device_implicit_gemm_convolution_1_nchw_srck_nkhw
device_implicit_gemm_convolution_1_nchw_srck_nkhw
#elif
1
#elif
0
device_implicit_gemm_convolution_1_chwn_csrk_khwn
device_implicit_gemm_convolution_1_chwn_csrk_khwn
#elif 0
#elif 0
device_implicit_gemm_convolution_2_cnhw_srck_knhw
device_implicit_gemm_convolution_2_cnhw_srck_knhw
...
@@ -459,15 +537,28 @@ int main()
...
@@ -459,15 +537,28 @@ int main()
device_winograd_convolution
device_winograd_convolution
#endif
#endif
(
in_nchw_desc
,
in_nchw
,
wei_kcsr_desc
,
wei_kcsr
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
(
in_nchw_desc
,
in_nchw
,
wei_kcsr_desc
,
wei_kcsr
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
#endif
#if 1
device_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding
(
in_nchw_desc
,
in_nchw
,
wei_kcsr_desc
,
wei_kcsr
,
out_nkhw_desc
,
out_nkhw_device
,
lower_pads
,
upper_pads
,
nrepeat
);
#endif
#if 1
#if 1
if
(
S
==
3
&&
R
==
3
)
if
(
S
==
3
&&
R
==
3
)
{
{
host_winograd_3x3_convolution
(
in_nchw
,
wei_kcsr
,
out_nkhw_host
);
host_winograd_3x3_convolution
(
in_nchw
,
wei_kcsr
,
out_nkhw_host
,
lower_pads
,
upper_pads
);
}
}
else
else
{
{
host_direct_convolution
(
in_nchw
,
wei_kcsr
,
out_nkhw_host
);
host_direct_convolution
(
in_nchw
,
wei_kcsr
,
out_nkhw_host
,
lower_pads
,
upper_pads
);
}
}
check_error
(
out_nkhw_host
,
out_nkhw_device
);
check_error
(
out_nkhw_host
,
out_nkhw_device
);
#endif
#endif
...
...
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh
0 → 100644
View file @
3439e4b5
#pragma once
#include "gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh"
#include <unistd.h>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
LowerPads
,
class
UpperPads
>
void
device_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcsr
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
LowerPads
,
UpperPads
,
unsigned
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
InDesc
{};
constexpr
auto
wei_kcsr_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
unsigned
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
unsigned
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
unsigned
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
unsigned
K
=
wei_kcsr_desc
.
GetLength
(
I0
);
constexpr
unsigned
C
=
wei_kcsr_desc
.
GetLength
(
I1
);
constexpr
unsigned
S
=
wei_kcsr_desc
.
GetLength
(
I2
);
constexpr
unsigned
R
=
wei_kcsr_desc
.
GetLength
(
I3
);
// reorder weight
auto
wei_csrk_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
S
,
R
,
K
>
{});
ostream_ConstantTensorDescriptor
(
wei_csrk_desc
,
std
::
cout
<<
"wei_csrk_desc: "
);
Tensor
<
T
>
wei_csrk
(
make_TensorDescriptor
(
wei_csrk_desc
));
auto
f_reorder_kcsr2csrk
=
[
&
](
auto
k
,
auto
c
,
auto
s
,
auto
r
)
{
wei_csrk
(
c
,
s
,
r
,
k
)
=
wei_kcsr
(
k
,
c
,
s
,
r
);
};
make_ParallelTensorFunctor
(
f_reorder_kcsr2csrk
,
K
,
C
,
S
,
R
)(
std
::
thread
::
hardware_concurrency
());
// reorder input
auto
in_chwn_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
Hi
,
Wi
,
N
>
{});
ostream_ConstantTensorDescriptor
(
in_chwn_desc
,
std
::
cout
<<
"in_chwn_desc: "
);
Tensor
<
T
>
in_chwn
(
make_TensorDescriptor
(
in_chwn_desc
));
auto
f_reorder_nchw2chwn
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
in_chwn
(
c
,
hi
,
wi
,
n
)
=
in_nchw
(
n
,
c
,
hi
,
wi
);
};
make_ParallelTensorFunctor
(
f_reorder_nchw2chwn
,
N
,
C
,
Hi
,
Wi
)(
std
::
thread
::
hardware_concurrency
());
// output
auto
out_khwn_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
Ho
,
Wo
,
N
>
{});
ostream_ConstantTensorDescriptor
(
out_khwn_desc
,
std
::
cout
<<
"out_khwn_desc: "
);
Tensor
<
T
>
out_khwn
(
make_TensorDescriptor
(
out_khwn_desc
));
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_chwn_device_buf
(
data_sz
*
in_chwn
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_csrk_device_buf
(
data_sz
*
wei_csrk
.
mDesc
.
GetElementSpace
());
DeviceMem
out_khwn_device_buf
(
data_sz
*
out_khwn
.
mDesc
.
GetElementSpace
());
in_chwn_device_buf
.
ToDevice
(
in_chwn
.
mData
.
data
());
wei_csrk_device_buf
.
ToDevice
(
wei_csrk
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
#if 0
constexpr unsigned NPerBlock = 1;
constexpr unsigned KPerBlock = 1;
constexpr unsigned CPerBlock = 1;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 1;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 8;
#elif
0
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
constexpr
unsigned
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
#elif 0
// 3x3 58x58, NKC = 16,256,128
constexpr
unsigned
NPerBlock
=
8
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
#elif 0
// for 5x5, 36x36
constexpr
unsigned
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
#elif 0
// for 7x7, 38x38
constexpr
unsigned
NPerBlock
=
8
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
#elif 0
// for 3x3, 56x56
constexpr
unsigned
NPerBlock
=
32
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
2
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
#elif 1
// 3x3 56x56, NKC = 16,256,128, with padding
// 3x3 28x28, NKC = 16,512,256, with padding
// 3x3 20x84, NKC = 16,256,256, with padding
constexpr
unsigned
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
#endif
constexpr
unsigned
GridSize
=
((
N
+
NPerBlock
-
1
)
/
NPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
)
*
((
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
)
*
((
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
);
dim3
block_dim
(
BlockSize
);
dim3
grid_dim
(
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
{
cudaEvent_t
start
,
stop
;
float
elapsedTime
;
cudaEventCreate
(
&
start
);
cudaEventRecord
(
start
,
0
);
gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_chwn_desc
),
decltype
(
wei_csrk_desc
),
decltype
(
out_khwn_desc
),
LowerPads
,
UpperPads
,
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
CPerThread
,
HoPerThread
,
WoPerThread
>
<<<
grid_dim
,
block_dim
>>>
(
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_csrk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
cudaEventCreate
(
&
stop
);
cudaEventRecord
(
stop
,
0
);
cudaEventSynchronize
(
stop
);
cudaEventElapsedTime
(
&
elapsedTime
,
start
,
stop
);
printf
(
"Elapsed time : %f ms
\n
"
,
elapsedTime
);
usleep
(
10000
);
}
checkCudaErrors
(
cudaGetLastError
());
out_khwn_device_buf
.
FromDevice
(
out_khwn
.
mData
.
data
());
// reorder output
auto
f_reorder_khwn2nkhw
=
[
&
](
auto
k
,
auto
ho
,
auto
wo
,
auto
n
)
{
out_nkhw
(
n
,
k
,
ho
,
wo
)
=
out_khwn
(
k
,
ho
,
wo
,
n
);
};
make_ParallelTensorFunctor
(
f_reorder_khwn2nkhw
,
K
,
Ho
,
Wo
,
N
)(
std
::
thread
::
hardware_concurrency
());
}
src/include/blockwise_4d_tensor_op.cuh
View file @
3439e4b5
...
@@ -211,6 +211,133 @@ struct blockwise_4d_tensor_copy_1
...
@@ -211,6 +211,133 @@ struct blockwise_4d_tensor_copy_1
}
}
};
};
template
<
unsigned
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
DstOpLengths
,
class
GlobalLowerPads
>
struct
blockwise_chwn_tensor_copy_with_padding
{
__device__
void
run
(
Float
*
const
__restrict__
p_src
,
unsigned
c_block_data_begin
,
unsigned
ho_block_data_begin
,
unsigned
wo_block_data_begin
,
unsigned
n_block_data_begin
,
Float
*
__restrict__
p_dst
,
unsigned
h_block_pad_low
,
unsigned
w_block_pad_low
,
unsigned
h_block_pad_up
,
unsigned
w_block_pad_up
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
DstOpLengths
{});
constexpr
auto
h_global_pad_low
=
GlobalLowerPads
{}.
Get
(
I0
);
constexpr
auto
w_global_pad_low
=
GlobalLowerPads
{}.
Get
(
I1
);
constexpr
unsigned
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
Float
*
const
p_src_tmp
=
p_src
+
src_desc
.
Get1dIndex
(
c_block_data_begin
,
(
ho_block_data_begin
+
h_block_pad_low
)
-
h_global_pad_low
,
(
wo_block_data_begin
+
w_block_pad_low
)
-
w_global_pad_low
,
n_block_data_begin
);
#if 0
if(get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(src_desc, "src_desc: ");
print_ConstantTensorDescriptor(dst_desc, "dst_desc: ");
print_ConstantTensorDescriptor(ref_desc, "ref_desc: ");
printf("%u %u, \t"
"h_global_pad_low %u w_global_pad_low %u \t"
"h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u \t"
"\n",
get_block_1d_id(),
get_thread_local_1d_id(),
h_global_pad_low,
w_global_pad_low,
h_block_pad_low,
w_block_pad_low,
h_block_pad_up,
w_block_pad_up);
}
#endif
for
(
unsigned
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
unsigned
is
=
threadIdx
.
x
+
iloop
*
BlockSize
;
unsigned
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
is
-=
did
[
2
]
*
ref_desc
.
GetStride
(
I2
);
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
unsigned
bindex
=
dst_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
p_dst
[
bindex
]
=
(
did
[
1
]
<
h_block_pad_low
||
did
[
1
]
+
h_block_pad_up
>=
ref_desc
.
GetLength
(
I1
)
||
did
[
2
]
<
w_block_pad_low
||
did
[
2
]
+
w_block_pad_up
>=
ref_desc
.
GetLength
(
I2
))
?
Float
(
0
)
:
p_src_tmp
[
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
])];
}
constexpr
bool
has_tail
=
(
ref_desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
unsigned
is
=
threadIdx
.
x
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
unsigned
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
is
-=
did
[
2
]
*
ref_desc
.
GetStride
(
I2
);
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
unsigned
bindex
=
dst_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
p_dst
[
bindex
]
=
(
did
[
1
]
<
h_block_pad_low
||
did
[
1
]
+
h_block_pad_up
>=
ref_desc
.
GetLength
(
I1
)
||
did
[
2
]
<
w_block_pad_low
||
did
[
2
]
+
w_block_pad_up
>=
ref_desc
.
GetLength
(
I2
))
?
Float
(
0
)
:
p_src_tmp
[
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
])];
}
}
}
};
template
<
unsigned
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
template
<
unsigned
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
struct
blockwise_4d_tensor_copy_dummy
struct
blockwise_4d_tensor_copy_dummy
{
{
...
...
src/include/conv_common.cuh
View file @
3439e4b5
...
@@ -27,8 +27,45 @@ __host__ __device__ constexpr auto get_convolution_output_default_4d_tensor_desc
...
@@ -27,8 +27,45 @@ __host__ __device__ constexpr auto get_convolution_output_default_4d_tensor_desc
constexpr
auto
S
=
wei_desc
.
GetLength
(
I2
);
constexpr
auto
S
=
wei_desc
.
GetLength
(
I2
);
constexpr
auto
R
=
wei_desc
.
GetLength
(
I3
);
constexpr
auto
R
=
wei_desc
.
GetLength
(
I3
);
constexpr
auto
HO
=
HI
-
S
+
1
;
constexpr
auto
HO
=
HI
+
1
-
S
;
constexpr
auto
WO
=
WI
-
R
+
1
;
constexpr
auto
WO
=
WI
+
1
-
R
;
return
make_ConstantTensorDescriptor
(
Sequence
<
N
,
K
,
HO
,
WO
>
{});
}
template
<
class
InDesc
,
class
WeiDesc
,
class
LowerPads
,
class
UpperPads
>
__host__
__device__
constexpr
auto
get_convolution_with_padding_output_default_4d_tensor_descriptor
(
InDesc
,
WeiDesc
,
LowerPads
,
UpperPads
)
{
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
static_assert
(
in_desc
.
GetDimension
()
==
4
,
"input nDim is not 4"
);
static_assert
(
wei_desc
.
GetDimension
()
==
4
,
"weight nDim is not 4"
);
static_assert
(
in_desc
.
GetLength
(
I1
)
==
wei_desc
.
GetLength
(
I1
),
"input & weight dimension not consistent"
);
constexpr
auto
N
=
in_desc
.
GetLength
(
I0
);
constexpr
auto
HI
=
in_desc
.
GetLength
(
I2
);
constexpr
auto
WI
=
in_desc
.
GetLength
(
I3
);
constexpr
auto
K
=
wei_desc
.
GetLength
(
I0
);
constexpr
auto
S
=
wei_desc
.
GetLength
(
I2
);
constexpr
auto
R
=
wei_desc
.
GetLength
(
I3
);
constexpr
auto
HPadLow
=
LowerPads
{}.
Get
(
I0
);
constexpr
auto
WPadLow
=
LowerPads
{}.
Get
(
I1
);
constexpr
auto
HPadUp
=
UpperPads
{}.
Get
(
I0
);
constexpr
auto
WPadUp
=
UpperPads
{}.
Get
(
I1
);
constexpr
auto
HO
=
HI
+
HPadLow
+
HPadUp
+
1
-
S
;
constexpr
auto
WO
=
WI
+
WPadLow
+
WPadUp
+
1
-
R
;
return
make_ConstantTensorDescriptor
(
Sequence
<
N
,
K
,
HO
,
WO
>
{});
return
make_ConstantTensorDescriptor
(
Sequence
<
N
,
K
,
HO
,
WO
>
{});
}
}
src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh
0 → 100644
View file @
3439e4b5
#pragma once
#include "common.cuh"
#include "ConstantTensorDescriptor.cuh"
#include "ConstantMatrixDescriptor.cuh"
#include "blockwise_4d_tensor_op.cuh"
#include "threadwise_4d_tensor_op.cuh"
#include "gemm.cuh"
template
<
unsigned
GridSize
,
unsigned
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
LowerPads
,
class
UpperPads
,
unsigned
NPerBlock
,
unsigned
KPerBlock
,
unsigned
CPerBlock
,
unsigned
HoPerBlock
,
unsigned
WoPerBlock
,
unsigned
NPerThread
,
unsigned
KPerThread
,
unsigned
CPerThread
,
unsigned
HoPerThread
,
unsigned
WoPerThread
>
__global__
void
gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding
(
Float
*
const
__restrict__
p_in_global
,
Float
*
const
__restrict__
p_wei_global
,
Float
*
__restrict__
p_out_global
)
{
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
static_assert
(
NPerBlock
%
NPerThread
==
0
,
"wrong! NPerBlock % NPerThread !=0"
);
static_assert
((
NPerThread
<
NPerBlock
&&
WoPerThread
==
1
)
||
NPerThread
==
NPerBlock
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_csrk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
constexpr
unsigned
C
=
in_chwn_global_desc
.
GetLength
(
I0
);
constexpr
unsigned
K
=
out_khwn_global_desc
.
GetLength
(
I0
);
constexpr
unsigned
Ho
=
out_khwn_global_desc
.
GetLength
(
I1
);
constexpr
unsigned
Wo
=
out_khwn_global_desc
.
GetLength
(
I2
);
constexpr
unsigned
N
=
out_khwn_global_desc
.
GetLength
(
I3
);
constexpr
unsigned
S
=
wei_csrk_global_desc
.
GetLength
(
I1
);
constexpr
unsigned
R
=
wei_csrk_global_desc
.
GetLength
(
I2
);
constexpr
unsigned
HPadLow
=
LowerPads
{}.
Get
(
I0
);
constexpr
unsigned
WPadLow
=
LowerPads
{}.
Get
(
I1
);
constexpr
unsigned
HPadUp
=
UpperPads
{}.
Get
(
I0
);
constexpr
unsigned
WPadUp
=
UpperPads
{}.
Get
(
I1
);
constexpr
unsigned
HiPerBlock
=
HoPerBlock
+
S
-
1
;
constexpr
unsigned
WiPerBlock
=
WoPerBlock
+
R
-
1
;
// divide block work: [K, Ho, Wo, N]
constexpr
unsigned
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
unsigned
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
unsigned
WBlockWork
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
;
constexpr
unsigned
NBlockWork
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
unsigned
k_block_work_id
=
get_block_1d_id
()
/
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
unsigned
itmp
=
get_block_1d_id
()
-
k_block_work_id
*
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
const
unsigned
h_block_work_id
=
itmp
/
(
WBlockWork
*
NBlockWork
);
itmp
-=
h_block_work_id
*
(
WBlockWork
*
NBlockWork
);
const
unsigned
w_block_work_id
=
itmp
/
NBlockWork
;
const
unsigned
n_block_work_id
=
itmp
-
w_block_work_id
*
NBlockWork
;
const
unsigned
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
unsigned
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
unsigned
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
unsigned
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
// tensor view of blockwise input and weight in LDS
constexpr
auto
in_chwn_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{});
constexpr
auto
wei_csrk_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
CPerBlock
,
S
,
R
,
KPerBlock
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_hkwn_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
HoPerThread
,
KPerThread
,
WoPerThread
,
NPerThread
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc");
print_ConstantTensorDescriptor(wei_csrk_block_desc, "wei_csrk_block_desc");
print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc");
}
#endif
// blockwise copy
// input: format is [C, Hi, Wi, N]
const
unsigned
h_block_pad_low
=
h_block_work_id
==
0
?
HPadLow
:
0
;
const
unsigned
w_block_pad_low
=
w_block_work_id
==
0
?
WPadLow
:
0
;
const
unsigned
h_block_pad_up
=
h_block_work_id
==
HBlockWork
-
1
?
HPadUp
:
0
;
const
unsigned
w_block_pad_up
=
w_block_work_id
==
WBlockWork
-
1
?
WPadUp
:
0
;
#if 0
if(get_thread_local_1d_id() == 0)
;
{
printf(
"%u %u, h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
h_block_pad_low,
w_block_pad_low,
h_block_pad_up,
w_block_pad_up);
}
#endif
constexpr
auto
blockwise_in_copy
=
blockwise_chwn_tensor_copy_with_padding
<
BlockSize
,
Float
,
decltype
(
in_chwn_global_desc
),
decltype
(
in_chwn_block_desc
),
decltype
(
in_chwn_block_desc
.
GetLengths
()),
LowerPads
>
{};
// weight: format is [S,R,C,K]
constexpr
auto
blockwise_wei_copy
=
blockwise_4d_tensor_copy_1
<
BlockSize
,
Float
,
decltype
(
wei_csrk_global_desc
),
decltype
(
wei_csrk_block_desc
),
decltype
(
wei_csrk_block_desc
.
GetLengths
())
>
{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[S,R,C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N]
const
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_csrk_block_desc
.
GetStride
(
I0
)
>
{});
// constexpr doesn't compile
const
auto
b_cxwn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{},
Number
<
in_chwn_block_desc
.
GetStride
(
I0
)
>
{});
// constexpr doesn't compile
const
auto
c_kxwn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{});
// constexpr doesn't compile
const
auto
blockwise_batch_gemm
=
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
true
,
false
,
false
,
0
,
in_chwn_block_desc
.
GetStride
(
I1
),
out_hkwn_thread_desc
.
GetStride
(
I0
),
HoPerBlock
,
HoPerThread
,
CPerThread
,
true
>
{};
// LDS
constexpr
unsigned
in_block_size
=
in_chwn_block_desc
.
GetElementSpace
();
constexpr
unsigned
wei_block_size
=
wei_csrk_block_desc
.
GetElementSpace
();
__shared__
Float
p_in_block
[
in_block_size
];
__shared__
Float
p_wei_block
[
wei_block_size
];
// register
Float
p_out_thread
[
out_hkwn_thread_desc
.
GetElementSpace
()];
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_hkwn_thread_desc
,
p_out_thread
);
for
(
unsigned
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
{
#if 1
// input: global mem to LDS,
blockwise_in_copy
.
run
(
p_in_global
,
c_block_data_begin
,
ho_block_data_begin
,
wo_block_data_begin
,
n_block_data_begin
,
p_in_block
,
h_block_pad_low
,
w_block_pad_low
,
h_block_pad_up
,
w_block_pad_up
);
#endif
#if 1
// weight: global mem to LDS,
blockwise_wei_copy
.
run
(
p_wei_global
+
wei_csrk_global_desc
.
Get1dIndex
(
c_block_data_begin
,
0
,
0
,
k_block_data_begin
),
p_wei_block
);
#endif
__syncthreads
();
// a series of batched GEMM
for
(
unsigned
s
=
0
;
s
<
S
;
++
s
)
{
for
(
unsigned
r
=
0
;
r
<
R
;
++
r
)
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
blockwise_batch_gemm
.
run
(
p_wei_block
+
wei_csrk_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_out_thread
,
f_accum
);
}
}
}
const
auto
matrix_c_index
=
blockwise_batch_gemm
.
CalculateThreadMatrixCIndex
(
get_thread_local_1d_id
());
const
unsigned
ho_thread_data_begin
=
matrix_c_index
.
batch_begin
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row_begin
;
const
unsigned
wo_thread_data_begin
=
matrix_c_index
.
col_begin
/
NPerBlock
;
const
unsigned
n_thread_data_begin
=
matrix_c_index
.
col_begin
-
wo_thread_data_begin
*
NPerBlock
;
#if 0
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
get_block_1d_id(), get_thread_local_1d_id(),
ho_block_data_begin, k_block_data_begin, wo_block_data_begin, n_block_data_begin,
ho_thread_data_begin, k_thread_data_begin, wo_thread_data_begin, n_thread_data_begin,
p_out_thread[0]);
#endif
// output: register to global mem,
// convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N]
constexpr
auto
reorder_khwn_from_hkwn
=
Sequence
<
1
,
0
,
2
,
3
>
{};
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src
(
out_hkwn_thread_desc
,
p_out_thread
,
out_khwn_global_desc
,
p_out_global
+
out_khwn_global_desc
.
Get1dIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_hkwn_thread_desc
.
GetLengths
(),
reorder_khwn_from_hkwn
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment