Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
3439e4b5
Commit
3439e4b5
authored
Jan 25, 2019
by
Chao Liu
Browse files
padding works (sort of), but code looks ugly. Tuned some resnet configs
parent
8bd6ea1a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
785 additions
and
20 deletions
+785
-20
driver/conv.cu
driver/conv.cu
+109
-18
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh
...plicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh
+245
-0
src/include/blockwise_4d_tensor_op.cuh
src/include/blockwise_4d_tensor_op.cuh
+127
-0
src/include/conv_common.cuh
src/include/conv_common.cuh
+39
-2
src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh
...plicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh
+265
-0
No files found.
driver/conv.cu
View file @
3439e4b5
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh"
#include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh"
#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh"
#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh"
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh"
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh"
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh"
#include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh"
#include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh"
//#include "device_winograd_convolution.cuh"
//#include "device_winograd_convolution.cuh"
...
@@ -107,20 +108,31 @@ auto make_TensorDescriptor(TConstTensorDesc)
...
@@ -107,20 +108,31 @@ auto make_TensorDescriptor(TConstTensorDesc)
return
TensorDescriptor
(
lengths
,
strides
);
return
TensorDescriptor
(
lengths
,
strides
);
}
}
template
<
class
T
>
template
<
class
T
,
class
LowerPads
,
class
UpperPads
>
void
host_direct_convolution
(
const
Tensor
<
T
>&
in_nchw
,
const
Tensor
<
T
>&
wei_kcsr
,
Tensor
<
T
>&
out
)
void
host_direct_convolution
(
const
Tensor
<
T
>&
in_nchw
,
const
Tensor
<
T
>&
wei_kcsr
,
Tensor
<
T
>&
out
,
LowerPads
,
UpperPads
)
{
{
unsigned
h_pad_low
=
LowerPads
{}.
Get
(
Number
<
0
>
{});
unsigned
w_pad_low
=
LowerPads
{}.
Get
(
Number
<
1
>
{});
unsigned
h_pad_up
=
UpperPads
{}.
Get
(
Number
<
0
>
{});
unsigned
w_pad_up
=
UpperPads
{}.
Get
(
Number
<
1
>
{});
auto
f
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
auto
f
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
double
v
=
0
;
double
v
=
0
;
for
(
int
c
=
0
;
c
<
wei_kcsr
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
int
c
=
0
;
c
<
wei_kcsr
.
mDesc
.
GetLengths
()[
1
];
++
c
)
{
{
for
(
int
y
=
0
;
y
<
wei_kcsr
.
mDesc
.
GetLengths
()[
2
];
++
y
)
for
(
int
y
=
0
;
y
<
wei_kcsr
.
mDesc
.
GetLengths
()[
2
];
++
y
)
{
{
int
hi
=
ho
+
y
;
int
hi
=
ho
+
y
-
h_pad_low
;
for
(
int
x
=
0
;
x
<
wei_kcsr
.
mDesc
.
GetLengths
()[
3
];
++
x
)
for
(
int
x
=
0
;
x
<
wei_kcsr
.
mDesc
.
GetLengths
()[
3
];
++
x
)
{
{
int
wi
=
wo
+
x
;
int
wi
=
wo
+
x
-
w_pad_low
;
v
+=
in_nchw
(
n
,
c
,
hi
,
wi
)
*
wei_kcsr
(
k
,
c
,
y
,
x
);
if
(
hi
>=
0
&&
hi
<
in_nchw
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
wi
<
in_nchw
.
mDesc
.
GetLengths
()[
3
])
{
v
+=
in_nchw
(
n
,
c
,
hi
,
wi
)
*
wei_kcsr
(
k
,
c
,
y
,
x
);
}
}
}
}
}
}
}
...
@@ -136,10 +148,9 @@ void host_direct_convolution(const Tensor<T>& in_nchw, const Tensor<T>& wei_kcsr
...
@@ -136,10 +148,9 @@ void host_direct_convolution(const Tensor<T>& in_nchw, const Tensor<T>& wei_kcsr
f_par
(
std
::
thread
::
hardware_concurrency
());
f_par
(
std
::
thread
::
hardware_concurrency
());
}
}
template
<
class
T
>
template
<
class
T
,
class
LowerPads
,
class
UpperPads
>
void
host_winograd_3x3_convolution
(
const
Tensor
<
T
>&
in_nchw
,
void
host_winograd_3x3_convolution
(
const
Tensor
<
T
>&
wei_kcsr
,
const
Tensor
<
T
>&
in_nchw
,
const
Tensor
<
T
>&
wei_kcsr
,
Tensor
<
T
>&
out
,
LowerPads
,
UpperPads
)
Tensor
<
T
>&
out
)
{
{
constexpr
std
::
size_t
OutTileSizeH
=
2
;
constexpr
std
::
size_t
OutTileSizeH
=
2
;
constexpr
std
::
size_t
OutTileSizeW
=
2
;
constexpr
std
::
size_t
OutTileSizeW
=
2
;
...
@@ -156,6 +167,12 @@ void host_winograd_3x3_convolution(const Tensor<T>& in_nchw,
...
@@ -156,6 +167,12 @@ void host_winograd_3x3_convolution(const Tensor<T>& in_nchw,
std
::
size_t
HO
=
out
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
HO
=
out
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
WO
=
out
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
WO
=
out
.
mDesc
.
GetLengths
()[
3
];
unsigned
h_pad_low
=
LowerPads
{}.
Get
(
Number
<
0
>
{});
unsigned
w_pad_low
=
LowerPads
{}.
Get
(
Number
<
1
>
{});
unsigned
h_pad_up
=
UpperPads
{}.
Get
(
Number
<
0
>
{});
unsigned
w_pad_up
=
UpperPads
{}.
Get
(
Number
<
1
>
{});
std
::
size_t
InTileSizeH
=
OutTileSizeH
+
S
-
1
;
std
::
size_t
InTileSizeH
=
OutTileSizeH
+
S
-
1
;
std
::
size_t
InTileSizeW
=
OutTileSizeW
+
R
-
1
;
std
::
size_t
InTileSizeW
=
OutTileSizeW
+
R
-
1
;
...
@@ -171,11 +188,20 @@ void host_winograd_3x3_convolution(const Tensor<T>& in_nchw,
...
@@ -171,11 +188,20 @@ void host_winograd_3x3_convolution(const Tensor<T>& in_nchw,
auto
f_in_hold
=
[
&
](
auto
n
,
auto
c
,
auto
y
,
auto
x
)
{
auto
f_in_hold
=
[
&
](
auto
n
,
auto
c
,
auto
y
,
auto
x
)
{
for
(
int
j
=
0
;
j
<
InTileSizeH
;
++
j
)
for
(
int
j
=
0
;
j
<
InTileSizeH
;
++
j
)
{
{
std
::
size_
t
hi
=
OutTileSizeH
*
y
+
j
;
in
t
hi
=
OutTileSizeH
*
y
+
j
-
h_pad_low
;
for
(
int
i
=
0
;
i
<
InTileSizeW
;
++
i
)
for
(
int
i
=
0
;
i
<
InTileSizeW
;
++
i
)
{
{
std
::
size_t
wi
=
OutTileSizeW
*
x
+
i
;
int
wi
=
OutTileSizeW
*
x
+
i
-
w_pad_low
;
in_hold
(
n
,
c
,
y
,
x
,
j
,
i
)
=
in_nchw
(
n
,
c
,
hi
,
wi
);
if
(
hi
>=
0
&&
hi
<
in_nchw
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
wi
<
in_nchw
.
mDesc
.
GetLengths
()[
3
])
{
in_hold
(
n
,
c
,
y
,
x
,
j
,
i
)
=
in_nchw
(
n
,
c
,
hi
,
wi
);
}
else
{
in_hold
(
n
,
c
,
y
,
x
,
j
,
i
)
=
T
(
0
);
}
}
}
}
}
};
};
...
@@ -406,7 +432,7 @@ int main()
...
@@ -406,7 +432,7 @@ int main()
constexpr
unsigned
K
=
64
;
constexpr
unsigned
K
=
64
;
constexpr
unsigned
S
=
7
;
constexpr
unsigned
S
=
7
;
constexpr
unsigned
R
=
7
;
constexpr
unsigned
R
=
7
;
#elif
1
#elif
0
// 3x3, 58x58
// 3x3, 58x58
constexpr
unsigned
N
=
16
;
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
128
;
constexpr
unsigned
C
=
128
;
...
@@ -415,12 +441,63 @@ int main()
...
@@ -415,12 +441,63 @@ int main()
constexpr
unsigned
K
=
256
;
constexpr
unsigned
K
=
256
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
constexpr
unsigned
R
=
3
;
#elif 0
// 3x3 filter, 58x58 image, 0x0 padding
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
128
;
constexpr
unsigned
HI
=
58
;
constexpr
unsigned
WI
=
58
;
constexpr
unsigned
K
=
256
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
constexpr
unsigned
HPad
=
0
;
constexpr
unsigned
WPad
=
0
;
#elif 1
// 3x3 filter, 56x56 image, 1x1 padding
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
128
;
constexpr
unsigned
HI
=
56
;
constexpr
unsigned
WI
=
56
;
constexpr
unsigned
K
=
256
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
constexpr
unsigned
HPad
=
1
;
constexpr
unsigned
WPad
=
1
;
#elif 0
// 3x3 filter, 28x28 image, 1x1 padding
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
256
;
constexpr
unsigned
HI
=
28
;
constexpr
unsigned
WI
=
28
;
constexpr
unsigned
K
=
512
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
constexpr
unsigned
HPad
=
1
;
constexpr
unsigned
WPad
=
1
;
#elif 0
// 3x3 filter, 20x84 image, 1x1 padding
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
256
;
constexpr
unsigned
HI
=
20
;
constexpr
unsigned
WI
=
84
;
constexpr
unsigned
K
=
256
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
constexpr
unsigned
HPad
=
1
;
constexpr
unsigned
WPad
=
1
;
#endif
#endif
auto
lower_pads
=
Sequence
<
HPad
,
WPad
>
{};
auto
upper_pads
=
Sequence
<
HPad
,
WPad
>
{};
auto
in_nchw_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
N
,
C
,
HI
,
WI
>
{});
auto
in_nchw_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
N
,
C
,
HI
,
WI
>
{});
auto
wei_kcsr_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
C
,
S
,
R
>
{});
auto
wei_kcsr_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
C
,
S
,
R
>
{});
auto
out_nkhw_desc
=
auto
out_nkhw_desc
=
get_convolution_with_padding_output_default_4d_tensor_descriptor
(
get_convolution_output_default_4d_tensor_descriptor
(
in_nchw_desc
,
wei_kcsr_desc
);
in_nchw_desc
,
wei_kcsr_desc
,
lower_pads
,
upper_pads
);
ostream_ConstantTensorDescriptor
(
in_nchw_desc
,
std
::
cout
<<
"in_nchw_desc: "
);
ostream_ConstantTensorDescriptor
(
in_nchw_desc
,
std
::
cout
<<
"in_nchw_desc: "
);
ostream_ConstantTensorDescriptor
(
wei_kcsr_desc
,
std
::
cout
<<
"wei_kcsr_desc: "
);
ostream_ConstantTensorDescriptor
(
wei_kcsr_desc
,
std
::
cout
<<
"wei_kcsr_desc: "
);
...
@@ -443,6 +520,7 @@ int main()
...
@@ -443,6 +520,7 @@ int main()
unsigned
nrepeat
=
50
;
unsigned
nrepeat
=
50
;
#if 0
#if 0
#if 0
device_direct_convolution_1
device_direct_convolution_1
#elif 0
#elif 0
...
@@ -451,7 +529,7 @@ int main()
...
@@ -451,7 +529,7 @@ int main()
device_implicit_gemm_convolution_1_nchw_kcsr
device_implicit_gemm_convolution_1_nchw_kcsr
#elif 0
#elif 0
device_implicit_gemm_convolution_1_nchw_srck_nkhw
device_implicit_gemm_convolution_1_nchw_srck_nkhw
#elif
1
#elif
0
device_implicit_gemm_convolution_1_chwn_csrk_khwn
device_implicit_gemm_convolution_1_chwn_csrk_khwn
#elif 0
#elif 0
device_implicit_gemm_convolution_2_cnhw_srck_knhw
device_implicit_gemm_convolution_2_cnhw_srck_knhw
...
@@ -459,15 +537,28 @@ int main()
...
@@ -459,15 +537,28 @@ int main()
device_winograd_convolution
device_winograd_convolution
#endif
#endif
(
in_nchw_desc
,
in_nchw
,
wei_kcsr_desc
,
wei_kcsr
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
(
in_nchw_desc
,
in_nchw
,
wei_kcsr_desc
,
wei_kcsr
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
#endif
#if 1
device_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding
(
in_nchw_desc
,
in_nchw
,
wei_kcsr_desc
,
wei_kcsr
,
out_nkhw_desc
,
out_nkhw_device
,
lower_pads
,
upper_pads
,
nrepeat
);
#endif
#if 1
#if 1
if
(
S
==
3
&&
R
==
3
)
if
(
S
==
3
&&
R
==
3
)
{
{
host_winograd_3x3_convolution
(
in_nchw
,
wei_kcsr
,
out_nkhw_host
);
host_winograd_3x3_convolution
(
in_nchw
,
wei_kcsr
,
out_nkhw_host
,
lower_pads
,
upper_pads
);
}
}
else
else
{
{
host_direct_convolution
(
in_nchw
,
wei_kcsr
,
out_nkhw_host
);
host_direct_convolution
(
in_nchw
,
wei_kcsr
,
out_nkhw_host
,
lower_pads
,
upper_pads
);
}
}
check_error
(
out_nkhw_host
,
out_nkhw_device
);
check_error
(
out_nkhw_host
,
out_nkhw_device
);
#endif
#endif
...
...
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh
0 → 100644
View file @
3439e4b5
#pragma once
#include "gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh"
#include <unistd.h>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
LowerPads
,
class
UpperPads
>
void
device_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcsr
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
LowerPads
,
UpperPads
,
unsigned
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
InDesc
{};
constexpr
auto
wei_kcsr_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
unsigned
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
unsigned
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
unsigned
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
unsigned
K
=
wei_kcsr_desc
.
GetLength
(
I0
);
constexpr
unsigned
C
=
wei_kcsr_desc
.
GetLength
(
I1
);
constexpr
unsigned
S
=
wei_kcsr_desc
.
GetLength
(
I2
);
constexpr
unsigned
R
=
wei_kcsr_desc
.
GetLength
(
I3
);
// reorder weight
auto
wei_csrk_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
S
,
R
,
K
>
{});
ostream_ConstantTensorDescriptor
(
wei_csrk_desc
,
std
::
cout
<<
"wei_csrk_desc: "
);
Tensor
<
T
>
wei_csrk
(
make_TensorDescriptor
(
wei_csrk_desc
));
auto
f_reorder_kcsr2csrk
=
[
&
](
auto
k
,
auto
c
,
auto
s
,
auto
r
)
{
wei_csrk
(
c
,
s
,
r
,
k
)
=
wei_kcsr
(
k
,
c
,
s
,
r
);
};
make_ParallelTensorFunctor
(
f_reorder_kcsr2csrk
,
K
,
C
,
S
,
R
)(
std
::
thread
::
hardware_concurrency
());
// reorder input
auto
in_chwn_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
Hi
,
Wi
,
N
>
{});
ostream_ConstantTensorDescriptor
(
in_chwn_desc
,
std
::
cout
<<
"in_chwn_desc: "
);
Tensor
<
T
>
in_chwn
(
make_TensorDescriptor
(
in_chwn_desc
));
auto
f_reorder_nchw2chwn
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
in_chwn
(
c
,
hi
,
wi
,
n
)
=
in_nchw
(
n
,
c
,
hi
,
wi
);
};
make_ParallelTensorFunctor
(
f_reorder_nchw2chwn
,
N
,
C
,
Hi
,
Wi
)(
std
::
thread
::
hardware_concurrency
());
// output
auto
out_khwn_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
Ho
,
Wo
,
N
>
{});
ostream_ConstantTensorDescriptor
(
out_khwn_desc
,
std
::
cout
<<
"out_khwn_desc: "
);
Tensor
<
T
>
out_khwn
(
make_TensorDescriptor
(
out_khwn_desc
));
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_chwn_device_buf
(
data_sz
*
in_chwn
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_csrk_device_buf
(
data_sz
*
wei_csrk
.
mDesc
.
GetElementSpace
());
DeviceMem
out_khwn_device_buf
(
data_sz
*
out_khwn
.
mDesc
.
GetElementSpace
());
in_chwn_device_buf
.
ToDevice
(
in_chwn
.
mData
.
data
());
wei_csrk_device_buf
.
ToDevice
(
wei_csrk
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
#if 0
constexpr unsigned NPerBlock = 1;
constexpr unsigned KPerBlock = 1;
constexpr unsigned CPerBlock = 1;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 1;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 8;
#elif
0
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
constexpr
unsigned
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
#elif 0
// 3x3 58x58, NKC = 16,256,128
constexpr
unsigned
NPerBlock
=
8
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
#elif 0
// for 5x5, 36x36
constexpr
unsigned
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
#elif 0
// for 7x7, 38x38
constexpr
unsigned
NPerBlock
=
8
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
#elif 0
// for 3x3, 56x56
constexpr
unsigned
NPerBlock
=
32
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
2
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
#elif 1
// 3x3 56x56, NKC = 16,256,128, with padding
// 3x3 28x28, NKC = 16,512,256, with padding
// 3x3 20x84, NKC = 16,256,256, with padding
constexpr
unsigned
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
#endif
constexpr
unsigned
GridSize
=
((
N
+
NPerBlock
-
1
)
/
NPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
)
*
((
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
)
*
((
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
);
dim3
block_dim
(
BlockSize
);
dim3
grid_dim
(
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
{
cudaEvent_t
start
,
stop
;
float
elapsedTime
;
cudaEventCreate
(
&
start
);
cudaEventRecord
(
start
,
0
);
gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_chwn_desc
),
decltype
(
wei_csrk_desc
),
decltype
(
out_khwn_desc
),
LowerPads
,
UpperPads
,
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
CPerThread
,
HoPerThread
,
WoPerThread
>
<<<
grid_dim
,
block_dim
>>>
(
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_csrk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
cudaEventCreate
(
&
stop
);
cudaEventRecord
(
stop
,
0
);
cudaEventSynchronize
(
stop
);
cudaEventElapsedTime
(
&
elapsedTime
,
start
,
stop
);
printf
(
"Elapsed time : %f ms
\n
"
,
elapsedTime
);
usleep
(
10000
);
}
checkCudaErrors
(
cudaGetLastError
());
out_khwn_device_buf
.
FromDevice
(
out_khwn
.
mData
.
data
());
// reorder output
auto
f_reorder_khwn2nkhw
=
[
&
](
auto
k
,
auto
ho
,
auto
wo
,
auto
n
)
{
out_nkhw
(
n
,
k
,
ho
,
wo
)
=
out_khwn
(
k
,
ho
,
wo
,
n
);
};
make_ParallelTensorFunctor
(
f_reorder_khwn2nkhw
,
K
,
Ho
,
Wo
,
N
)(
std
::
thread
::
hardware_concurrency
());
}
src/include/blockwise_4d_tensor_op.cuh
View file @
3439e4b5
...
@@ -211,6 +211,133 @@ struct blockwise_4d_tensor_copy_1
...
@@ -211,6 +211,133 @@ struct blockwise_4d_tensor_copy_1
}
}
};
};
template
<
unsigned
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
DstOpLengths
,
class
GlobalLowerPads
>
struct
blockwise_chwn_tensor_copy_with_padding
{
__device__
void
run
(
Float
*
const
__restrict__
p_src
,
unsigned
c_block_data_begin
,
unsigned
ho_block_data_begin
,
unsigned
wo_block_data_begin
,
unsigned
n_block_data_begin
,
Float
*
__restrict__
p_dst
,
unsigned
h_block_pad_low
,
unsigned
w_block_pad_low
,
unsigned
h_block_pad_up
,
unsigned
w_block_pad_up
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
DstOpLengths
{});
constexpr
auto
h_global_pad_low
=
GlobalLowerPads
{}.
Get
(
I0
);
constexpr
auto
w_global_pad_low
=
GlobalLowerPads
{}.
Get
(
I1
);
constexpr
unsigned
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
Float
*
const
p_src_tmp
=
p_src
+
src_desc
.
Get1dIndex
(
c_block_data_begin
,
(
ho_block_data_begin
+
h_block_pad_low
)
-
h_global_pad_low
,
(
wo_block_data_begin
+
w_block_pad_low
)
-
w_global_pad_low
,
n_block_data_begin
);
#if 0
if(get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(src_desc, "src_desc: ");
print_ConstantTensorDescriptor(dst_desc, "dst_desc: ");
print_ConstantTensorDescriptor(ref_desc, "ref_desc: ");
printf("%u %u, \t"
"h_global_pad_low %u w_global_pad_low %u \t"
"h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u \t"
"\n",
get_block_1d_id(),
get_thread_local_1d_id(),
h_global_pad_low,
w_global_pad_low,
h_block_pad_low,
w_block_pad_low,
h_block_pad_up,
w_block_pad_up);
}
#endif
for
(
unsigned
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
unsigned
is
=
threadIdx
.
x
+
iloop
*
BlockSize
;
unsigned
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
is
-=
did
[
2
]
*
ref_desc
.
GetStride
(
I2
);
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
unsigned
bindex
=
dst_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
p_dst
[
bindex
]
=
(
did
[
1
]
<
h_block_pad_low
||
did
[
1
]
+
h_block_pad_up
>=
ref_desc
.
GetLength
(
I1
)
||
did
[
2
]
<
w_block_pad_low
||
did
[
2
]
+
w_block_pad_up
>=
ref_desc
.
GetLength
(
I2
))
?
Float
(
0
)
:
p_src_tmp
[
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
])];
}
constexpr
bool
has_tail
=
(
ref_desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
unsigned
is
=
threadIdx
.
x
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
unsigned
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
is
-=
did
[
2
]
*
ref_desc
.
GetStride
(
I2
);
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
unsigned
bindex
=
dst_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
p_dst
[
bindex
]
=
(
did
[
1
]
<
h_block_pad_low
||
did
[
1
]
+
h_block_pad_up
>=
ref_desc
.
GetLength
(
I1
)
||
did
[
2
]
<
w_block_pad_low
||
did
[
2
]
+
w_block_pad_up
>=
ref_desc
.
GetLength
(
I2
))
?
Float
(
0
)
:
p_src_tmp
[
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
])];
}
}
}
};
template
<
unsigned
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
template
<
unsigned
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
struct
blockwise_4d_tensor_copy_dummy
struct
blockwise_4d_tensor_copy_dummy
{
{
...
...
src/include/conv_common.cuh
View file @
3439e4b5
...
@@ -27,8 +27,45 @@ __host__ __device__ constexpr auto get_convolution_output_default_4d_tensor_desc
...
@@ -27,8 +27,45 @@ __host__ __device__ constexpr auto get_convolution_output_default_4d_tensor_desc
constexpr
auto
S
=
wei_desc
.
GetLength
(
I2
);
constexpr
auto
S
=
wei_desc
.
GetLength
(
I2
);
constexpr
auto
R
=
wei_desc
.
GetLength
(
I3
);
constexpr
auto
R
=
wei_desc
.
GetLength
(
I3
);
constexpr
auto
HO
=
HI
-
S
+
1
;
constexpr
auto
HO
=
HI
+
1
-
S
;
constexpr
auto
WO
=
WI
-
R
+
1
;
constexpr
auto
WO
=
WI
+
1
-
R
;
return
make_ConstantTensorDescriptor
(
Sequence
<
N
,
K
,
HO
,
WO
>
{});
}
template
<
class
InDesc
,
class
WeiDesc
,
class
LowerPads
,
class
UpperPads
>
__host__
__device__
constexpr
auto
get_convolution_with_padding_output_default_4d_tensor_descriptor
(
InDesc
,
WeiDesc
,
LowerPads
,
UpperPads
)
{
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
static_assert
(
in_desc
.
GetDimension
()
==
4
,
"input nDim is not 4"
);
static_assert
(
wei_desc
.
GetDimension
()
==
4
,
"weight nDim is not 4"
);
static_assert
(
in_desc
.
GetLength
(
I1
)
==
wei_desc
.
GetLength
(
I1
),
"input & weight dimension not consistent"
);
constexpr
auto
N
=
in_desc
.
GetLength
(
I0
);
constexpr
auto
HI
=
in_desc
.
GetLength
(
I2
);
constexpr
auto
WI
=
in_desc
.
GetLength
(
I3
);
constexpr
auto
K
=
wei_desc
.
GetLength
(
I0
);
constexpr
auto
S
=
wei_desc
.
GetLength
(
I2
);
constexpr
auto
R
=
wei_desc
.
GetLength
(
I3
);
constexpr
auto
HPadLow
=
LowerPads
{}.
Get
(
I0
);
constexpr
auto
WPadLow
=
LowerPads
{}.
Get
(
I1
);
constexpr
auto
HPadUp
=
UpperPads
{}.
Get
(
I0
);
constexpr
auto
WPadUp
=
UpperPads
{}.
Get
(
I1
);
constexpr
auto
HO
=
HI
+
HPadLow
+
HPadUp
+
1
-
S
;
constexpr
auto
WO
=
WI
+
WPadLow
+
WPadUp
+
1
-
R
;
return
make_ConstantTensorDescriptor
(
Sequence
<
N
,
K
,
HO
,
WO
>
{});
return
make_ConstantTensorDescriptor
(
Sequence
<
N
,
K
,
HO
,
WO
>
{});
}
}
src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh
0 → 100644
View file @
3439e4b5
#pragma once
#include "common.cuh"
#include "ConstantTensorDescriptor.cuh"
#include "ConstantMatrixDescriptor.cuh"
#include "blockwise_4d_tensor_op.cuh"
#include "threadwise_4d_tensor_op.cuh"
#include "gemm.cuh"
template
<
unsigned
GridSize
,
unsigned
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
LowerPads
,
class
UpperPads
,
unsigned
NPerBlock
,
unsigned
KPerBlock
,
unsigned
CPerBlock
,
unsigned
HoPerBlock
,
unsigned
WoPerBlock
,
unsigned
NPerThread
,
unsigned
KPerThread
,
unsigned
CPerThread
,
unsigned
HoPerThread
,
unsigned
WoPerThread
>
__global__
void
gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding
(
Float
*
const
__restrict__
p_in_global
,
Float
*
const
__restrict__
p_wei_global
,
Float
*
__restrict__
p_out_global
)
{
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
static_assert
(
NPerBlock
%
NPerThread
==
0
,
"wrong! NPerBlock % NPerThread !=0"
);
static_assert
((
NPerThread
<
NPerBlock
&&
WoPerThread
==
1
)
||
NPerThread
==
NPerBlock
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_csrk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
constexpr
unsigned
C
=
in_chwn_global_desc
.
GetLength
(
I0
);
constexpr
unsigned
K
=
out_khwn_global_desc
.
GetLength
(
I0
);
constexpr
unsigned
Ho
=
out_khwn_global_desc
.
GetLength
(
I1
);
constexpr
unsigned
Wo
=
out_khwn_global_desc
.
GetLength
(
I2
);
constexpr
unsigned
N
=
out_khwn_global_desc
.
GetLength
(
I3
);
constexpr
unsigned
S
=
wei_csrk_global_desc
.
GetLength
(
I1
);
constexpr
unsigned
R
=
wei_csrk_global_desc
.
GetLength
(
I2
);
constexpr
unsigned
HPadLow
=
LowerPads
{}.
Get
(
I0
);
constexpr
unsigned
WPadLow
=
LowerPads
{}.
Get
(
I1
);
constexpr
unsigned
HPadUp
=
UpperPads
{}.
Get
(
I0
);
constexpr
unsigned
WPadUp
=
UpperPads
{}.
Get
(
I1
);
constexpr
unsigned
HiPerBlock
=
HoPerBlock
+
S
-
1
;
constexpr
unsigned
WiPerBlock
=
WoPerBlock
+
R
-
1
;
// divide block work: [K, Ho, Wo, N]
constexpr
unsigned
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
unsigned
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
unsigned
WBlockWork
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
;
constexpr
unsigned
NBlockWork
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
unsigned
k_block_work_id
=
get_block_1d_id
()
/
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
unsigned
itmp
=
get_block_1d_id
()
-
k_block_work_id
*
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
const
unsigned
h_block_work_id
=
itmp
/
(
WBlockWork
*
NBlockWork
);
itmp
-=
h_block_work_id
*
(
WBlockWork
*
NBlockWork
);
const
unsigned
w_block_work_id
=
itmp
/
NBlockWork
;
const
unsigned
n_block_work_id
=
itmp
-
w_block_work_id
*
NBlockWork
;
const
unsigned
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
unsigned
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
unsigned
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
unsigned
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
// tensor view of blockwise input and weight in LDS
constexpr
auto
in_chwn_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{});
constexpr
auto
wei_csrk_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
CPerBlock
,
S
,
R
,
KPerBlock
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_hkwn_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
HoPerThread
,
KPerThread
,
WoPerThread
,
NPerThread
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc");
print_ConstantTensorDescriptor(wei_csrk_block_desc, "wei_csrk_block_desc");
print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc");
}
#endif
// blockwise copy
// input: format is [C, Hi, Wi, N]
const
unsigned
h_block_pad_low
=
h_block_work_id
==
0
?
HPadLow
:
0
;
const
unsigned
w_block_pad_low
=
w_block_work_id
==
0
?
WPadLow
:
0
;
const
unsigned
h_block_pad_up
=
h_block_work_id
==
HBlockWork
-
1
?
HPadUp
:
0
;
const
unsigned
w_block_pad_up
=
w_block_work_id
==
WBlockWork
-
1
?
WPadUp
:
0
;
#if 0
if(get_thread_local_1d_id() == 0)
;
{
printf(
"%u %u, h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
h_block_pad_low,
w_block_pad_low,
h_block_pad_up,
w_block_pad_up);
}
#endif
constexpr
auto
blockwise_in_copy
=
blockwise_chwn_tensor_copy_with_padding
<
BlockSize
,
Float
,
decltype
(
in_chwn_global_desc
),
decltype
(
in_chwn_block_desc
),
decltype
(
in_chwn_block_desc
.
GetLengths
()),
LowerPads
>
{};
// weight: format is [S,R,C,K]
constexpr
auto
blockwise_wei_copy
=
blockwise_4d_tensor_copy_1
<
BlockSize
,
Float
,
decltype
(
wei_csrk_global_desc
),
decltype
(
wei_csrk_block_desc
),
decltype
(
wei_csrk_block_desc
.
GetLengths
())
>
{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[S,R,C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N]
const
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_csrk_block_desc
.
GetStride
(
I0
)
>
{});
// constexpr doesn't compile
const
auto
b_cxwn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{},
Number
<
in_chwn_block_desc
.
GetStride
(
I0
)
>
{});
// constexpr doesn't compile
const
auto
c_kxwn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{});
// constexpr doesn't compile
const
auto
blockwise_batch_gemm
=
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
true
,
false
,
false
,
0
,
in_chwn_block_desc
.
GetStride
(
I1
),
out_hkwn_thread_desc
.
GetStride
(
I0
),
HoPerBlock
,
HoPerThread
,
CPerThread
,
true
>
{};
// LDS
constexpr
unsigned
in_block_size
=
in_chwn_block_desc
.
GetElementSpace
();
constexpr
unsigned
wei_block_size
=
wei_csrk_block_desc
.
GetElementSpace
();
__shared__
Float
p_in_block
[
in_block_size
];
__shared__
Float
p_wei_block
[
wei_block_size
];
// register
Float
p_out_thread
[
out_hkwn_thread_desc
.
GetElementSpace
()];
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_hkwn_thread_desc
,
p_out_thread
);
for
(
unsigned
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
{
#if 1
// input: global mem to LDS,
blockwise_in_copy
.
run
(
p_in_global
,
c_block_data_begin
,
ho_block_data_begin
,
wo_block_data_begin
,
n_block_data_begin
,
p_in_block
,
h_block_pad_low
,
w_block_pad_low
,
h_block_pad_up
,
w_block_pad_up
);
#endif
#if 1
// weight: global mem to LDS,
blockwise_wei_copy
.
run
(
p_wei_global
+
wei_csrk_global_desc
.
Get1dIndex
(
c_block_data_begin
,
0
,
0
,
k_block_data_begin
),
p_wei_block
);
#endif
__syncthreads
();
// a series of batched GEMM
for
(
unsigned
s
=
0
;
s
<
S
;
++
s
)
{
for
(
unsigned
r
=
0
;
r
<
R
;
++
r
)
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
blockwise_batch_gemm
.
run
(
p_wei_block
+
wei_csrk_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
s
,
r
,
0
),
p_out_thread
,
f_accum
);
}
}
}
const
auto
matrix_c_index
=
blockwise_batch_gemm
.
CalculateThreadMatrixCIndex
(
get_thread_local_1d_id
());
const
unsigned
ho_thread_data_begin
=
matrix_c_index
.
batch_begin
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row_begin
;
const
unsigned
wo_thread_data_begin
=
matrix_c_index
.
col_begin
/
NPerBlock
;
const
unsigned
n_thread_data_begin
=
matrix_c_index
.
col_begin
-
wo_thread_data_begin
*
NPerBlock
;
#if 0
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
get_block_1d_id(), get_thread_local_1d_id(),
ho_block_data_begin, k_block_data_begin, wo_block_data_begin, n_block_data_begin,
ho_thread_data_begin, k_thread_data_begin, wo_thread_data_begin, n_thread_data_begin,
p_out_thread[0]);
#endif
// output: register to global mem,
// convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N]
constexpr
auto
reorder_khwn_from_hkwn
=
Sequence
<
1
,
0
,
2
,
3
>
{};
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src
(
out_hkwn_thread_desc
,
p_out_thread
,
out_khwn_global_desc
,
p_out_global
+
out_khwn_global_desc
.
Get1dIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_hkwn_thread_desc
.
GetLengths
(),
reorder_khwn_from_hkwn
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment