Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
569941a7
"vscode:/vscode.git/clone" did not exist on "2deb09ae53f417cbc9310531577006c18d7f3434"
Commit
569941a7
authored
Apr 03, 2019
by
Chao Liu
Browse files
create mini code
parent
6166233e
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2 additions
and
5079 deletions
+2
-5079
driver/device_direct_convolution_1.hpp
driver/device_direct_convolution_1.hpp
+0
-87
driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp
driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp
+0
-111
driver/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
...device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
+0
-211
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
+0
-341
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
...ice_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
+0
-293
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
+0
-1
driver/driver.hip.cpp
driver/driver.hip.cpp
+2
-262
src/include/blockwise_4d_tensor_op.hip.hpp
src/include/blockwise_4d_tensor_op.hip.hpp
+0
-587
src/include/blockwise_batched_gemm.hip.hpp
src/include/blockwise_batched_gemm.hip.hpp
+0
-347
src/include/blockwise_direct_convolution.hip.hpp
src/include/blockwise_direct_convolution.hip.hpp
+0
-134
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+0
-593
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp
...dwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp
+0
-3
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
...implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
+0
-385
src/include/gridwise_direct_convolution_1.hip.hpp
src/include/gridwise_direct_convolution_1.hip.hpp
+0
-152
src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp
...lude/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp
+0
-237
src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
...se_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
+0
-252
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp
...idwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp
+0
-318
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp
...implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp
+0
-293
src/include/threadwise_4d_tensor_op.hip.hpp
src/include/threadwise_4d_tensor_op.hip.hpp
+0
-258
src/include/threadwise_direct_convolution.hip.hpp
src/include/threadwise_direct_convolution.hip.hpp
+0
-214
No files found.
driver/device_direct_convolution_1.hpp
deleted
100644 → 0
View file @
6166233e
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "gridwise_direct_convolution_1.hip.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_direct_convolution_1
(
InDesc
,
const
Tensor
<
T
>&
in
,
WeiDesc
,
const
Tensor
<
T
>&
wei
,
OutDesc
,
Tensor
<
T
>&
out
,
index_t
nrepeat
)
{
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_device_buf
(
data_sz
*
in
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_device_buf
(
data_sz
*
wei
.
mDesc
.
GetElementSpace
());
DeviceMem
out_device_buf
(
data_sz
*
out
.
mDesc
.
GetElementSpace
());
int
num_thread
=
std
::
thread
::
hardware_concurrency
();
in_device_buf
.
ToDevice
(
in
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei
.
mData
.
data
());
out_device_buf
.
ToDevice
(
out
.
mData
.
data
());
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
out_desc
=
OutDesc
{};
#if 1
// 3x3, 34x34
constexpr
index_t
NPerBlock
=
2
;
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
NPerThread
=
2
;
constexpr
index_t
KPerThread
=
4
;
constexpr
index_t
CPerThread
=
2
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
BlockSize
=
128
;
#endif
constexpr
index_t
GridSize
=
(
out_desc
.
GetLength
(
I0
)
/
NPerBlock
)
*
(
out_desc
.
GetLength
(
I1
)
/
KPerBlock
)
*
(
out_desc
.
GetLength
(
I2
)
/
HoPerBlock
)
*
(
out_desc
.
GetLength
(
I3
)
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_kernel
(
gridwise_direct_convolution_1
<
T
,
InDesc
,
WeiDesc
,
OutDesc
,
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
CPerThread
,
HoPerThread
,
WoPerThread
,
BlockSize
,
GridSize
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
static_cast
<
T
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
out_device_buf
.
FromDevice
(
out
.
mData
.
data
());
}
driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
6166233e
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_direct_convolution_2_nchw_kcyx_nkhw
(
InDesc
,
const
Tensor
<
T
>&
in
,
WeiDesc
,
const
Tensor
<
T
>&
wei
,
OutDesc
,
Tensor
<
T
>&
out
,
index_t
nrepeat
)
{
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_device_buf
(
data_sz
*
in
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_device_buf
(
data_sz
*
wei
.
mDesc
.
GetElementSpace
());
DeviceMem
out_device_buf
(
data_sz
*
out
.
mDesc
.
GetElementSpace
());
int
num_thread
=
std
::
thread
::
hardware_concurrency
();
in_device_buf
.
ToDevice
(
in
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei
.
mData
.
data
());
out_device_buf
.
ToDevice
(
out
.
mData
.
data
());
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
out_desc
=
OutDesc
{};
#if 1
// 3x3, 34x34, 128 thread
constexpr
index_t
NPerBlock
=
2
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
NPerThread
=
2
;
constexpr
index_t
KPerThread
=
4
;
constexpr
index_t
CPerThread
=
2
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
BlockSize
=
128
;
#elif 1
// 3x3, 34x34, 128 thread, fp16
constexpr
index_t
NPerBlock
=
2
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
NPerThread
=
2
;
constexpr
index_t
KPerThread
=
4
;
constexpr
index_t
CPerThread
=
2
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
BlockSize
=
128
;
#endif
constexpr
index_t
GridSize
=
(
out_desc
.
GetLength
(
I0
)
/
NPerBlock
)
*
(
out_desc
.
GetLength
(
I1
)
/
KPerBlock
)
*
(
out_desc
.
GetLength
(
I2
)
/
HoPerBlock
)
*
(
out_desc
.
GetLength
(
I3
)
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_kernel
(
gridwise_direct_convolution_2_nchw_kcyx_nkhw
<
T
,
InDesc
,
WeiDesc
,
OutDesc
,
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
CPerThread
,
HoPerThread
,
WoPerThread
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
,
BlockSize
,
GridSize
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
static_cast
<
T
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
out_device_buf
.
FromDevice
(
out
.
mData
.
data
());
}
driver/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
6166233e
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp"
template
<
class
TInWei
,
class
TOut
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
(
InDesc
,
const
Tensor
<
TInWei
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
TInWei
>&
wei_kcyx
,
OutDesc
,
Tensor
<
TOut
>&
out_nkhw
,
index_t
nrepeat
)
{
// this suppose in / wei data type is int8x4
constexpr
index_t
NVector
=
4
;
using
accum_t
=
int32_t
;
using
vector_t
=
vector_type
<
TInWei
,
NVector
>
;
using
vector_mem_t
=
typename
vector_t
::
MemoryType
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
InDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
// vectorized input
auto
in_nchw_vec_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
N
,
C
/
NVector
,
Hi
,
Wi
>
{});
ostream_ConstantTensorDescriptor
(
in_nchw_vec_desc
,
std
::
cout
<<
"in_nchw_vec_desc: "
);
Tensor
<
vector_mem_t
>
in_nchw_vec
(
make_TensorDescriptor
(
in_nchw_vec_desc
));
auto
f_vectorized_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
h
,
auto
w
)
{
#if 0
in_nchw_vec(n, c, h, w) = in_nchw(n, c, h, w);
#elif
0
in_nchw_vec
(
n
,
c
,
h
,
w
)
=
vector_t
::
Pack
(
in_nchw
(
n
,
2
*
c
,
h
,
w
),
in_nchw
(
n
,
2
*
c
+
1
,
h
,
w
));
#elif 1
in_nchw_vec
(
n
,
c
,
h
,
w
)
=
vector_t
::
Pack
(
in_nchw
(
n
,
4
*
c
,
h
,
w
),
in_nchw
(
n
,
4
*
c
+
1
,
h
,
w
),
in_nchw
(
n
,
4
*
c
+
2
,
h
,
w
),
in_nchw
(
n
,
4
*
c
+
3
,
h
,
w
));
#endif
};
make_ParallelTensorFunctor
(
f_vectorized_nchw
,
N
,
C
/
NVector
,
Hi
,
Wi
)(
std
::
thread
::
hardware_concurrency
());
// vectorize weight
auto
wei_kcyx_vec_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
C
/
NVector
,
Y
,
X
>
{});
ostream_ConstantTensorDescriptor
(
wei_kcyx_vec_desc
,
std
::
cout
<<
"wei_kcyx_vec_desc: "
);
Tensor
<
vector_mem_t
>
wei_kcyx_vec
(
make_TensorDescriptor
(
wei_kcyx_vec_desc
));
auto
f_vectorized_kcyx
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
#if 0
wei_kcyx_vec(k, c, y, x) = wei_kcyx(k, c, y, x);
#elif
0
wei_kcyx_vec
(
k
,
c
,
y
,
x
)
=
vector_t
::
Pack
(
wei_kcyx
(
k
,
2
*
c
,
y
,
x
),
wei_kcyx
(
k
,
2
*
c
+
1
,
y
,
x
));
#elif 1
wei_kcyx_vec
(
k
,
c
,
y
,
x
)
=
vector_t
::
Pack
(
wei_kcyx
(
k
,
4
*
c
,
y
,
x
),
wei_kcyx
(
k
,
4
*
c
+
1
,
y
,
x
),
wei_kcyx
(
k
,
4
*
c
+
2
,
y
,
x
),
wei_kcyx
(
k
,
4
*
c
+
3
,
y
,
x
));
#endif
};
make_ParallelTensorFunctor
(
f_vectorized_kcyx
,
K
,
C
/
NVector
,
Y
,
X
)(
std
::
thread
::
hardware_concurrency
());
//
DeviceMem
in_nchw_vec_device_buf
(
sizeof
(
vector_mem_t
)
*
in_nchw_vec
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_vec_device_buf
(
sizeof
(
vector_mem_t
)
*
wei_kcyx_vec
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
sizeof
(
TOut
)
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_vec_device_buf
.
ToDevice
(
in_nchw_vec
.
mData
.
data
());
wei_kcyx_vec_device_buf
.
ToDevice
(
wei_kcyx_vec
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if 0
// 3x3, 34x34, 128 thread, fp32, vector = 1
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr index_t BlockSize = 128;
#elif
0
// 3x3, 34x34, 128 thread, fp32, vector = 2
constexpr
index_t
NPerBlock
=
2
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
NPerThread
=
2
;
constexpr
index_t
KPerThread
=
4
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
2
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// 3x3, 34x34, 128 thread, int8, vector = 4
constexpr
index_t
NPerBlock
=
2
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
NPerThread
=
1
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
CPerThread
=
2
;
constexpr
index_t
HoPerThread
=
4
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
2
;
constexpr
index_t
BlockSize
=
128
;
#elif 1
// 1x1, 32x32, 128 thread, int8, vector = 4
constexpr
index_t
NPerBlock
=
1
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
16
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
NPerThread
=
1
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
CPerThread
=
2
;
constexpr
index_t
HoPerThread
=
4
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
2
;
constexpr
index_t
BlockSize
=
128
;
#endif
constexpr
index_t
GridSize
=
(
N
/
NPerBlock
)
*
(
K
/
KPerBlock
)
*
(
Ho
/
HoPerBlock
)
*
(
Wo
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_kernel
(
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw
<
TInWei
,
TOut
,
accum_t
,
decltype
(
in_nchw_vec_desc
),
decltype
(
wei_kcyx_vec_desc
),
decltype
(
out_nkhw_desc
),
NVector
,
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
CPerThread
,
HoPerThread
,
WoPerThread
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
,
BlockSize
,
GridSize
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
static_cast
<
TInWei
*>
(
in_nchw_vec_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_kcyx_vec_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
deleted
100644 → 0
View file @
6166233e
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
index_t
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
InDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
// reorder weight
auto
wei_cyxk_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
Y
,
X
,
K
>
{});
ostream_ConstantTensorDescriptor
(
wei_cyxk_desc
,
std
::
cout
<<
"wei_cyxk_desc: "
);
Tensor
<
T
>
wei_cyxk
(
make_TensorDescriptor
(
wei_cyxk_desc
));
auto
f_reorder_kcyx2cyxk
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
wei_cyxk
(
c
,
y
,
x
,
k
)
=
wei_kcyx
(
k
,
c
,
y
,
x
);
};
make_ParallelTensorFunctor
(
f_reorder_kcyx2cyxk
,
K
,
C
,
Y
,
X
)(
std
::
thread
::
hardware_concurrency
());
// reorder input
auto
in_chwn_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
Hi
,
Wi
,
N
>
{});
ostream_ConstantTensorDescriptor
(
in_chwn_desc
,
std
::
cout
<<
"in_chwn_desc: "
);
Tensor
<
T
>
in_chwn
(
make_TensorDescriptor
(
in_chwn_desc
));
auto
f_reorder_nchw2chwn
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
in_chwn
(
c
,
hi
,
wi
,
n
)
=
in_nchw
(
n
,
c
,
hi
,
wi
);
};
make_ParallelTensorFunctor
(
f_reorder_nchw2chwn
,
N
,
C
,
Hi
,
Wi
)(
std
::
thread
::
hardware_concurrency
());
// output
auto
out_khwn_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
Ho
,
Wo
,
N
>
{});
ostream_ConstantTensorDescriptor
(
out_khwn_desc
,
std
::
cout
<<
"out_khwn_desc: "
);
Tensor
<
T
>
out_khwn
(
make_TensorDescriptor
(
out_khwn_desc
));
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_chwn_device_buf
(
data_sz
*
in_chwn
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_cyxk_device_buf
(
data_sz
*
wei_cyxk
.
mDesc
.
GetElementSpace
());
DeviceMem
out_khwn_device_buf
(
data_sz
*
out_khwn
.
mDesc
.
GetElementSpace
());
in_chwn_device_buf
.
ToDevice
(
in_chwn
.
mData
.
data
());
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
#if 0
// for 3x3, 34x34
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
#elif
0
// for 5x5, 36x36
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// 3x3 58x58, NKC = 64, 64, 256
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
// not used, yet
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// 3x3 58x58, NKC = 16,256,128
constexpr
index_t
NPerBlock
=
8
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 7x7, 38x38
constexpr
index_t
NPerBlock
=
8
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
1
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
// not used, yet
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 3x3, 56x56
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 1x1, 28x28
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
BlockSize
=
128
;
#elif 1
// for 1x1, 14x14
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
BlockSize
=
128
;
#endif
constexpr
index_t
GridSize
=
((
N
+
NPerBlock
-
1
)
/
NPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
)
*
((
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
)
*
((
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_kernel
(
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_chwn_desc
),
decltype
(
wei_cyxk_desc
),
decltype
(
out_khwn_desc
),
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
HoPerThread
,
WoPerThread
,
Sequence
<
InBlockCopy_ThreadPerDimC
,
InBlockCopy_ThreadPerDimH
,
InBlockCopy_ThreadPerDimW
,
InBlockCopy_ThreadPerDimN
>
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
OutThreadCopyDataPerWrite
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_cyxk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
out_khwn_device_buf
.
FromDevice
(
out_khwn
.
mData
.
data
());
// reorder output
auto
f_reorder_khwn2nkhw
=
[
&
](
auto
k
,
auto
ho
,
auto
wo
,
auto
n
)
{
out_nkhw
(
n
,
k
,
ho
,
wo
)
=
out_khwn
(
k
,
ho
,
wo
,
n
);
};
make_ParallelTensorFunctor
(
f_reorder_khwn2nkhw
,
K
,
Ho
,
Wo
,
N
)(
std
::
thread
::
hardware_concurrency
());
}
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
deleted
100644 → 0
View file @
6166233e
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
LowerPads
,
class
UpperPads
>
void
device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
LowerPads
,
UpperPads
,
index_t
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
InDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
// reorder weight
auto
wei_cyxk_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
Y
,
X
,
K
>
{});
ostream_ConstantTensorDescriptor
(
wei_cyxk_desc
,
std
::
cout
<<
"wei_cyxk_desc: "
);
Tensor
<
T
>
wei_cyxk
(
make_TensorDescriptor
(
wei_cyxk_desc
));
auto
f_reorder_kcyx2cyxk
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
wei_cyxk
(
c
,
y
,
x
,
k
)
=
wei_kcyx
(
k
,
c
,
y
,
x
);
};
make_ParallelTensorFunctor
(
f_reorder_kcyx2cyxk
,
K
,
C
,
Y
,
X
)(
std
::
thread
::
hardware_concurrency
());
// reorder input
auto
in_chwn_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
Hi
,
Wi
,
N
>
{});
ostream_ConstantTensorDescriptor
(
in_chwn_desc
,
std
::
cout
<<
"in_chwn_desc: "
);
Tensor
<
T
>
in_chwn
(
make_TensorDescriptor
(
in_chwn_desc
));
auto
f_reorder_nchw2chwn
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
in_chwn
(
c
,
hi
,
wi
,
n
)
=
in_nchw
(
n
,
c
,
hi
,
wi
);
};
make_ParallelTensorFunctor
(
f_reorder_nchw2chwn
,
N
,
C
,
Hi
,
Wi
)(
std
::
thread
::
hardware_concurrency
());
// output
auto
out_khwn_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
Ho
,
Wo
,
N
>
{});
ostream_ConstantTensorDescriptor
(
out_khwn_desc
,
std
::
cout
<<
"out_khwn_desc: "
);
Tensor
<
T
>
out_khwn
(
make_TensorDescriptor
(
out_khwn_desc
));
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_chwn_device_buf
(
data_sz
*
in_chwn
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_cyxk_device_buf
(
data_sz
*
wei_cyxk
.
mDesc
.
GetElementSpace
());
DeviceMem
out_khwn_device_buf
(
data_sz
*
out_khwn
.
mDesc
.
GetElementSpace
());
in_chwn_device_buf
.
ToDevice
(
in_chwn
.
mData
.
data
());
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
#if 0
constexpr index_t NPerBlock = 1;
constexpr index_t KPerBlock = 1;
constexpr index_t CPerBlock = 1;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 1;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t WeiBlockCopyThreadPerDim0 = 1;
constexpr index_t WeiBlockCopyThreadPerDim1 = 1;
constexpr index_t BlockSize = 8;
#elif
1
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// 3x3 58x58, NKC = 16,256,128
constexpr
index_t
NPerBlock
=
8
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 5x5, 36x36
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 7x7, 38x38
constexpr
index_t
NPerBlock
=
8
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 3x3, 56x56
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 1
// 3x3 56x56, NKC = 16,256,128, with padding
// 3x3 28x28, NKC = 16,512,256, with padding
// 3x3 20x84, NKC = 16,256,256, with padding
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
2
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
64
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 5x5 filter, 20x84 image, 1x1 padding
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
1
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// 5x5 filter, 28x28 image, 2x2 padding
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 1x1, 28x28
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
2
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
index_t
BlockSize
=
128
;
#endif
constexpr
index_t
GridSize
=
((
N
+
NPerBlock
-
1
)
/
NPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
)
*
((
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
)
*
((
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_kernel
(
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_chwn_desc
),
decltype
(
wei_cyxk_desc
),
decltype
(
out_khwn_desc
),
LowerPads
,
UpperPads
,
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
CPerThread
,
HoPerThread
,
WoPerThread
,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_cyxk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
out_khwn_device_buf
.
FromDevice
(
out_khwn
.
mData
.
data
());
// reorder output
auto
f_reorder_khwn2nkhw
=
[
&
](
auto
k
,
auto
ho
,
auto
wo
,
auto
n
)
{
out_nkhw
(
n
,
k
,
ho
,
wo
)
=
out_khwn
(
k
,
ho
,
wo
,
n
);
};
make_ParallelTensorFunctor
(
f_reorder_khwn2nkhw
,
K
,
Ho
,
Wo
,
N
)(
std
::
thread
::
hardware_concurrency
());
}
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
View file @
569941a7
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
#include "device.hpp"
#include "device.hpp"
#include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_implicit_gemm_convolution_2_chwn_cyxk_khwn
(
InDesc
,
void
device_implicit_gemm_convolution_2_chwn_cyxk_khwn
(
InDesc
,
...
...
driver/driver.hip.cpp
View file @
569941a7
...
@@ -7,22 +7,8 @@
...
@@ -7,22 +7,8 @@
#include "tensor.hpp"
#include "tensor.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "conv_common.hip.hpp"
#include "conv_common.hip.hpp"
//#include "device_direct_convolution_1.hpp"
#include "device_direct_convolution_2_nchw_kcyx_nkhw.hpp"
//#include "device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp"
#include "device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp"
//#include "device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp"
#include "device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp"
#include "device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp"
struct
GeneratorTensor_1
{
template
<
class
...
Is
>
double
operator
()(
Is
...
is
)
{
return
1
;
}
};
struct
GeneratorTensor_2
struct
GeneratorTensor_2
{
{
int
min_value
=
0
;
int
min_value
=
0
;
...
@@ -35,21 +21,6 @@ struct GeneratorTensor_2
...
@@ -35,21 +21,6 @@ struct GeneratorTensor_2
}
}
};
};
struct
GeneratorTensor_Checkboard
{
template
<
class
...
Ts
>
double
operator
()(
Ts
...
Xs
)
const
{
std
::
array
<
index_t
,
sizeof
...(
Ts
)
>
dims
=
{{
Xs
...}};
return
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
true
,
[](
bool
init
,
index_t
x
)
->
int
{
return
init
!=
(
x
%
2
);
})
?
1
:
-
1
;
}
};
// this is ugly, only for 4d
// this is ugly, only for 4d
template
<
class
TConstTensorDesc
>
template
<
class
TConstTensorDesc
>
void
ostream_ConstantTensorDescriptor
(
TConstTensorDesc
,
std
::
ostream
&
os
=
std
::
cout
)
void
ostream_ConstantTensorDescriptor
(
TConstTensorDesc
,
std
::
ostream
&
os
=
std
::
cout
)
...
@@ -398,201 +369,6 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
...
@@ -398,201 +369,6 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
#if 0
constexpr index_t N = 1;
constexpr index_t C = 1;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 1;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif
0
// 3x3, 34x34
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
34
;
constexpr
index_t
WI
=
34
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
// 3x3, 56x56
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
#elif 0
// 3x3, 58x58
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
58
;
constexpr
index_t
WI
=
58
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
#elif 0
// 5x5, 36x36
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
36
;
constexpr
index_t
WI
=
36
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
5
;
constexpr
index_t
X
=
5
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
// 7x7, 38x38
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
38
;
constexpr
index_t
WI
=
38
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
7
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
// 3x3, 58x58
constexpr
index_t
N
=
16
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
58
;
constexpr
index_t
WI
=
58
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
#elif 0
// 3x3 filter, 58x58 image, 0x0 padding
constexpr
index_t
N
=
16
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
58
;
constexpr
index_t
WI
=
58
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
// 3x3 filter, 56x56 image, 1x1 padding
constexpr
index_t
N
=
16
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
HPad
=
1
;
constexpr
index_t
WPad
=
1
;
#elif 0
// 3x3 filter, 28x28 image, 1x1 padding
constexpr
index_t
N
=
16
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
HPad
=
1
;
constexpr
index_t
WPad
=
1
;
#elif 0
// 1x1 filter, 28x28 image
constexpr
index_t
N
=
16
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
// 3x3 filter, 20x84 image, 1x1 padding
constexpr
index_t
N
=
16
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
20
;
constexpr
index_t
WI
=
84
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
HPad
=
1
;
constexpr
index_t
WPad
=
1
;
#elif 0
// 3x3 filter, 112x112 image, 1x1 padding
constexpr
index_t
N
=
16
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
112
;
constexpr
index_t
WI
=
112
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
HPad
=
1
;
constexpr
index_t
WPad
=
1
;
#elif 0
// 5x5 filter, 20x86 image, 1x1 padding
constexpr
index_t
N
=
16
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
20
;
constexpr
index_t
WI
=
86
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
5
;
constexpr
index_t
X
=
5
;
constexpr
index_t
HPad
=
1
;
constexpr
index_t
WPad
=
1
;
#elif 0
// 5x5 filter, 28x28 image, 2x2 padding
constexpr
index_t
N
=
16
;
constexpr
index_t
C
=
192
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
32
;
constexpr
index_t
Y
=
5
;
constexpr
index_t
X
=
5
;
constexpr
index_t
HPad
=
2
;
constexpr
index_t
WPad
=
2
;
#elif 0
// 1x1 filter, 32x32 image
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
32
;
constexpr
index_t
WI
=
32
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
// 1x1 filter, 14x14 image, C = 2048
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
2048
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 1
// 1x1 filter, 14x14 image, C = 512
// 1x1 filter, 14x14 image, C = 512
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
C
=
512
;
...
@@ -604,7 +380,6 @@ int main(int argc, char* argv[])
...
@@ -604,7 +380,6 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#endif
auto
lower_pads
=
Sequence
<
HPad
,
WPad
>
{};
auto
lower_pads
=
Sequence
<
HPad
,
WPad
>
{};
auto
upper_pads
=
Sequence
<
HPad
,
WPad
>
{};
auto
upper_pads
=
Sequence
<
HPad
,
WPad
>
{};
...
@@ -638,47 +413,12 @@ int main(int argc, char* argv[])
...
@@ -638,47 +413,12 @@ int main(int argc, char* argv[])
if
(
do_verification
)
if
(
do_verification
)
{
{
#if 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif
1
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
#elif 0
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
1
,
5
},
num_thread
);
auto
gen_wei
=
[](
auto
...
is
)
{
return
GeneratorTensor_2
{
1
,
5
}(
is
...)
*
GeneratorTensor_Checkboard
{}(
is
...);
};
wei_kcyx
.
GenerateTensorValue
(
gen_wei
,
num_thread
);
#endif
}
}
#if 1
device_implicit_gemm_convolution_2_chwn_cyxk_khwn
(
#if 0
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
device_direct_convolution_1
#elif
0
device_direct_convolution_2_nchw_kcyx_nkhw
#elif 0
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 0
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
#elif 1
device_implicit_gemm_convolution_2_chwn_cyxk_khwn
#endif
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
#elif 1
device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
lower_pads
,
upper_pads
,
nrepeat
);
#endif
if
(
do_verification
)
if
(
do_verification
)
{
{
...
...
src/include/blockwise_4d_tensor_op.hip.hpp
deleted
100644 → 0
View file @
6166233e
#pragma once
#include "ConstantTensorDescriptor.hip.hpp"
template
<
index_t
BlockSize
,
class
Float
,
class
DstDesc
,
class
F
>
__device__
void
blockwise_4d_tensor_pointwise_operation_unary
(
DstDesc
,
Float
*
__restrict__
p_dst
,
F
f
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
desc
=
make_ConstantTensorDescriptor
(
dst_desc
.
GetLengths
());
#if 0
if(threadIdx.x == 0)
{
print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
}
#endif
constexpr
index_t
NLoop
=
desc
.
GetElementSize
()
/
BlockSize
;
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
index_t
is
=
threadIdx
.
x
+
iloop
*
BlockSize
;
const
index_t
did0
=
is
/
desc
.
GetStride
(
I0
);
is
-=
did0
*
desc
.
GetStride
(
I0
);
const
index_t
did1
=
is
/
desc
.
GetStride
(
I1
);
is
-=
did1
*
desc
.
GetStride
(
I1
);
const
index_t
did2
=
is
/
desc
.
GetStride
(
I2
);
is
-=
did2
*
desc
.
GetStride
(
I2
);
const
index_t
did3
=
is
/
desc
.
GetStride
(
I3
);
const
index_t
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
,
did2
,
did3
);
f
(
p_dst
[
dindex
]);
}
constexpr
bool
has_tail
=
(
desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
index_t
is
=
threadIdx
.
x
+
NLoop
*
BlockSize
;
if
(
is
<
desc
.
GetElementSize
())
{
const
index_t
did0
=
is
/
desc
.
GetStride
(
I0
);
is
-=
did0
*
desc
.
GetStride
(
I0
);
const
index_t
did1
=
is
/
desc
.
GetStride
(
I1
);
is
-=
did1
*
desc
.
GetStride
(
I1
);
const
index_t
did2
=
is
/
desc
.
GetStride
(
I2
);
is
-=
did2
*
desc
.
GetStride
(
I2
);
const
index_t
did3
=
is
/
desc
.
GetStride
(
I3
);
const
index_t
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
,
did2
,
did3
);
f
(
p_dst
[
dindex
]);
}
}
}
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3]
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
DstFromSrcReorder
,
class
F
>
__device__
void
blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
SrcDesc
,
const
Float
*
__restrict__
p_src
,
DstDesc
,
Float
*
__restrict__
p_dst
,
SrcOpLengths
,
DstFromSrcReorder
,
F
f
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
IR0
=
DstFromSrcReorder
{}.
Get
(
I0
);
constexpr
index_t
IR1
=
DstFromSrcReorder
{}.
Get
(
I1
);
constexpr
index_t
IR2
=
DstFromSrcReorder
{}.
Get
(
I2
);
constexpr
index_t
IR3
=
DstFromSrcReorder
{}.
Get
(
I3
);
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
SrcOpLengths
{});
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
index_t
is
=
threadIdx
.
x
+
iloop
*
BlockSize
;
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
is
-=
did
[
2
]
*
ref_desc
.
GetStride
(
I2
);
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
index_t
src_index
=
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
const
index_t
dst_index
=
dst_desc
.
Get1dIndex
(
did
[
IR0
],
did
[
IR1
],
did
[
IR2
],
did
[
IR3
]);
f
(
p_src
[
src_index
],
p_dst
[
dst_index
]);
}
constexpr
bool
has_tail
=
(
ref_desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
index_t
is
=
threadIdx
.
x
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
is
-=
did
[
2
]
*
ref_desc
.
GetStride
(
I2
);
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
index_t
src_index
=
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
const
index_t
dst_index
=
dst_desc
.
Get1dIndex
(
did
[
IR0
],
did
[
IR1
],
did
[
IR2
],
did
[
IR3
]);
f
(
p_src
[
src_index
],
p_dst
[
dst_index
]);
}
}
}
template
<
index_t
BlockSize
,
class
Float
,
class
DstDesc
>
__device__
void
blockwise_4d_tensor_set_zero
(
DstDesc
,
Float
*
__restrict__
p_dst
)
{
auto
f_set_zero
=
[](
Float
&
v
)
{
v
=
Float
(
0
);
};
blockwise_4d_tensor_pointwise_operation_unary
<
BlockSize
>
(
DstDesc
{},
p_dst
,
f_set_zero
);
}
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
DstFromSrcReorder
>
__device__
void
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src
(
SrcDesc
,
const
Float
*
__restrict__
p_src
,
DstDesc
,
Float
*
__restrict__
p_dst
,
SrcOpLengths
,
DstFromSrcReorder
)
{
auto
f_copy
=
[](
const
Float
&
src
,
Float
&
dst
)
{
dst
=
src
;
};
blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
<
BlockSize
>
(
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
DstFromSrcReorder
{},
f_copy
);
}
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
CopyLengths
,
index_t
DataPerRead
>
struct
Blockwise4dTensorCopy1
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
__device__
constexpr
Blockwise4dTensorCopy1
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
static_assert
(
DataPerRead
==
1
||
(
SrcDesc
{}.
GetStride
(
I3
)
==
1
&&
DstDesc
{}.
GetStride
(
I3
)
==
1
),
"wrong! only support stride3 == 1 if DataPerRead > 1!
\n
"
);
static_assert
(
DataPerRead
==
1
||
DataPerRead
==
2
||
DataPerRead
==
4
,
"wrong! only support DataPerRead == 1, 2 or 4!
\n
"
);
static_assert
(
SrcDesc
{}.
GetStride
(
I2
)
%
DataPerRead
==
0
&&
DstDesc
{}.
GetStride
(
I2
)
%
DataPerRead
==
0
,
"src and dst stride2 should be multiple of DataPerRead to keep alignment"
);
// we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride2 is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
read_per_d3
=
integer_divide_ceil
(
L3
,
DataPerRead
);
static_assert
(
read_per_d3
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I2
),
"wrong! out-of-bound write will contaminate next line!
\n
"
);
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
read_per_d3
=
integer_divide_ceil
(
L3
,
DataPerRead
);
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
L0
,
L1
,
L2
,
read_per_d3
>
{});
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
auto
f_copy
=
[
&
](
index_t
is
)
{
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
is
-=
did
[
2
]
*
ref_desc
.
GetStride
(
I2
);
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
index_t
src_index
=
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]
*
DataPerRead
);
const
index_t
dst_index
=
dst_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
dst_index
))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
src_index
));
};
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
index_t
is
=
threadIdx
.
x
+
iloop
*
BlockSize
;
f_copy
(
is
);
}
constexpr
bool
has_tail
=
(
ref_desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
index_t
is
=
threadIdx
.
x
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
f_copy
(
is
);
}
}
}
};
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
DstOpLengths
,
class
GlobalLowerPads
>
struct
BlockwiseChwnTensorCopyPadded
{
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
index_t
c_block_data_begin
,
index_t
ho_block_data_begin
,
index_t
wo_block_data_begin
,
index_t
n_block_data_begin
,
Float
*
__restrict__
p_dst
,
index_t
h_block_pad_low
,
index_t
w_block_pad_low
,
index_t
h_block_pad_up
,
index_t
w_block_pad_up
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
DstOpLengths
{});
constexpr
auto
h_global_pad_low
=
GlobalLowerPads
{}.
Get
(
I0
);
constexpr
auto
w_global_pad_low
=
GlobalLowerPads
{}.
Get
(
I1
);
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
const
Float
*
p_src_tmp
=
p_src
+
src_desc
.
Get1dIndex
(
c_block_data_begin
,
(
ho_block_data_begin
+
h_block_pad_low
)
-
h_global_pad_low
,
(
wo_block_data_begin
+
w_block_pad_low
)
-
w_global_pad_low
,
n_block_data_begin
);
#if 0
if(get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(src_desc, "src_desc: ");
print_ConstantTensorDescriptor(dst_desc, "dst_desc: ");
print_ConstantTensorDescriptor(ref_desc, "ref_desc: ");
printf("%u %u, \t"
"h_global_pad_low %u w_global_pad_low %u \t"
"h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u \t"
"\n",
get_block_1d_id(),
get_thread_local_1d_id(),
h_global_pad_low,
w_global_pad_low,
h_block_pad_low,
w_block_pad_low,
h_block_pad_up,
w_block_pad_up);
}
#endif
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
index_t
is
=
threadIdx
.
x
+
iloop
*
BlockSize
;
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
is
-=
did
[
2
]
*
ref_desc
.
GetStride
(
I2
);
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
index_t
bindex
=
dst_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
p_dst
[
bindex
]
=
(
did
[
1
]
<
h_block_pad_low
||
did
[
1
]
+
h_block_pad_up
>=
ref_desc
.
GetLength
(
I1
)
||
did
[
2
]
<
w_block_pad_low
||
did
[
2
]
+
w_block_pad_up
>=
ref_desc
.
GetLength
(
I2
))
?
Float
(
0
)
:
p_src_tmp
[
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
])];
}
constexpr
bool
has_tail
=
(
ref_desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
index_t
is
=
threadIdx
.
x
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
is
-=
did
[
2
]
*
ref_desc
.
GetStride
(
I2
);
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
index_t
bindex
=
dst_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
p_dst
[
bindex
]
=
(
did
[
1
]
<
h_block_pad_low
||
did
[
1
]
+
h_block_pad_up
>=
ref_desc
.
GetLength
(
I1
)
||
did
[
2
]
<
w_block_pad_low
||
did
[
2
]
+
w_block_pad_up
>=
ref_desc
.
GetLength
(
I2
))
?
Float
(
0
)
:
p_src_tmp
[
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
])];
}
}
}
};
// starting point need to be aligned to float4 or float2 or float
// stride3 need to be 1 for both source and destination
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
CopyLengths
,
class
ThreadPerDims
,
index_t
DataPerRead
>
struct
Blockwise4dTensorCopy3
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
index_t
mSrcMyThreadOffset
;
index_t
mDstMyThreadOffset
;
__device__
Blockwise4dTensorCopy3
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
static_assert
(
DataPerRead
==
1
||
(
SrcDesc
{}.
GetStride
(
I3
)
==
1
&&
DstDesc
{}.
GetStride
(
I3
)
==
1
),
"wrong! only support stride3 == 1 if DataPerRead > 1!
\n
"
);
static_assert
(
DataPerRead
==
1
||
DataPerRead
==
2
||
DataPerRead
==
4
,
"wrong! only support DataPerRead == 1, 2 or 4!
\n
"
);
static_assert
(
SrcDesc
{}.
GetStride
(
I2
)
%
DataPerRead
==
0
&&
DstDesc
{}.
GetStride
(
I2
)
%
DataPerRead
==
0
,
"wrong! src and dst stride2 should be multiple of DataPerRead to keep alignment"
);
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d3
=
ThreadPerDims
{}.
Get
(
I3
);
// we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr
index_t
nloop_d3
=
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
static_assert
(
nloop_d3
*
thread_per_d3
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I2
),
"wrong! out-of-bound write will contaminate next line!
\n
"
);
static_assert
(
L0
%
thread_per_d0
==
0
&&
L1
%
thread_per_d1
==
0
&&
L2
%
thread_per_d2
==
0
,
"wrong! L0, L1, L2 should be divided evenly!
\n
"
);
static_assert
(
BlockSize
>=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
*
thread_per_d3
,
"wrrong! BlockSize is not big enough for ThreadPerDims!"
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
*
thread_per_d3
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
const
index_t
thread_id_d0
=
get_thread_local_1d_id
()
/
(
thread_per_d1
*
thread_per_d2
*
thread_per_d3
);
index_t
itmp
=
get_thread_local_1d_id
()
-
thread_id_d0
*
(
thread_per_d1
*
thread_per_d2
*
thread_per_d3
);
const
index_t
thread_id_d1
=
itmp
/
(
thread_per_d2
*
thread_per_d3
);
itmp
-=
thread_id_d1
*
(
thread_per_d2
*
thread_per_d3
);
const
index_t
thread_id_d2
=
itmp
/
thread_per_d3
;
const
index_t
thread_id_d3
=
itmp
-
thread_id_d2
*
thread_per_d3
;
mSrcMyThreadOffset
=
SrcDesc
{}.
Get1dIndex
(
thread_id_d0
,
thread_id_d1
,
thread_id_d2
,
thread_id_d3
*
DataPerRead
);
mDstMyThreadOffset
=
DstDesc
{}.
Get1dIndex
(
thread_id_d0
,
thread_id_d1
,
thread_id_d2
,
thread_id_d3
*
DataPerRead
);
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d3
=
ThreadPerDims
{}.
Get
(
I3
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
*
thread_per_d3
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
L2
/
thread_per_d2
;
constexpr
index_t
nloop_d3
=
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
#pragma unroll
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
#pragma unroll
for
(
index_t
iloop_d1
=
0
;
iloop_d1
<
nloop_d1
;
++
iloop_d1
)
{
#pragma unroll
for
(
index_t
iloop_d2
=
0
;
iloop_d2
<
nloop_d2
;
++
iloop_d2
)
{
#pragma unroll
for
(
index_t
iloop_d3
=
0
;
iloop_d3
<
nloop_d3
;
++
iloop_d3
)
{
const
index_t
src_offset
=
SrcDesc
{}.
Get1dIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
,
iloop_d3
*
thread_per_d3
*
DataPerRead
);
const
index_t
dst_offset
=
DstDesc
{}.
Get1dIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
,
iloop_d3
*
thread_per_d3
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
dst_offset
+
mDstMyThreadOffset
))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
src_offset
+
mSrcMyThreadOffset
));
}
}
}
}
}
};
src/include/blockwise_batched_gemm.hip.hpp
deleted
100644 → 0
View file @
6166233e
#pragma once
#include "threadwise_gemm.hip.hpp"
template
<
index_t
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
ThreadMatrixC
,
index_t
BlockMatrixStrideA
,
index_t
BlockMatrixStrideB
,
index_t
ThreadMatrixStrideC
,
index_t
BatchSize
,
index_t
MPerThreadSubC
,
index_t
NPerThreadSubC
,
index_t
MLevel0Cluster
,
index_t
NLevel0Cluster
,
index_t
MLevel1Cluster
,
index_t
NLevel1Cluster
,
index_t
KPerThreadLoop
,
index_t
BatchPerThread
>
struct
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{
index_t
mMyThreadOffsetA
=
0
;
index_t
mMyThreadOffsetB
=
0
;
struct
MatrixIndex
{
index_t
batch
;
index_t
row
;
index_t
col
;
};
__device__
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
()
{
static_assert
(
BatchSize
%
BatchPerThread
==
0
,
"wrong! BatchSize is not dividable by BatchPerThread"
);
constexpr
index_t
BatchThreadWork
=
BatchSize
/
BatchPerThread
;
constexpr
index_t
ThreadPerLevel1Cluster
=
MLevel0Cluster
*
NLevel0Cluster
*
MLevel1Cluster
*
NLevel1Cluster
;
static_assert
(
BlockSize
==
BatchThreadWork
*
ThreadPerLevel1Cluster
,
"wrong! wrong blocksize
\n
"
);
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
static_assert
(
a_block_mtx
.
NRow
()
==
b_block_mtx
.
NRow
(),
"wrong! K dimension not consistent
\n
"
);
constexpr
index_t
M
=
a_block_mtx
.
NCol
();
// A is transposed
constexpr
index_t
N
=
b_block_mtx
.
NCol
();
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
static_assert
((
MPerThread
%
MPerThreadSubC
==
0
)
&&
(
NPerThread
%
NPerThreadSubC
==
0
),
"wrong! Cannot evenly divide thread work among repeat
\n
"
);
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_assert
((
M
%
MRepeat
==
0
)
&&
(
N
%
NRepeat
==
0
),
"wrong! Cannot evenly divide work among repeat
\n
"
);
constexpr
index_t
MPerLevel1Cluster
=
M
/
MRepeat
;
constexpr
index_t
NPerLevel1Cluster
=
N
/
NRepeat
;
static_assert
((
MPerLevel1Cluster
%
MLevel1Cluster
==
0
)
&&
(
NPerLevel1Cluster
%
NLevel1Cluster
==
0
),
"wrong! Cannot evenly divide work among Level1Cluster
\n
"
);
constexpr
index_t
MPerLevel0Cluster
=
MPerLevel1Cluster
/
MLevel1Cluster
;
constexpr
index_t
NPerLevel0Cluster
=
NPerLevel1Cluster
/
NLevel1Cluster
;
static_assert
((
MPerLevel0Cluster
%
MLevel0Cluster
==
0
)
&&
(
NPerLevel0Cluster
%
NLevel0Cluster
==
0
),
"wrong! Cannot evenly divide work among Level0Cluster
\n
"
);
static_assert
((
MPerThreadSubC
==
MPerLevel0Cluster
/
MLevel0Cluster
)
&&
(
NPerThreadSubC
==
NPerLevel0Cluster
/
NLevel0Cluster
),
"wrong! thread work size is wrong
\n
"
);
const
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
c_thread_mtx_index
.
batch
*
BlockMatrixStrideA
+
a_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
row
);
mMyThreadOffsetB
=
c_thread_mtx_index
.
batch
*
BlockMatrixStrideB
+
b_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
col
);
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: ");
print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: ");
print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: ");
printf("%u %u, %u %u %u, %u %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
c_thread_mtx_index.batch,
c_thread_mtx_index.row,
c_thread_mtx_index.col,
mMyThreadOffsetA,
mMyThreadOffsetB);
}
#endif
}
__device__
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
const
{
constexpr
index_t
BatchThreadWork
=
BatchSize
/
BatchPerThread
;
constexpr
index_t
ThreadPerLevel1Cluster
=
MLevel0Cluster
*
NLevel0Cluster
*
MLevel1Cluster
*
NLevel1Cluster
;
constexpr
index_t
ThreadPerLevel0Cluster
=
MLevel0Cluster
*
NLevel0Cluster
;
index_t
batch_work_id
=
thread_id
/
ThreadPerLevel1Cluster
;
index_t
cluster_id
=
thread_id
-
batch_work_id
*
ThreadPerLevel1Cluster
;
index_t
level1_id
=
cluster_id
/
ThreadPerLevel0Cluster
;
index_t
level1_m_id
=
level1_id
/
NLevel1Cluster
;
index_t
level1_n_id
=
level1_id
%
NLevel1Cluster
;
index_t
level0_id
=
cluster_id
%
ThreadPerLevel0Cluster
;
index_t
level0_m_id
=
level0_id
/
NLevel0Cluster
;
index_t
level0_n_id
=
level0_id
%
NLevel0Cluster
;
constexpr
index_t
MPerLevel0Cluster
=
MPerThreadSubC
*
MLevel0Cluster
;
constexpr
index_t
NPerLevel0Cluster
=
NPerThreadSubC
*
NLevel0Cluster
;
return
MatrixIndex
{
batch_work_id
*
BatchPerThread
,
level1_m_id
*
MPerLevel0Cluster
+
level0_m_id
*
MPerThreadSubC
,
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
}
// this should be optimized away if input is known
__device__
static
MatrixIndex
GetDistanceFromBeginOfThreadMatrixC
(
index_t
batch_in_c
,
index_t
m_in_c
,
index_t
n_in_c
)
{
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
index_t
m_repeat
=
m_in_c
/
MPerThreadSubC
;
index_t
n_repeat
=
n_in_c
/
NPerThreadSubC
;
index_t
m_in_sub_c
=
m_in_c
%
MPerThreadSubC
;
index_t
n_in_sub_c
=
n_in_c
%
NPerThreadSubC
;
return
MatrixIndex
{
batch_in_c
,
m_repeat
*
MPerLevel1Cluster
+
m_in_sub_c
,
n_repeat
*
NPerLevel1Cluster
+
n_in_sub_c
};
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
// A is transposed, b is not
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// loop over k
#pragma unroll
for
(
index_t
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
// read first batch of A, B
// copy A-sub to form A
#pragma unroll
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
mMyThreadOffsetA
,
a_thread_mtx
,
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
a_thread_sub_mtx
.
GetLengths
());
}
// copy B-sub to form B
#pragma unroll
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
b_block_mtx
.
Get1dIndex
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
)
+
mMyThreadOffsetB
,
b_thread_mtx
,
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
b_thread_sub_mtx
.
GetLengths
());
}
// loop over batch
#pragma unroll
for
(
index_t
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
{
// do current batch of gemm
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
+
ib
*
ThreadMatrixStrideC
,
f_accum
);
// read next batch of a, b
if
(
BlockMatrixStrideA
!=
0
)
{
#pragma unroll
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
(
ib
+
1
)
*
BlockMatrixStrideA
+
mMyThreadOffsetA
,
a_thread_mtx
,
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
a_thread_sub_mtx
.
GetLengths
());
}
}
if
(
BlockMatrixStrideB
!=
0
)
{
#pragma unroll
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
b_block_mtx
.
Get1dIndex
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
)
+
(
ib
+
1
)
*
BlockMatrixStrideB
+
mMyThreadOffsetB
,
b_thread_mtx
,
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
b_thread_sub_mtx
.
GetLengths
());
}
}
}
// do last batch of gemm
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
+
(
BatchPerThread
-
1
)
*
ThreadMatrixStrideC
,
f_accum
);
}
}
template
<
class
BlockMatrixC
,
index_t
BlockMatrixStrideC
,
class
FloatC
>
__device__
void
CopyThreadMatrixCToBlockMatrixC
(
const
FloatC
*
__restrict__
p_c_thread
,
FloatC
*
__restrict__
p_c_block
)
const
{
constexpr
auto
c_block_mtx
=
BlockMatrixC
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
auto
c_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
const
auto
c_thread_mtx_begin
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
c_thread_offset
=
c_thread_mtx_begin
.
batch
*
BlockMatrixStrideC
+
c_block_mtx
.
Get1dIndex
(
c_thread_mtx_begin
.
row
,
c_thread_mtx_begin
.
col
);
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
c_thread_sub_mtx
,
p_c_thread
+
c_thread_sub_mtx
.
Get1dIndex
(
m_repeat
*
MPerLevel1Cluster
,
n_repeat
*
NPerLevel1Cluster
),
c_block_mtx
,
p_c_block
+
c_block_mtx
.
Get1dIndex
(
m_repeat
*
MPerLevel1Cluster
,
n_repeat
*
NPerLevel1Cluster
)
+
c_thread_offset
,
c_thread_sub_mtx
.
GetLengths
());
}
}
}
};
src/include/blockwise_direct_convolution.hip.hpp
deleted
100644 → 0
View file @
6166233e
#pragma once
#include "ConstantTensorDescriptor.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "threadwise_direct_convolution.hip.hpp"
template
<
index_t
BlockSize
,
class
Float
,
class
InBlockDesc
,
class
WeiBlockDesc
,
class
OutBlockDesc
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
CPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
>
__device__
void
blockwise_direct_convolution
(
InBlockDesc
,
Float
*
const
__restrict__
p_in_block
,
WeiBlockDesc
,
Float
*
const
__restrict__
p_wei_block
,
OutBlockDesc
,
Float
*
__restrict__
p_out_block
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_block_desc
=
InBlockDesc
{};
constexpr
auto
wei_block_desc
=
WeiBlockDesc
{};
constexpr
auto
out_block_desc
=
OutBlockDesc
{};
constexpr
index_t
Y
=
wei_block_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_block_desc
.
GetLength
(
I3
);
constexpr
index_t
InTileSizeH
=
HoPerThread
+
Y
-
1
;
constexpr
index_t
InTileSizeW
=
WoPerThread
+
X
-
1
;
// divide thread work
constexpr
index_t
NThreadWork
=
(
out_block_desc
.
GetLength
(
I0
)
+
NPerThread
-
1
)
/
NPerThread
;
constexpr
index_t
KThreadWork
=
(
out_block_desc
.
GetLength
(
I1
)
+
KPerThread
-
1
)
/
KPerThread
;
constexpr
index_t
YThreadWork
=
(
out_block_desc
.
GetLength
(
I2
)
+
HoPerThread
-
1
)
/
HoPerThread
;
constexpr
index_t
XThreadWork
=
(
out_block_desc
.
GetLength
(
I3
)
+
WoPerThread
-
1
)
/
WoPerThread
;
#if 0
if(threadIdx.x == 0)
{
print_ConstantTensorDescriptor(in_block_desc);
print_ConstantTensorDescriptor(wei_block_desc);
print_ConstantTensorDescriptor(out_block_desc);
}
#endif
constexpr
auto
in_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
NPerThread
,
CPerThread
,
InTileSizeH
,
InTileSizeW
>
{});
constexpr
auto
wei_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
CPerThread
,
Y
,
X
>
{});
constexpr
auto
out_thread_desc
=
get_convolution_output_default_4d_tensor_descriptor
(
in_thread_desc
,
wei_thread_desc
);
constexpr
auto
in_thread_block_desc
=
make_ConstantTensorDescriptor
(
in_thread_desc
.
GetLengths
(),
in_block_desc
.
GetStrides
());
constexpr
auto
wei_thread_block_desc
=
make_ConstantTensorDescriptor
(
wei_thread_desc
.
GetLengths
(),
wei_block_desc
.
GetStrides
());
constexpr
auto
out_thread_block_desc
=
make_ConstantTensorDescriptor
(
out_thread_desc
.
GetLengths
(),
out_block_desc
.
GetStrides
());
const
index_t
thread_id
=
threadIdx
.
x
;
for
(
index_t
thread_work_id
=
thread_id
;
thread_work_id
<
NThreadWork
*
KThreadWork
*
YThreadWork
*
XThreadWork
;
thread_work_id
+=
BlockSize
)
{
index_t
itmp
=
thread_work_id
;
index_t
n_thread_work_id
=
itmp
/
(
KThreadWork
*
YThreadWork
*
XThreadWork
);
itmp
-=
n_thread_work_id
*
(
KThreadWork
*
YThreadWork
*
XThreadWork
);
index_t
k_thread_work_id
=
itmp
/
(
YThreadWork
*
XThreadWork
);
itmp
-=
k_thread_work_id
*
(
YThreadWork
*
XThreadWork
);
index_t
y_thread_work_id
=
itmp
/
XThreadWork
;
index_t
x_thread_work_id
=
itmp
-
y_thread_work_id
*
XThreadWork
;
index_t
n_thread_data_begin
=
n_thread_work_id
*
NPerThread
;
index_t
k_thread_data_begin
=
k_thread_work_id
*
KPerThread
;
index_t
ho_thread_data_begin
=
y_thread_work_id
*
HoPerThread
;
index_t
wo_thread_data_begin
=
x_thread_work_id
*
WoPerThread
;
index_t
hi_thread_data_begin
=
ho_thread_data_begin
;
// minus padding
index_t
wi_thread_data_begin
=
wo_thread_data_begin
;
// minus padding
Float
p_out_thread
[
out_thread_desc
.
GetElementSpace
()];
threadwise_4d_tensor_copy
(
out_block_desc
,
p_out_block
+
out_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
k_thread_data_begin
,
ho_thread_data_begin
,
wo_thread_data_begin
),
out_thread_desc
,
p_out_thread
,
out_thread_desc
.
GetLengths
());
for
(
index_t
c_thread_data_begin
=
0
;
c_thread_data_begin
<
in_block_desc
.
GetLength
(
I1
);
c_thread_data_begin
+=
CPerThread
)
{
// threadwise convolution
threadwise_direct_convolution_2
(
in_thread_block_desc
,
p_in_block
+
in_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data_begin
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wei_thread_block_desc
,
p_wei_block
+
wei_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data_begin
,
0
,
0
),
out_thread_desc
,
p_out_thread
);
}
// copy output into LDS
threadwise_4d_tensor_copy
(
out_thread_desc
,
p_out_thread
,
out_block_desc
,
p_out_block
+
out_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
k_thread_data_begin
,
ho_thread_data_begin
,
wo_thread_data_begin
),
out_thread_desc
.
GetLengths
());
}
}
src/include/blockwise_gemm.hip.hpp
View file @
569941a7
...
@@ -123,599 +123,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -123,599 +123,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
n_repeat
*
NPerLevel1Cluster
+
n_in_sub_c
};
n_repeat
*
NPerLevel1Cluster
+
n_in_sub_c
};
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run_asm
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
{
#if DEVICE_BACKEND_HIP
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
M
=
a_block_mtx
.
NCol
();
constexpr
index_t
N
=
b_block_mtx
.
NCol
();
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
float
p_thread
[
a_thread_mtx
.
GetElementSpace
()
+
b_thread_mtx
.
GetElementSpace
()];
FloatA
*
p_a_thread
=
p_thread
;
FloatB
*
p_b_thread
=
p_thread
+
a_thread_mtx
.
GetElementSpace
();
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
#pragma unroll
// loop over k
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
#if 1
auto
a_src_index
=
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
;
auto
b_src_index
=
b_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetB
;
const
float4
*
a_loc
=
(
const
float4
*
)(
p_a_block
+
a_src_index
);
const
float4
*
b_loc
=
(
const
float4
*
)(
p_b_block
+
b_src_index
);
float4
*
reg
=
(
float4
*
)(
p_thread
);
reg
[
0
]
=
a_loc
[
0
];
reg
[
1
]
=
a_loc
[
16
];
reg
[
2
]
=
b_loc
[
0
];
reg
[
3
]
=
b_loc
[
8
];
//asm volatile("\n \
//ds_read2_b64 %0, %1 offset1:1 \n \
//s_waitcnt lgkmcnt(0)"
//: "=v"(reg[0])
//: "v"(__to_local((void *)(a_loc)))
//);
//asm volatile("\n \
//ds_read2_b64 %0, %1 offset1:1 \n \
//s_waitcnt lgkmcnt(0)"
//: "=v"(reg[1])
//: "v"(__to_local((void *)(a_loc + 16)))
//);
//asm volatile("\n \
//ds_read2_b64 %0, %1 offset1:1 \n \
//s_waitcnt lgkmcnt(0)"
//: "=v"(reg[2])
//: "v"(__to_local((void *)(b_loc)))
//);
//asm volatile("\n \
//ds_read2_b64 %0, %1 offset1:1 \n \
//s_waitcnt lgkmcnt(0)"
//: "=v"(reg[3])
//: "v"(__to_local((void *)(b_loc + 8)))
//);
//asm volatile("\n \
//ds_read2_b64 %0, %4 offset1:1 \n \
//ds_read2_b64 %1, %4 offset0:32 offset1:33 \n \
//ds_read2_b64 %2, %5 offset1:1 \n \
//ds_read2_b64 %3, %5 offset0:16 offset1:17 \n \
//s_waitcnt lgkmcnt(0)"
//: "=v"(reg[0]), "=v"(reg[1]), "=v"(reg[2]), "=v"(reg[3])
//: "v"(__to_local((void *)(a_loc))), "v"(__to_local((void *)(b_loc)))
//);
//asm volatile("\n \
//ds_read_b32 %0, %16 \n \
//ds_read_b32 %1, %16 offset:1\n \
//ds_read_b32 %2, %16 offset:2\n \
//ds_read_b32 %3, %16 offset:3\n \
//ds_read_b32 %4, %17 \n \
//ds_read_b32 %5, %17 offset:1\n \
//ds_read_b32 %6, %17 offset:2\n \
//ds_read_b32 %7, %17 offset:3\n \
//ds_read_b32 %8, %18 \n \
//ds_read_b32 %9, %18 offset:1\n \
//ds_read_b32 %10, %18 offset:2\n \
//ds_read_b32 %11, %18 offset:3\n \
//ds_read_b32 %12, %19 \n \
//ds_read_b32 %13, %19 offset:1\n \
//ds_read_b32 %14, %19 offset:2\n \
//ds_read_b32 %15, %19 offset:3\n \
//s_waitcnt lgkmcnt(0)"
//:
//"=v"(p_a_thread[0]),
//"=v"(p_a_thread[1]),
//"=v"(p_a_thread[2]),
//"=v"(p_a_thread[3]),
//"=v"(p_a_thread[4]),
//"=v"(p_a_thread[5]),
//"=v"(p_a_thread[6]),
//"=v"(p_a_thread[7]),
//"=v"(p_b_thread[0]),
//"=v"(p_b_thread[1]),
//"=v"(p_b_thread[2]),
//"=v"(p_b_thread[3]),
//"=v"(p_b_thread[4]),
//"=v"(p_b_thread[5]),
//"=v"(p_b_thread[6]),
//"=v"(p_b_thread[7])
//:
//"v"(__to_local((void *)(&p_a_block[0]))),
//"v"(__to_local((void *)(&p_a_block[64]))),
//"v"(__to_local((void *)(&p_b_block[0]))),
//"v"(__to_local((void *)(&p_b_block[32])))
//);
// C = A * B
asm
volatile
(
"
\n
\
v_mac_f32 %0, %64, %72
\n
\
v_mac_f32 %1, %64, %73
\n
\
v_mac_f32 %2, %64, %74
\n
\
v_mac_f32 %3, %64, %75
\n
\
v_mac_f32 %4, %64, %76
\n
\
v_mac_f32 %5, %64, %77
\n
\
v_mac_f32 %6, %64, %78
\n
\
v_mac_f32 %7, %64, %79
\n
\
v_mac_f32 %8, %65, %72
\n
\
v_mac_f32 %9, %65, %73
\n
\
v_mac_f32 %10, %65, %74
\n
\
v_mac_f32 %11, %65, %75
\n
\
v_mac_f32 %12, %65, %76
\n
\
v_mac_f32 %13, %65, %77
\n
\
v_mac_f32 %14, %65, %78
\n
\
v_mac_f32 %15, %65, %79
\n
\
v_mac_f32 %16, %66, %72
\n
\
v_mac_f32 %17, %66, %73
\n
\
v_mac_f32 %18, %66, %74
\n
\
v_mac_f32 %19, %66, %75
\n
\
v_mac_f32 %20, %66, %76
\n
\
v_mac_f32 %21, %66, %77
\n
\
v_mac_f32 %22, %66, %78
\n
\
v_mac_f32 %23, %66, %79
\n
\
v_mac_f32 %24, %67, %72
\n
\
v_mac_f32 %25, %67, %73
\n
\
v_mac_f32 %26, %67, %74
\n
\
v_mac_f32 %27, %67, %75
\n
\
v_mac_f32 %28, %67, %76
\n
\
v_mac_f32 %29, %67, %77
\n
\
v_mac_f32 %30, %67, %78
\n
\
v_mac_f32 %31, %67, %79
\n
\
v_mac_f32 %32, %68, %72
\n
\
v_mac_f32 %33, %68, %73
\n
\
v_mac_f32 %34, %68, %74
\n
\
v_mac_f32 %35, %68, %75
\n
\
v_mac_f32 %36, %68, %76
\n
\
v_mac_f32 %37, %68, %77
\n
\
v_mac_f32 %38, %68, %78
\n
\
v_mac_f32 %39, %68, %79
\n
\
v_mac_f32 %40, %69, %72
\n
\
v_mac_f32 %41, %69, %73
\n
\
v_mac_f32 %42, %69, %74
\n
\
v_mac_f32 %43, %69, %75
\n
\
v_mac_f32 %44, %69, %76
\n
\
v_mac_f32 %45, %69, %77
\n
\
v_mac_f32 %46, %69, %78
\n
\
v_mac_f32 %47, %69, %79
\n
\
v_mac_f32 %48, %70, %72
\n
\
v_mac_f32 %49, %70, %73
\n
\
v_mac_f32 %50, %70, %74
\n
\
v_mac_f32 %51, %70, %75
\n
\
v_mac_f32 %52, %70, %76
\n
\
v_mac_f32 %53, %70, %77
\n
\
v_mac_f32 %54, %70, %78
\n
\
v_mac_f32 %55, %70, %79
\n
\
v_mac_f32 %56, %71, %72
\n
\
v_mac_f32 %57, %71, %73
\n
\
v_mac_f32 %58, %71, %74
\n
\
v_mac_f32 %59, %71, %75
\n
\
v_mac_f32 %60, %71, %76
\n
\
v_mac_f32 %61, %71, %77
\n
\
v_mac_f32 %62, %71, %78
\n
\
v_mac_f32 %63, %71, %79
\n
\
"
:
"=v"
(
p_c_thread
[
0
]),
"=v"
(
p_c_thread
[
1
]),
"=v"
(
p_c_thread
[
2
]),
"=v"
(
p_c_thread
[
3
]),
"=v"
(
p_c_thread
[
4
]),
"=v"
(
p_c_thread
[
5
]),
"=v"
(
p_c_thread
[
6
]),
"=v"
(
p_c_thread
[
7
]),
"=v"
(
p_c_thread
[
8
]),
"=v"
(
p_c_thread
[
9
]),
"=v"
(
p_c_thread
[
10
]),
"=v"
(
p_c_thread
[
11
]),
"=v"
(
p_c_thread
[
12
]),
"=v"
(
p_c_thread
[
13
]),
"=v"
(
p_c_thread
[
14
]),
"=v"
(
p_c_thread
[
15
]),
"=v"
(
p_c_thread
[
16
]),
"=v"
(
p_c_thread
[
17
]),
"=v"
(
p_c_thread
[
18
]),
"=v"
(
p_c_thread
[
19
]),
"=v"
(
p_c_thread
[
20
]),
"=v"
(
p_c_thread
[
21
]),
"=v"
(
p_c_thread
[
22
]),
"=v"
(
p_c_thread
[
23
]),
"=v"
(
p_c_thread
[
24
]),
"=v"
(
p_c_thread
[
25
]),
"=v"
(
p_c_thread
[
26
]),
"=v"
(
p_c_thread
[
27
]),
"=v"
(
p_c_thread
[
28
]),
"=v"
(
p_c_thread
[
29
]),
"=v"
(
p_c_thread
[
30
]),
"=v"
(
p_c_thread
[
31
]),
"=v"
(
p_c_thread
[
32
]),
"=v"
(
p_c_thread
[
33
]),
"=v"
(
p_c_thread
[
34
]),
"=v"
(
p_c_thread
[
35
]),
"=v"
(
p_c_thread
[
36
]),
"=v"
(
p_c_thread
[
37
]),
"=v"
(
p_c_thread
[
38
]),
"=v"
(
p_c_thread
[
39
]),
"=v"
(
p_c_thread
[
40
]),
"=v"
(
p_c_thread
[
41
]),
"=v"
(
p_c_thread
[
42
]),
"=v"
(
p_c_thread
[
43
]),
"=v"
(
p_c_thread
[
44
]),
"=v"
(
p_c_thread
[
45
]),
"=v"
(
p_c_thread
[
46
]),
"=v"
(
p_c_thread
[
47
]),
"=v"
(
p_c_thread
[
48
]),
"=v"
(
p_c_thread
[
49
]),
"=v"
(
p_c_thread
[
50
]),
"=v"
(
p_c_thread
[
51
]),
"=v"
(
p_c_thread
[
52
]),
"=v"
(
p_c_thread
[
53
]),
"=v"
(
p_c_thread
[
54
]),
"=v"
(
p_c_thread
[
55
]),
"=v"
(
p_c_thread
[
56
]),
"=v"
(
p_c_thread
[
57
]),
"=v"
(
p_c_thread
[
58
]),
"=v"
(
p_c_thread
[
59
]),
"=v"
(
p_c_thread
[
60
]),
"=v"
(
p_c_thread
[
61
]),
"=v"
(
p_c_thread
[
62
]),
"=v"
(
p_c_thread
[
63
])
:
"v"
(
p_a_thread
[
0
]),
"v"
(
p_a_thread
[
1
]),
"v"
(
p_a_thread
[
2
]),
"v"
(
p_a_thread
[
3
]),
"v"
(
p_a_thread
[
4
]),
"v"
(
p_a_thread
[
5
]),
"v"
(
p_a_thread
[
6
]),
"v"
(
p_a_thread
[
7
]),
"v"
(
p_b_thread
[
0
]),
"v"
(
p_b_thread
[
1
]),
"v"
(
p_b_thread
[
2
]),
"v"
(
p_b_thread
[
3
]),
"v"
(
p_b_thread
[
4
]),
"v"
(
p_b_thread
[
5
]),
"v"
(
p_b_thread
[
6
]),
"v"
(
p_b_thread
[
7
]),
"0"
(
p_c_thread
[
0
]),
"1"
(
p_c_thread
[
1
]),
"2"
(
p_c_thread
[
2
]),
"3"
(
p_c_thread
[
3
]),
"4"
(
p_c_thread
[
4
]),
"5"
(
p_c_thread
[
5
]),
"6"
(
p_c_thread
[
6
]),
"7"
(
p_c_thread
[
7
]),
"8"
(
p_c_thread
[
8
]),
"9"
(
p_c_thread
[
9
]),
"10"
(
p_c_thread
[
10
]),
"11"
(
p_c_thread
[
11
]),
"12"
(
p_c_thread
[
12
]),
"13"
(
p_c_thread
[
13
]),
"14"
(
p_c_thread
[
14
]),
"15"
(
p_c_thread
[
15
]),
"16"
(
p_c_thread
[
16
]),
"17"
(
p_c_thread
[
17
]),
"18"
(
p_c_thread
[
18
]),
"19"
(
p_c_thread
[
19
]),
"20"
(
p_c_thread
[
20
]),
"21"
(
p_c_thread
[
21
]),
"22"
(
p_c_thread
[
22
]),
"23"
(
p_c_thread
[
23
]),
"24"
(
p_c_thread
[
24
]),
"25"
(
p_c_thread
[
25
]),
"26"
(
p_c_thread
[
26
]),
"27"
(
p_c_thread
[
27
]),
"28"
(
p_c_thread
[
28
]),
"29"
(
p_c_thread
[
29
]),
"30"
(
p_c_thread
[
30
]),
"31"
(
p_c_thread
[
31
]),
"32"
(
p_c_thread
[
32
]),
"33"
(
p_c_thread
[
33
]),
"34"
(
p_c_thread
[
34
]),
"35"
(
p_c_thread
[
35
]),
"36"
(
p_c_thread
[
36
]),
"37"
(
p_c_thread
[
37
]),
"38"
(
p_c_thread
[
38
]),
"39"
(
p_c_thread
[
39
]),
"40"
(
p_c_thread
[
40
]),
"41"
(
p_c_thread
[
41
]),
"42"
(
p_c_thread
[
42
]),
"43"
(
p_c_thread
[
43
]),
"44"
(
p_c_thread
[
44
]),
"45"
(
p_c_thread
[
45
]),
"46"
(
p_c_thread
[
46
]),
"47"
(
p_c_thread
[
47
]),
"48"
(
p_c_thread
[
48
]),
"49"
(
p_c_thread
[
49
]),
"50"
(
p_c_thread
[
50
]),
"51"
(
p_c_thread
[
51
]),
"52"
(
p_c_thread
[
52
]),
"53"
(
p_c_thread
[
53
]),
"54"
(
p_c_thread
[
54
]),
"55"
(
p_c_thread
[
55
]),
"56"
(
p_c_thread
[
56
]),
"57"
(
p_c_thread
[
57
]),
"58"
(
p_c_thread
[
58
]),
"59"
(
p_c_thread
[
59
]),
"60"
(
p_c_thread
[
60
]),
"61"
(
p_c_thread
[
61
]),
"62"
(
p_c_thread
[
62
]),
"63"
(
p_c_thread
[
63
]));
#else
auto
a_src_index
=
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
;
auto
b_src_index
=
b_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetB
;
auto
dst_index
=
a_thread_sub_mtx
.
Get1dIndex
(
0
,
0
);
const
float4
*
a_loc
=
(
const
float4
*
)(
p_a_block
+
a_src_index
);
const
float4
*
b_loc
=
(
const
float4
*
)(
p_b_block
+
b_src_index
);
float4
*
reg
=
(
float4
*
)(
p_a_thread
+
dst_index
);
asm
volatile
(
"
\n
\
ds_read2_b64 %0, %84 offset1:1
\n
\
ds_read2_b64 %1, %84 offset0:32 offset1:33
\n
\
ds_read2_b64 %2, %85 offset1:1
\n
\
ds_read2_b64 %3, %85 offset0:16 offset1:17
\n
\
s_waitcnt lgkmcnt(0)
\n
\
v_mac_f32 %4, %68, %76
\n
\
v_mac_f32 %5, %68, %77
\n
\
v_mac_f32 %6, %68, %78
\n
\
v_mac_f32 %7, %68, %79
\n
\
v_mac_f32 %8, %68, %80
\n
\
v_mac_f32 %9, %68, %81
\n
\
v_mac_f32 %10, %68, %82
\n
\
v_mac_f32 %11, %68, %83
\n
\
v_mac_f32 %12, %69, %76
\n
\
v_mac_f32 %13, %69, %77
\n
\
v_mac_f32 %14, %69, %78
\n
\
v_mac_f32 %15, %69, %79
\n
\
v_mac_f32 %16, %69, %80
\n
\
v_mac_f32 %17, %69, %81
\n
\
v_mac_f32 %18, %69, %82
\n
\
v_mac_f32 %19, %69, %83
\n
\
v_mac_f32 %20, %70, %76
\n
\
v_mac_f32 %21, %70, %77
\n
\
v_mac_f32 %22, %70, %78
\n
\
v_mac_f32 %23, %70, %79
\n
\
v_mac_f32 %24, %70, %80
\n
\
v_mac_f32 %25, %70, %81
\n
\
v_mac_f32 %26, %70, %82
\n
\
v_mac_f32 %27, %70, %83
\n
\
v_mac_f32 %28, %71, %76
\n
\
v_mac_f32 %29, %71, %77
\n
\
v_mac_f32 %30, %71, %78
\n
\
v_mac_f32 %31, %71, %79
\n
\
v_mac_f32 %32, %71, %80
\n
\
v_mac_f32 %33, %71, %81
\n
\
v_mac_f32 %34, %71, %82
\n
\
v_mac_f32 %35, %71, %83
\n
\
v_mac_f32 %36, %72, %76
\n
\
v_mac_f32 %37, %72, %77
\n
\
v_mac_f32 %38, %72, %78
\n
\
v_mac_f32 %39, %72, %79
\n
\
v_mac_f32 %40, %72, %80
\n
\
v_mac_f32 %41, %72, %81
\n
\
v_mac_f32 %42, %72, %82
\n
\
v_mac_f32 %43, %72, %83
\n
\
v_mac_f32 %44, %73, %76
\n
\
v_mac_f32 %45, %73, %77
\n
\
v_mac_f32 %46, %73, %78
\n
\
v_mac_f32 %47, %73, %79
\n
\
v_mac_f32 %48, %73, %80
\n
\
v_mac_f32 %49, %73, %81
\n
\
v_mac_f32 %50, %73, %82
\n
\
v_mac_f32 %51, %73, %83
\n
\
v_mac_f32 %52, %74, %76
\n
\
v_mac_f32 %53, %74, %77
\n
\
v_mac_f32 %54, %74, %78
\n
\
v_mac_f32 %55, %74, %79
\n
\
v_mac_f32 %56, %74, %80
\n
\
v_mac_f32 %57, %74, %81
\n
\
v_mac_f32 %58, %74, %82
\n
\
v_mac_f32 %59, %74, %83
\n
\
v_mac_f32 %60, %75, %76
\n
\
v_mac_f32 %61, %75, %77
\n
\
v_mac_f32 %62, %75, %78
\n
\
v_mac_f32 %63, %75, %79
\n
\
v_mac_f32 %64, %75, %80
\n
\
v_mac_f32 %65, %75, %81
\n
\
v_mac_f32 %66, %75, %82
\n
\
v_mac_f32 %67, %75, %83
\n
\
"
:
"=v"
(
reg
[
0
]),
"=v"
(
reg
[
1
]),
"=v"
(
reg
[
2
]),
"=v"
(
reg
[
3
]),
"=v"
(
p_c_thread
[
0
]),
"=v"
(
p_c_thread
[
1
]),
"=v"
(
p_c_thread
[
2
]),
"=v"
(
p_c_thread
[
3
]),
"=v"
(
p_c_thread
[
4
]),
"=v"
(
p_c_thread
[
5
]),
"=v"
(
p_c_thread
[
6
]),
"=v"
(
p_c_thread
[
7
]),
"=v"
(
p_c_thread
[
8
]),
"=v"
(
p_c_thread
[
9
]),
"=v"
(
p_c_thread
[
10
]),
"=v"
(
p_c_thread
[
11
]),
"=v"
(
p_c_thread
[
12
]),
"=v"
(
p_c_thread
[
13
]),
"=v"
(
p_c_thread
[
14
]),
"=v"
(
p_c_thread
[
15
]),
"=v"
(
p_c_thread
[
16
]),
"=v"
(
p_c_thread
[
17
]),
"=v"
(
p_c_thread
[
18
]),
"=v"
(
p_c_thread
[
19
]),
"=v"
(
p_c_thread
[
20
]),
"=v"
(
p_c_thread
[
21
]),
"=v"
(
p_c_thread
[
22
]),
"=v"
(
p_c_thread
[
23
]),
"=v"
(
p_c_thread
[
24
]),
"=v"
(
p_c_thread
[
25
]),
"=v"
(
p_c_thread
[
26
]),
"=v"
(
p_c_thread
[
27
]),
"=v"
(
p_c_thread
[
28
]),
"=v"
(
p_c_thread
[
29
]),
"=v"
(
p_c_thread
[
30
]),
"=v"
(
p_c_thread
[
31
]),
"=v"
(
p_c_thread
[
32
]),
"=v"
(
p_c_thread
[
33
]),
"=v"
(
p_c_thread
[
34
]),
"=v"
(
p_c_thread
[
35
]),
"=v"
(
p_c_thread
[
36
]),
"=v"
(
p_c_thread
[
37
]),
"=v"
(
p_c_thread
[
38
]),
"=v"
(
p_c_thread
[
39
]),
"=v"
(
p_c_thread
[
40
]),
"=v"
(
p_c_thread
[
41
]),
"=v"
(
p_c_thread
[
42
]),
"=v"
(
p_c_thread
[
43
]),
"=v"
(
p_c_thread
[
44
]),
"=v"
(
p_c_thread
[
45
]),
"=v"
(
p_c_thread
[
46
]),
"=v"
(
p_c_thread
[
47
]),
"=v"
(
p_c_thread
[
48
]),
"=v"
(
p_c_thread
[
49
]),
"=v"
(
p_c_thread
[
50
]),
"=v"
(
p_c_thread
[
51
]),
"=v"
(
p_c_thread
[
52
]),
"=v"
(
p_c_thread
[
53
]),
"=v"
(
p_c_thread
[
54
]),
"=v"
(
p_c_thread
[
55
]),
"=v"
(
p_c_thread
[
56
]),
"=v"
(
p_c_thread
[
57
]),
"=v"
(
p_c_thread
[
58
]),
"=v"
(
p_c_thread
[
59
]),
"=v"
(
p_c_thread
[
60
]),
"=v"
(
p_c_thread
[
61
]),
"=v"
(
p_c_thread
[
62
]),
"=v"
(
p_c_thread
[
63
])
:
"v"
(
p_a_thread
[
0
]),
"v"
(
p_a_thread
[
1
]),
"v"
(
p_a_thread
[
2
]),
"v"
(
p_a_thread
[
3
]),
"v"
(
p_a_thread
[
4
]),
"v"
(
p_a_thread
[
5
]),
"v"
(
p_a_thread
[
6
]),
"v"
(
p_a_thread
[
7
]),
"v"
(
p_b_thread
[
0
]),
"v"
(
p_b_thread
[
1
]),
"v"
(
p_b_thread
[
2
]),
"v"
(
p_b_thread
[
3
]),
"v"
(
p_b_thread
[
4
]),
"v"
(
p_b_thread
[
5
]),
"v"
(
p_b_thread
[
6
]),
"v"
(
p_b_thread
[
7
]),
"v"
(
__to_local
((
void
*
)(
a_loc
))),
"v"
(
__to_local
((
void
*
)(
b_loc
))),
"4"
(
p_c_thread
[
0
]),
"5"
(
p_c_thread
[
1
]),
"6"
(
p_c_thread
[
2
]),
"7"
(
p_c_thread
[
3
]),
"8"
(
p_c_thread
[
4
]),
"9"
(
p_c_thread
[
5
]),
"10"
(
p_c_thread
[
6
]),
"11"
(
p_c_thread
[
7
]),
"12"
(
p_c_thread
[
8
]),
"13"
(
p_c_thread
[
9
]),
"14"
(
p_c_thread
[
10
]),
"15"
(
p_c_thread
[
11
]),
"16"
(
p_c_thread
[
12
]),
"17"
(
p_c_thread
[
13
]),
"18"
(
p_c_thread
[
14
]),
"19"
(
p_c_thread
[
15
]),
"20"
(
p_c_thread
[
16
]),
"21"
(
p_c_thread
[
17
]),
"22"
(
p_c_thread
[
18
]),
"23"
(
p_c_thread
[
19
]),
"24"
(
p_c_thread
[
20
]),
"25"
(
p_c_thread
[
21
]),
"26"
(
p_c_thread
[
22
]),
"27"
(
p_c_thread
[
23
]),
"28"
(
p_c_thread
[
24
]),
"29"
(
p_c_thread
[
25
]),
"30"
(
p_c_thread
[
26
]),
"31"
(
p_c_thread
[
27
]),
"32"
(
p_c_thread
[
28
]),
"33"
(
p_c_thread
[
29
]),
"34"
(
p_c_thread
[
30
]),
"35"
(
p_c_thread
[
31
]),
"36"
(
p_c_thread
[
32
]),
"37"
(
p_c_thread
[
33
]),
"38"
(
p_c_thread
[
34
]),
"39"
(
p_c_thread
[
35
]),
"40"
(
p_c_thread
[
36
]),
"41"
(
p_c_thread
[
37
]),
"42"
(
p_c_thread
[
38
]),
"43"
(
p_c_thread
[
39
]),
"44"
(
p_c_thread
[
40
]),
"45"
(
p_c_thread
[
41
]),
"46"
(
p_c_thread
[
42
]),
"47"
(
p_c_thread
[
43
]),
"48"
(
p_c_thread
[
44
]),
"49"
(
p_c_thread
[
45
]),
"50"
(
p_c_thread
[
46
]),
"51"
(
p_c_thread
[
47
]),
"52"
(
p_c_thread
[
48
]),
"53"
(
p_c_thread
[
49
]),
"54"
(
p_c_thread
[
50
]),
"55"
(
p_c_thread
[
51
]),
"56"
(
p_c_thread
[
52
]),
"57"
(
p_c_thread
[
53
]),
"58"
(
p_c_thread
[
54
]),
"59"
(
p_c_thread
[
55
]),
"60"
(
p_c_thread
[
56
]),
"61"
(
p_c_thread
[
57
]),
"62"
(
p_c_thread
[
58
]),
"63"
(
p_c_thread
[
59
]),
"64"
(
p_c_thread
[
60
]),
"65"
(
p_c_thread
[
61
]),
"66"
(
p_c_thread
[
62
]),
"67"
(
p_c_thread
[
63
]));
#endif
}
#else
printf
(
"asm only support on HIP backend
\n
"
);
assert
(
false
);
#endif
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run
(
const
FloatA
*
const
__restrict__
p_a_block
,
__device__
void
Run
(
const
FloatA
*
const
__restrict__
p_a_block
,
const
FloatB
*
const
__restrict__
p_b_block
,
const
FloatB
*
const
__restrict__
p_b_block
,
...
...
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp
View file @
569941a7
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
#include "common.hip.hpp"
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_2d_tensor_op.hip.hpp"
#include "threadwise_2d_tensor_op.hip.hpp"
#include "blockwise_gemm.hip.hpp"
#include "blockwise_gemm.hip.hpp"
...
@@ -284,8 +283,6 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
...
@@ -284,8 +283,6 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
blockwise_gemm
.
Run
blockwise_gemm
.
Run
#elif 0
#elif 0
blockwise_gemm
.
Run_RegisterDoubleBuffer
blockwise_gemm
.
Run_RegisterDoubleBuffer
#elif 0
blockwise_gemm
.
Run_asm
#endif
#endif
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block
+
y
*
Wi
+
x
,
p_in_block
+
y
*
Wi
+
x
,
...
...
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
deleted
100644 → 0
View file @
6166233e
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_2d_tensor_op.hip.hpp"
#include "blockwise_gemm.hip.hpp"
// define B = flatten(N, Hi, Wi)
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
BPerThread
,
index_t
KPerThread
,
index_t
GemmThreadPerColumnPerCluster
,
index_t
GemmThreadPerRowPerCluster
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
InBlockCopyThreadPerDim0
,
index_t
InBlockCopyThreadPerDim1
,
index_t
WeiBlockCopyThreadPerDim0
,
index_t
WeiBlockCopyThreadPerDim1
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
>
struct
GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
{
__host__
__device__
constexpr
GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
()
{
}
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_chwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Hi
=
in_chwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wi
=
in_chwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
in_chwn_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_khwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_khwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wo
=
out_khwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Y
=
wei_cyxk_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_cyxk_global_desc
.
GetLength
(
I2
);
constexpr
index_t
B
=
N
*
Hi
*
Wi
;
constexpr
index_t
BGhostRead
=
(
Y
-
1
)
*
Wi
+
(
X
-
1
);
// divide block work by 2d: [K, B]
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
(
B
+
BPerBlock
-
1
)
/
BPerBlock
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
BBlockWork
;
const
index_t
b_block_work_id
=
get_block_1d_id
()
-
k_block_work_id
*
BBlockWork
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
b_block_data_begin
=
b_block_work_id
*
BPerBlock
;
// flattend (2d) tensor view of gridwise input
constexpr
auto
in_cb_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
B
>
{});
constexpr
auto
wei_ek_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
*
Y
*
X
,
K
>
{});
// tensor view of blockwise input and weight
// be careful of alignment
constexpr
auto
in_cb_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
BPerBlock
+
BGhostRead
>
{},
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
auto
wei_ek_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
*
Y
*
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_kb_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
BPerThread
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_chwn_global_desc, "in_chwn_global_desc");
print_ConstantTensorDescriptor(wei_cyxk_global_desc, "wei_cyxk_global_desc");
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
print_ConstantTensorDescriptor(in_cb_global_desc, "in_cb_global_desc");
print_ConstantTensorDescriptor(wei_ek_global_desc, "wei_ek_global_desc");
print_ConstantTensorDescriptor(in_cb_block_desc, "in_cb_block_desc");
print_ConstantTensorDescriptor(wei_cyxk_block_desc, "wei_cyxk_block_desc");
print_ConstantTensorDescriptor(wei_ek_block_desc, "wei_ek_block_desc");
print_ConstantTensorDescriptor(out_kb_thread_desc, "out_kb_thread_desc");
printf("KPerBlock %u\n", KPerBlock);
}
#endif
// blockwise in copy
// formmat is [CPerBlock,BPerBlock + BGhostRead]
#if 0
const auto blockwise_in_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{};
#elif
0
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
()),
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
>
{};
#elif 1
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
()),
InBlockCopyDataPerRead
>
{};
#endif
// blockwise wei copy
// format is [CPerBlock*Y*X,KPerBlock]
#if 0
const auto blockwise_wei_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
#elif
0
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
{};
#elif 1
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
{};
#endif
// a series of blockwise GEMM
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
// a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
// c_mtx[K,B] is out_block[K,B]
constexpr
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_cyxk_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_cxb_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
BPerBlock
>
{},
Number
<
in_cb_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_kxb_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
BPerThread
>
{});
#if 0
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC<BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxb_block_mtx_desc),
decltype(c_kxb_thread_mtx_desc),
true,
false,
false,
GemmKPerThreadLoop,
GemmThreadPerColumnPerCluster,
GemmThreadPerRowPerCluster,
true>{};
#else
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxb_block_mtx_desc
),
decltype
(
c_kxb_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
>
{};
#endif
// LDS: be careful of alignment
constexpr
index_t
in_block_element_size
=
in_cb_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
index_t
wei_block_element_size
=
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
?
InBlockCopyDataPerRead
:
WeiBlockCopyDataPerRead
;
// LDS double buffer
__shared__
Float
p_in_block_0
[
max_align
*
((
in_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block_0
[
max_align
*
((
wei_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_in_block_1
[
max_align
*
((
in_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block_1
[
max_align
*
((
wei_block_element_size
+
max_align
-
1
)
/
max_align
)];
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_cb_global_desc
.
Get1dIndex
(
0
,
b_block_data_begin
);
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_cyxk_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
// preload data into LDS
blockwise_in_copy
.
Run
(
p_in_global_block_offset
,
p_in_block_0
);
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
,
p_wei_block_0
);
p_in_global_block_offset
+=
CPerBlock
*
in_cb_global_desc
.
GetStride
(
I0
);
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
);
// register
Float
p_out_thread
[
out_kb_thread_desc
.
GetElementSpace
()];
// set threadwise output tensor to 0
threadwise_2d_tensor_set_zero
(
out_kb_thread_desc
,
p_out_thread
);
bool
even_loop
=
true
;
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
+
CPerBlock
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_in_global_block_offset
+=
CPerBlock
*
in_cb_global_desc
.
GetStride
(
I0
),
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
),
even_loop
=
!
even_loop
)
{
Float
*
p_in_block_now
=
even_loop
?
p_in_block_0
:
p_in_block_1
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_0
:
p_wei_block_1
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_1
:
p_in_block_0
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_1
:
p_wei_block_0
;
__syncthreads
();
// load next data
#if 0
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next);
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next);
#elif
1
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
#endif
// compute on current data
// a series of GEMM
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
#if 1
blockwise_gemm
.
Run
#else
blockwise_gemm
.
Run_RegisterDoubleBuffer
#endif
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_now
+
y
*
Wi
+
x
,
p_out_thread
,
f_accum
);
}
}
#if 1
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
#endif
}
// last computation
{
Float
*
p_in_block_now
=
even_loop
?
p_in_block_0
:
p_in_block_1
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_0
:
p_wei_block_1
;
__syncthreads
();
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
#if 1
blockwise_gemm
.
Run
#else
blockwise_gemm
.
Run_RegisterDoubleBuffer
#endif
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_now
+
y
*
Wi
+
x
,
p_out_thread
,
f_accum
);
}
}
}
// output: register to global mem,
const
auto
c_thread_mtx_begin
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_begin
=
k_block_data_begin
+
c_thread_mtx_begin
.
row
;
const
index_t
b_thread_data_begin
=
b_block_data_begin
+
c_thread_mtx_begin
.
col
;
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
matrix_c_index.row,
matrix_c_index.col,
k_data_begin,
b_data_begin,
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
}
#endif
for
(
index_t
k
=
0
;
k
<
out_kb_thread_desc
.
GetLength
(
I0
);
++
k
)
{
for
(
index_t
b
=
0
;
b
<
out_kb_thread_desc
.
GetLength
(
I1
);
++
b
)
{
const
auto
c_thread_mtx_distance
=
blockwise_gemm
.
GetDistanceFromBeginOfThreadMatrixC
(
k
,
b
);
index_t
k_data
=
k_thread_data_begin
+
c_thread_mtx_distance
.
row
;
index_t
b_data
=
b_thread_data_begin
+
c_thread_mtx_distance
.
col
;
index_t
h_data
=
b_data
/
(
Wi
*
N
);
index_t
itmp
=
b_data
-
h_data
*
(
Wi
*
N
);
index_t
w_data
=
itmp
/
N
;
index_t
n_data
=
itmp
-
w_data
*
N
;
if
(
n_data
<
N
&&
h_data
<
Ho
&&
w_data
<
Wo
)
{
p_out_global
[
out_khwn_global_desc
.
Get1dIndex
(
k_data
,
h_data
,
w_data
,
n_data
)]
=
p_out_thread
[
out_kb_thread_desc
.
Get1dIndex
(
k
,
b
)];
}
}
}
}
};
src/include/gridwise_direct_convolution_1.hip.hpp
deleted
100644 → 0
View file @
6166233e
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_direct_convolution.hip.hpp"
template
<
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
CPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
BlockSize
,
index_t
GridSize
>
__global__
void
gridwise_direct_convolution_1
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
Y
=
wei_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_global_desc
.
GetLength
(
I3
);
constexpr
index_t
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
constexpr
index_t
NBlockWork
=
(
out_global_desc
.
GetLength
(
I0
)
+
NPerBlock
-
1
)
/
NPerBlock
;
constexpr
index_t
KBlockWork
=
(
out_global_desc
.
GetLength
(
I1
)
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
out_global_desc
.
GetLength
(
I2
)
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
WBlockWork
=
(
out_global_desc
.
GetLength
(
I3
)
+
WoPerBlock
-
1
)
/
WoPerBlock
;
constexpr
auto
in_block_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
NPerBlock
,
CPerBlock
,
HiPerBlock
,
WiPerBlock
>
{},
in_global_desc
.
GetStrides
());
constexpr
auto
wei_block_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerBlock
,
CPerBlock
,
Y
,
X
>
{},
wei_global_desc
.
GetStrides
());
constexpr
auto
out_block_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
NPerBlock
,
KPerBlock
,
HoPerBlock
,
WoPerBlock
>
{},
out_global_desc
.
GetStrides
());
constexpr
auto
in_block_desc
=
make_ConstantTensorDescriptor
(
in_block_global_desc
.
GetLengths
());
constexpr
auto
wei_block_desc
=
make_ConstantTensorDescriptor
(
wei_block_global_desc
.
GetLengths
());
constexpr
auto
out_block_desc
=
make_ConstantTensorDescriptor
(
out_block_global_desc
.
GetLengths
());
constexpr
index_t
in_block_element_size
=
in_block_desc
.
GetElementSpace
();
constexpr
index_t
wei_block_element_size
=
wei_block_desc
.
GetElementSpace
();
constexpr
index_t
out_block_size
=
out_block_desc
.
GetElementSpace
();
__shared__
Float
p_in_block
[
in_block_element_size
];
__shared__
Float
p_wei_block
[
wei_block_element_size
];
__shared__
Float
p_out_block
[
out_block_size
];
const
index_t
block_id
=
blockIdx
.
x
;
index_t
itmp
=
block_id
;
index_t
n_block_work_id
=
itmp
/
(
KBlockWork
*
HBlockWork
*
WBlockWork
);
itmp
-=
n_block_work_id
*
(
KBlockWork
*
HBlockWork
*
WBlockWork
);
index_t
k_block_work_id
=
itmp
/
(
HBlockWork
*
WBlockWork
);
itmp
-=
k_block_work_id
*
(
HBlockWork
*
WBlockWork
);
index_t
h_block_work_id
=
itmp
/
WBlockWork
;
index_t
w_block_work_id
=
itmp
-
h_block_work_id
*
WBlockWork
;
index_t
n_block_work_begin
=
n_block_work_id
*
NPerBlock
;
index_t
k_block_work_begin
=
k_block_work_id
*
KPerBlock
;
index_t
ho_block_work_begin
=
h_block_work_id
*
HoPerBlock
;
index_t
wo_block_work_begin
=
w_block_work_id
*
WoPerBlock
;
index_t
hi_block_work_begin
=
ho_block_work_begin
;
// minus padding
index_t
wi_block_work_begin
=
wo_block_work_begin
;
// minus padding
constexpr
auto
blockwise_in_copy
=
Blockwise4dTensorCopy1
<
BlockSize
,
Float
,
decltype
(
in_block_global_desc
),
decltype
(
in_block_desc
),
decltype
(
in_block_desc
.
GetLengths
())
>
{};
constexpr
auto
blockwise_wei_copy
=
Blockwise4dTensorCopy1
<
BlockSize
,
Float
,
decltype
(
wei_block_global_desc
),
decltype
(
wei_block_desc
),
decltype
(
wei_block_desc
.
GetLengths
())
>
{};
constexpr
auto
blockwise_out_copy
=
Blockwise4dTensorCopy1
<
BlockSize
,
Float
,
decltype
(
out_block_desc
),
decltype
(
out_block_global_desc
),
decltype
(
out_block_desc
.
GetLengths
())
>
{};
// set output tensor in LDS to 0
blockwise_4d_tensor_set_zero
<
BlockSize
>
(
out_block_desc
,
p_out_block
);
for
(
index_t
c_block_work_begin
=
0
;
c_block_work_begin
<
in_global_desc
.
GetLength
(
I1
);
c_block_work_begin
+=
CPerBlock
)
{
// copy input tensor to LDS
blockwise_in_copy
.
Run
(
p_in_global
+
in_global_desc
.
Get1dIndex
(
n_block_work_begin
,
c_block_work_begin
,
hi_block_work_begin
,
wi_block_work_begin
),
p_in_block
);
// copy weight tensor to LDS
blockwise_wei_copy
.
Run
(
p_wei_global
+
wei_global_desc
.
Get1dIndex
(
k_block_work_begin
,
c_block_work_begin
,
0
,
0
),
p_wei_block
);
__syncthreads
();
// blockwise convolution
blockwise_direct_convolution
<
BlockSize
,
Float
,
decltype
(
in_block_desc
),
decltype
(
wei_block_desc
),
decltype
(
out_block_desc
),
NPerThread
,
KPerThread
,
CPerThread
,
HoPerThread
,
WoPerThread
>
(
in_block_desc
,
p_in_block
,
wei_block_desc
,
p_wei_block
,
out_block_desc
,
p_out_block
);
__syncthreads
();
}
// copy output tensor from LDS to device mem
blockwise_out_copy
.
Run
(
p_out_block
,
p_out_global
+
out_global_desc
.
Get1dIndex
(
n_block_work_begin
,
k_block_work_begin
,
ho_block_work_begin
,
wo_block_work_begin
));
}
src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp
deleted
100644 → 0
View file @
6166233e
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_direct_convolution.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "threadwise_direct_convolution.hip.hpp"
template
<
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
CPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
,
index_t
BlockSize
,
index_t
GridSize
>
__global__
void
gridwise_direct_convolution_2_nchw_kcyx_nkhw
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_kcyx_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_nkhw_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_nchw_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
wei_kcyx_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_global_desc
.
GetLength
(
I3
);
constexpr
auto
wei_ke_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
C
*
Y
*
X
>
{});
// 2d view of wei for blockwise copy
constexpr
index_t
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
constexpr
auto
in_nchw_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
NPerBlock
,
CPerBlock
,
HiPerBlock
,
WiPerBlock
>
{},
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
auto
wei_ke_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
KPerBlock
,
CPerBlock
*
Y
*
X
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
// 2d view of wei for blockwise copy
constexpr
auto
wei_kcyx_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerBlock
,
CPerBlock
,
Y
,
X
>
{},
Sequence
<
wei_ke_block_desc
.
GetStride
(
I0
),
Y
*
X
,
X
,
1
>
{});
// shared mem
constexpr
index_t
in_block_element_size
=
in_nchw_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
index_t
wei_block_element_size
=
wei_kcyx_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
?
InBlockCopyDataPerRead
:
WeiBlockCopyDataPerRead
;
__shared__
Float
p_in_block
[
max_align
*
((
in_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block
[
max_align
*
((
wei_block_element_size
+
max_align
-
1
)
/
max_align
)];
// threadwise tensors
constexpr
index_t
HiPerThread
=
HoPerThread
+
Y
-
1
;
constexpr
index_t
WiPerThread
=
WoPerThread
+
X
-
1
;
constexpr
auto
in_nchw_thread_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
NPerThread
,
CPerThread
,
HiPerThread
,
WiPerThread
>
{},
in_nchw_block_desc
.
GetStrides
());
constexpr
auto
wei_kcyx_thread_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
CPerThread
,
Y
,
X
>
{},
wei_kcyx_block_desc
.
GetStrides
());
constexpr
auto
out_nkhw_thread_desc
=
get_convolution_output_default_4d_tensor_descriptor
(
in_nchw_thread_block_desc
,
wei_kcyx_thread_block_desc
);
// register
Float
p_out_thread
[
out_nkhw_thread_desc
.
GetElementSpace
()];
// divide block work
constexpr
index_t
NBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I0
)
+
NPerBlock
-
1
)
/
NPerBlock
;
constexpr
index_t
KBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I1
)
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I2
)
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
WBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I3
)
+
WoPerBlock
-
1
)
/
WoPerBlock
;
const
index_t
block_id
=
blockIdx
.
x
;
index_t
itmp
=
block_id
;
const
index_t
n_block_work_id
=
itmp
/
(
KBlockWork
*
HBlockWork
*
WBlockWork
);
itmp
-=
n_block_work_id
*
(
KBlockWork
*
HBlockWork
*
WBlockWork
);
const
index_t
k_block_work_id
=
itmp
/
(
HBlockWork
*
WBlockWork
);
itmp
-=
k_block_work_id
*
(
HBlockWork
*
WBlockWork
);
const
index_t
h_block_work_id
=
itmp
/
WBlockWork
;
const
index_t
w_block_work_id
=
itmp
-
h_block_work_id
*
WBlockWork
;
const
index_t
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
// minus padding
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// minus padding
// divide thread work
constexpr
index_t
NThreadWork
=
(
NPerBlock
+
NPerThread
-
1
)
/
NPerThread
;
constexpr
index_t
KThreadWork
=
(
KPerBlock
+
KPerThread
-
1
)
/
KPerThread
;
constexpr
index_t
HThreadWork
=
(
HoPerBlock
+
HoPerThread
-
1
)
/
HoPerThread
;
constexpr
index_t
WThreadWork
=
(
WoPerBlock
+
WoPerThread
-
1
)
/
WoPerThread
;
const
index_t
thread_id
=
threadIdx
.
x
;
itmp
=
thread_id
;
const
index_t
n_thread_work_id
=
itmp
/
(
KThreadWork
*
HThreadWork
*
WThreadWork
);
itmp
-=
n_thread_work_id
*
(
KThreadWork
*
HThreadWork
*
WThreadWork
);
const
index_t
k_thread_work_id
=
itmp
/
(
HThreadWork
*
WThreadWork
);
itmp
-=
k_thread_work_id
*
(
HThreadWork
*
WThreadWork
);
const
index_t
h_thread_work_id
=
itmp
/
WThreadWork
;
const
index_t
w_thread_work_id
=
itmp
-
h_thread_work_id
*
WThreadWork
;
const
index_t
n_thread_data_begin
=
n_thread_work_id
*
NPerThread
;
const
index_t
k_thread_data_begin
=
k_thread_work_id
*
KPerThread
;
const
index_t
ho_thread_data_begin
=
h_thread_work_id
*
HoPerThread
;
const
index_t
wo_thread_data_begin
=
w_thread_work_id
*
WoPerThread
;
const
index_t
hi_thread_data_begin
=
ho_thread_data_begin
;
const
index_t
wi_thread_data_begin
=
wo_thread_data_begin
;
constexpr
auto
blockwise_in_copy
=
Blockwise4dTensorCopy1
<
BlockSize
,
Float
,
decltype
(
in_nchw_global_desc
),
decltype
(
in_nchw_block_desc
),
decltype
(
in_nchw_block_desc
.
GetLengths
()),
InBlockCopyDataPerRead
>
{};
#if 0
constexpr auto blockwise_wei_copy =
Blockwise4dTensorCopy1<BlockSize,
Float,
decltype(wei_kcyx_global_desc),
decltype(wei_kcyx_block_desc),
decltype(wei_kcyx_block_desc.GetLengths()),
1>{};
#elif
1
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
wei_ke_global_desc
),
decltype
(
wei_ke_block_desc
),
decltype
(
wei_ke_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
{};
#endif
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_nkhw_thread_desc
,
p_out_thread
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
{
// copy input tensor to LDS
blockwise_in_copy
.
Run
(
p_in_global
+
in_nchw_global_desc
.
Get1dIndex
(
n_block_data_begin
,
c_block_data_begin
,
hi_block_data_begin
,
wi_block_data_begin
),
p_in_block
);
// copy weight tensor to LDS
blockwise_wei_copy
.
Run
(
p_wei_global
+
wei_kcyx_global_desc
.
Get1dIndex
(
k_block_data_begin
,
c_block_data_begin
,
0
,
0
),
p_wei_block
);
__syncthreads
();
for
(
index_t
c_thread_data
=
0
;
c_thread_data
<
CPerBlock
;
c_thread_data
+=
CPerThread
)
{
// threadwise convolution
#if 1
threadwise_direct_convolution_2
(
in_nchw_thread_block_desc
,
p_in_block
+
in_nchw_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wei_kcyx_thread_block_desc
,
p_wei_block
+
wei_kcyx_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
out_nkhw_thread_desc
,
p_out_thread
);
#elif 0
threadwise_direct_convolution_3
(
in_nchw_thread_block_desc
,
p_in_block
+
in_nchw_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wei_kcyx_thread_block_desc
,
p_wei_block
+
wei_kcyx_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
out_nkhw_thread_desc
,
p_out_thread
);
#endif
}
}
// copy output tensor from register to global mem
threadwise_4d_tensor_copy
(
out_nkhw_thread_desc
,
p_out_thread
,
out_nkhw_global_desc
,
p_out_global
+
out_nkhw_global_desc
.
Get1dIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
out_nkhw_thread_desc
.
GetLengths
());
}
src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
deleted
100644 → 0
View file @
6166233e
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_direct_convolution.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "threadwise_direct_convolution.hip.hpp"
template
<
class
TInWei
,
class
TOut
,
class
TAccum
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
ScalarPerVector
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
CPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
,
index_t
BlockSize
,
index_t
GridSize
>
__global__
void
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw
(
const
typename
vector_type
<
TInWei
,
ScalarPerVector
>::
MemoryType
*
const
__restrict__
p_in_vec_global
,
const
typename
vector_type
<
TInWei
,
ScalarPerVector
>::
MemoryType
*
const
__restrict__
p_wei_vec_global
,
TOut
*
const
__restrict__
p_out_global
)
{
using
in_scalar_t
=
TInWei
;
using
in_vector_mem_t
=
typename
vector_type
<
in_scalar_t
,
ScalarPerVector
>::
MemoryType
;
using
out_scalar_t
=
TOut
;
using
accum_t
=
TAccum
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_vec_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_kcyx_vec_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_nkhw_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_nchw_vec_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
wei_kcyx_vec_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_vec_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_vec_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_vec_global_desc
.
GetLength
(
I3
);
constexpr
auto
wei_ke_vec_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
C
*
Y
*
X
>
{});
// 2d view of wei for blockwise copy
constexpr
index_t
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
constexpr
auto
in_nchw_vec_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
NPerBlock
,
CPerBlock
,
HiPerBlock
,
WiPerBlock
>
{},
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
auto
wei_ke_vec_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
KPerBlock
,
CPerBlock
*
Y
*
X
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
// 2d view of wei for blockwise copy
constexpr
auto
wei_kcyx_vec_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerBlock
,
CPerBlock
,
Y
,
X
>
{},
Sequence
<
wei_ke_vec_block_desc
.
GetStride
(
I0
),
Y
*
X
,
X
,
1
>
{});
// shared mem
constexpr
index_t
in_block_element_size
=
in_nchw_vec_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
index_t
wei_block_element_size
=
wei_kcyx_vec_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
?
InBlockCopyDataPerRead
:
WeiBlockCopyDataPerRead
;
__shared__
in_vector_mem_t
p_in_vec_block
[
max_align
*
((
in_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
in_vector_mem_t
p_wei_vec_block
[
max_align
*
((
wei_block_element_size
+
max_align
-
1
)
/
max_align
)];
// threadwise tensors
constexpr
index_t
HiPerThread
=
HoPerThread
+
Y
-
1
;
constexpr
index_t
WiPerThread
=
WoPerThread
+
X
-
1
;
constexpr
auto
in_nchw_vec_thread_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
NPerThread
,
CPerThread
,
HiPerThread
,
WiPerThread
>
{},
in_nchw_vec_block_desc
.
GetStrides
());
constexpr
auto
wei_kcyx_vec_thread_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
CPerThread
,
Y
,
X
>
{},
wei_kcyx_vec_block_desc
.
GetStrides
());
constexpr
auto
out_nkhw_thread_desc
=
get_convolution_output_default_4d_tensor_descriptor
(
in_nchw_vec_thread_block_desc
,
wei_kcyx_vec_thread_block_desc
);
// register
out_scalar_t
p_out_thread
[
out_nkhw_thread_desc
.
GetElementSpace
()];
// divide block work
constexpr
index_t
NBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I0
)
+
NPerBlock
-
1
)
/
NPerBlock
;
constexpr
index_t
KBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I1
)
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I2
)
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
WBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I3
)
+
WoPerBlock
-
1
)
/
WoPerBlock
;
const
index_t
block_id
=
blockIdx
.
x
;
index_t
itmp
=
block_id
;
const
index_t
n_block_work_id
=
itmp
/
(
KBlockWork
*
HBlockWork
*
WBlockWork
);
itmp
-=
n_block_work_id
*
(
KBlockWork
*
HBlockWork
*
WBlockWork
);
const
index_t
k_block_work_id
=
itmp
/
(
HBlockWork
*
WBlockWork
);
itmp
-=
k_block_work_id
*
(
HBlockWork
*
WBlockWork
);
const
index_t
h_block_work_id
=
itmp
/
WBlockWork
;
const
index_t
w_block_work_id
=
itmp
-
h_block_work_id
*
WBlockWork
;
const
index_t
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
// minus padding
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// minus padding
// divide thread work
constexpr
index_t
NThreadWork
=
(
NPerBlock
+
NPerThread
-
1
)
/
NPerThread
;
constexpr
index_t
KThreadWork
=
(
KPerBlock
+
KPerThread
-
1
)
/
KPerThread
;
constexpr
index_t
HThreadWork
=
(
HoPerBlock
+
HoPerThread
-
1
)
/
HoPerThread
;
constexpr
index_t
WThreadWork
=
(
WoPerBlock
+
WoPerThread
-
1
)
/
WoPerThread
;
const
index_t
thread_id
=
threadIdx
.
x
;
itmp
=
thread_id
;
const
index_t
n_thread_work_id
=
itmp
/
(
KThreadWork
*
HThreadWork
*
WThreadWork
);
itmp
-=
n_thread_work_id
*
(
KThreadWork
*
HThreadWork
*
WThreadWork
);
const
index_t
k_thread_work_id
=
itmp
/
(
HThreadWork
*
WThreadWork
);
itmp
-=
k_thread_work_id
*
(
HThreadWork
*
WThreadWork
);
const
index_t
h_thread_work_id
=
itmp
/
WThreadWork
;
const
index_t
w_thread_work_id
=
itmp
-
h_thread_work_id
*
WThreadWork
;
const
index_t
n_thread_data_begin
=
n_thread_work_id
*
NPerThread
;
const
index_t
k_thread_data_begin
=
k_thread_work_id
*
KPerThread
;
const
index_t
ho_thread_data_begin
=
h_thread_work_id
*
HoPerThread
;
const
index_t
wo_thread_data_begin
=
w_thread_work_id
*
WoPerThread
;
const
index_t
hi_thread_data_begin
=
ho_thread_data_begin
;
const
index_t
wi_thread_data_begin
=
wo_thread_data_begin
;
constexpr
auto
blockwise_in_copy
=
Blockwise4dTensorCopy1
<
BlockSize
,
in_vector_mem_t
,
decltype
(
in_nchw_vec_global_desc
),
decltype
(
in_nchw_vec_block_desc
),
decltype
(
in_nchw_vec_block_desc
.
GetLengths
()),
InBlockCopyDataPerRead
>
{};
#if 0
constexpr auto blockwise_wei_copy =
Blockwise4dTensorCopy1<BlockSize,
in_vector_mem_t,
decltype(wei_kcyx_vec_global_desc),
decltype(wei_kcyx_vec_block_desc),
decltype(wei_kcyx_vec_block_desc.GetLengths()),
1>{};
#elif
1
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
in_vector_mem_t
,
decltype
(
wei_ke_vec_global_desc
),
decltype
(
wei_ke_vec_block_desc
),
decltype
(
wei_ke_vec_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
{};
#endif
#if 1 // debug
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_nkhw_thread_desc
,
p_out_thread
);
#endif
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
{
// copy input tensor to LDS
blockwise_in_copy
.
Run
(
p_in_vec_global
+
in_nchw_vec_global_desc
.
Get1dIndex
(
n_block_data_begin
,
c_block_data_begin
,
hi_block_data_begin
,
wi_block_data_begin
),
p_in_vec_block
);
// copy weight tensor to LDS
blockwise_wei_copy
.
Run
(
p_wei_vec_global
+
wei_kcyx_vec_global_desc
.
Get1dIndex
(
k_block_data_begin
,
c_block_data_begin
,
0
,
0
),
p_wei_vec_block
);
__syncthreads
();
for
(
index_t
c_thread_data
=
0
;
c_thread_data
<
CPerBlock
;
c_thread_data
+=
CPerThread
)
{
// threadwise convolution
#if 1
threadwise_direct_convolution_2
(
in_nchw_vec_thread_block_desc
,
p_in_vec_block
+
in_nchw_vec_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wei_kcyx_vec_thread_block_desc
,
p_wei_vec_block
+
wei_kcyx_vec_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
out_nkhw_thread_desc
,
p_out_thread
);
#elif 0
threadwise_direct_convolution_3
(
in_nchw_vec_thread_block_desc
,
p_in_vec_block
+
in_nchw_vec_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wei_kcyx_vec_thread_block_desc
,
p_wei_vec_block
+
wei_kcyx_vec_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
out_nkhw_thread_desc
,
p_out_thread
);
#endif
}
}
// copy output tensor from register to global mem
threadwise_4d_tensor_copy
(
out_nkhw_thread_desc
,
p_out_thread
,
out_nkhw_global_desc
,
p_out_global
+
out_nkhw_global_desc
.
Get1dIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
out_nkhw_thread_desc
.
GetLengths
());
}
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp
deleted
100644 → 0
View file @
6166233e
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_nd_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_batched_gemm.hip.hpp"
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
class
InBlockCopyThreadPerDims
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
OutThreadCopyDataPerWrite
>
__global__
void
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
{
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
static_assert
(
NPerBlock
%
NPerThread
==
0
,
"wrong! NPerBlock % NPerThread !=0"
);
static_assert
((
NPerThread
<
NPerBlock
&&
WoPerThread
==
1
)
||
NPerThread
==
NPerBlock
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_chwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_khwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_khwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wo
=
out_khwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
out_khwn_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_cyxk_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_cyxk_global_desc
.
GetLength
(
I2
);
constexpr
index_t
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
// divide block work: [K, Ho, Wo, N]
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
WBlockWork
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
;
constexpr
index_t
NBlockWork
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
index_t
itmp
=
get_block_1d_id
()
-
k_block_work_id
*
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
const
index_t
h_block_work_id
=
itmp
/
(
WBlockWork
*
NBlockWork
);
itmp
-=
h_block_work_id
*
(
WBlockWork
*
NBlockWork
);
const
index_t
w_block_work_id
=
itmp
/
NBlockWork
;
const
index_t
n_block_work_id
=
itmp
-
w_block_work_id
*
NBlockWork
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
index_t
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// flattend (2d) tensor view of gridwise weight
constexpr
auto
wei_ek_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
*
Y
*
X
,
K
>
{});
// tensor view of blockwise input and weight in LDS
// be careful of alignment
constexpr
auto
in_chwn_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
auto
wei_ek_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
*
Y
*
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_khwn_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
// blockwise copy
// input: format is [C, Hi, Wi, N]
const
auto
blockwise_in_copy
=
Blockwise4dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
in_chwn_global_desc
),
decltype
(
in_chwn_block_desc
),
decltype
(
in_chwn_block_desc
.
GetLengths
()),
InBlockCopyThreadPerDims
,
InBlockCopyDataPerRead
>
{};
// blockwise wei copy
// format is [CPerBlock*Y*X,KPerBlock]
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_cyxk_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_cxwn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{},
Number
<
in_chwn_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_kxwn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{},
Number
<
out_khwn_thread_desc
.
GetStride
(
I1
)
>
{});
const
auto
blockwise_batch_gemm
=
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
0
,
in_chwn_block_desc
.
GetStride
(
I1
),
out_khwn_thread_desc
.
GetStride
(
I1
),
HoPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
HoPerThread
>
{};
// LDS: be careful of alignment
constexpr
index_t
in_block_element_size
=
in_chwn_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
index_t
wei_block_element_size
=
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
?
InBlockCopyDataPerRead
:
WeiBlockCopyDataPerRead
;
__shared__
Float
p_in_block
[
max_align
*
((
in_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block
[
max_align
*
((
wei_block_element_size
+
max_align
-
1
)
/
max_align
)];
// register
Float
p_out_thread
[
out_khwn_thread_desc
.
GetElementSpace
()];
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_khwn_thread_desc
,
p_out_thread
);
const
Float
*
p_in_global_block_begin
=
p_in_global
+
in_chwn_global_desc
.
Get1dIndex
(
0
,
hi_block_data_begin
,
wi_block_data_begin
,
n_block_data_begin
);
const
Float
*
p_wei_global_block_begin
=
p_wei_global
+
wei_cyxk_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_in_global_block_begin
+=
CPerBlock
*
in_chwn_global_desc
.
GetStride
(
I0
),
p_wei_global_block_begin
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
),
__syncthreads
())
{
// input: global mem to LDS
blockwise_in_copy
.
Run
(
p_in_global_block_begin
,
p_in_block
);
// weight: global mem to LDS
blockwise_wei_copy
.
Run
(
p_wei_global_block_begin
,
p_wei_block
);
__syncthreads
();
// a series of batched GEMM
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
#if 0
blockwise_batch_gemm.Run
#elif
1
blockwise_batch_gemm
.
Run_v3
#endif
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_out_thread
,
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
});
}
}
}
// output: register to global mem,
#if 0
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
{
for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
{
for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
{
for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
{
const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
const auto c_thread_mtx_distance =
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
const index_t ho_thread =
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
const index_t wo_thread = b_thread / NPerBlock;
const index_t n_thread = b_thread % NPerBlock;
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
ho_block_data_begin + ho_thread,
wo_block_data_begin + wo_thread,
n_block_data_begin + n_thread)] =
p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
}
}
}
}
#elif
1
const
auto
c_thread_mtx_begin
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_begin
=
c_thread_mtx_begin
.
row
;
const
index_t
ho_thread_data_begin
=
c_thread_mtx_begin
.
batch
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
-
NPerBlock
*
wo_thread_data_begin
;
// this is for v2 GEMM
// output is a 8d tensor
if
(
NPerThread
<
NPerBlock
&&
WoPerThread
==
1
)
{
constexpr
index_t
N1_
=
GemmNPerThreadSubC
;
constexpr
index_t
W1_
=
WoPerBlock
/
((
WoPerThread
*
NPerThread
)
/
GemmNPerThreadSubC
);
constexpr
index_t
K2_
=
GemmMPerThreadSubC
;
constexpr
index_t
K1_
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_8d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1_
*
K2_
),
K1_
,
K2_
,
Ho
,
Wo
/
W1_
,
W1_
,
N
/
N1_
,
N1_
>
{});
constexpr
auto
out_8d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerBlock
/
(
K1_
*
K2_
),
1
,
K2_
,
HoPerThread
,
WoPerBlock
/
W1_
,
1
,
1
,
N1_
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
print_ConstantTensorDescriptor(out_8d_thread_desc, "out_8d_thread_desc");
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
print_ConstantTensorDescriptor(out_8d_global_desc, "out_8d_global_desc");
}
#endif
threadwise_8d_tensor_copy
(
out_8d_thread_desc
,
p_out_thread
,
out_8d_global_desc
,
p_out_global
+
out_khwn_global_desc
.
Get1dIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_8d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite
>
{});
}
else
if
(
NPerThread
==
NPerBlock
)
{
// not implemented yet
assert
(
false
);
}
else
{
assert
(
false
);
}
#endif
}
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp
deleted
100644 → 0
View file @
6166233e
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_gemm.hip.hpp"
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
LowerPads
,
class
UpperPads
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
CPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
WeiBlockCopyThreadPerDim0
,
index_t
WeiBlockCopyThreadPerDim1
>
__global__
void
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
{
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
static_assert
(
NPerBlock
%
NPerThread
==
0
,
"wrong! NPerBlock % NPerThread !=0"
);
static_assert
((
NPerThread
<
NPerBlock
&&
WoPerThread
==
1
)
||
NPerThread
==
NPerBlock
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_chwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_khwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_khwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wo
=
out_khwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
out_khwn_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_cyxk_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_cyxk_global_desc
.
GetLength
(
I2
);
constexpr
index_t
HPadLow
=
LowerPads
{}.
Get
(
I0
);
constexpr
index_t
WPadLow
=
LowerPads
{}.
Get
(
I1
);
constexpr
index_t
HPadUp
=
UpperPads
{}.
Get
(
I0
);
constexpr
index_t
WPadUp
=
UpperPads
{}.
Get
(
I1
);
constexpr
index_t
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
// divide block work: [K, Ho, Wo, N]
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
WBlockWork
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
;
constexpr
index_t
NBlockWork
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
index_t
itmp
=
get_block_1d_id
()
-
k_block_work_id
*
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
const
index_t
h_block_work_id
=
itmp
/
(
WBlockWork
*
NBlockWork
);
itmp
-=
h_block_work_id
*
(
WBlockWork
*
NBlockWork
);
const
index_t
w_block_work_id
=
itmp
/
NBlockWork
;
const
index_t
n_block_work_id
=
itmp
-
w_block_work_id
*
NBlockWork
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
index_t
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
// flattened (2d) tensor view of wei in global mem
constexpr
auto
wei_ek_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
*
Y
*
X
,
K
>
{});
// tensor view of blockwise input and weight in LDS
constexpr
auto
in_chwn_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{});
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{});
// flattened (2d) tensor view of wei in LDS
constexpr
auto
wei_ek_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
CPerBlock
*
Y
*
X
,
KPerBlock
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_hkwn_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
HoPerThread
,
KPerThread
,
WoPerThread
,
NPerThread
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc");
print_ConstantTensorDescriptor(wei_cyxk_block_desc, "wei_cyxk_block_desc");
print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc");
}
#endif
// blockwise copy
// input: format is [C, Hi, Wi, N]
const
index_t
h_block_pad_low
=
h_block_work_id
==
0
?
HPadLow
:
0
;
const
index_t
w_block_pad_low
=
w_block_work_id
==
0
?
WPadLow
:
0
;
const
index_t
h_block_pad_up
=
h_block_work_id
==
HBlockWork
-
1
?
HPadUp
:
0
;
const
index_t
w_block_pad_up
=
w_block_work_id
==
WBlockWork
-
1
?
WPadUp
:
0
;
#if 0
if(get_thread_local_1d_id() == 0)
;
{
printf(
"%u %u, h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
h_block_pad_low,
w_block_pad_low,
h_block_pad_up,
w_block_pad_up);
}
#endif
constexpr
auto
blockwise_in_copy
=
BlockwiseChwnTensorCopyPadded
<
BlockSize
,
Float
,
decltype
(
in_chwn_global_desc
),
decltype
(
in_chwn_block_desc
),
decltype
(
in_chwn_block_desc
.
GetLengths
()),
LowerPads
>
{};
#if 0
// weight: format is [C,Y,X,K]
constexpr auto blockwise_wei_copy =
Blockwise4dTensorCopy1<BlockSize,
Float,
decltype(wei_cyxk_global_desc),
decltype(wei_cyxk_block_desc),
decltype(wei_cyxk_block_desc.GetLengths())>{};
#elif
0
// weight: format is [C*Y*X,K]
constexpr
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy1
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
())
>
{};
#elif 1
// weight: format is [C*Y*X,K]
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
{};
#endif
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N]
constexpr
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_cyxk_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_cxwn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{},
Number
<
in_chwn_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_kxwn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{});
const
auto
blockwise_batch_gemm
=
Blockwise1dStridedBatchedGemmBlockABlockBThreadC
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
true
,
false
,
false
,
0
,
in_chwn_block_desc
.
GetStride
(
I1
),
out_hkwn_thread_desc
.
GetStride
(
I0
),
HoPerBlock
,
HoPerThread
,
CPerThread
,
true
>
{};
// LDS
constexpr
index_t
in_block_element_size
=
in_chwn_block_desc
.
GetElementSpace
();
constexpr
index_t
wei_block_element_size
=
wei_cyxk_block_desc
.
GetElementSpace
();
__shared__
Float
p_in_block
[
in_block_element_size
];
__shared__
Float
p_wei_block
[
wei_block_element_size
];
// register
Float
p_out_thread
[
out_hkwn_thread_desc
.
GetElementSpace
()];
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_hkwn_thread_desc
,
p_out_thread
);
const
Float
*
p_wei_global_block_begin
=
p_wei_global
+
wei_ek_global_desc
.
Get1dIndex
(
0
,
k_block_data_begin
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_wei_global_block_begin
+=
CPerBlock
*
wei_ek_global_desc
.
GetStride
(
I0
),
__syncthreads
())
{
#if 1
// input: global mem to LDS,
blockwise_in_copy
.
Run
(
p_in_global
,
c_block_data_begin
,
ho_block_data_begin
,
wo_block_data_begin
,
n_block_data_begin
,
p_in_block
,
h_block_pad_low
,
w_block_pad_low
,
h_block_pad_up
,
w_block_pad_up
);
#endif
#if 1
// weight: global mem to LDS,
blockwise_wei_copy
.
Run
(
p_wei_global_block_begin
,
p_wei_block
);
#endif
__syncthreads
();
// a series of batched GEMM
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
blockwise_batch_gemm
.
Run
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_out_thread
,
f_accum
);
}
}
}
const
auto
matrix_c_index
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
ho_thread_data_begin
=
matrix_c_index
.
batch
;
const
index_t
k_thread_data_begin
=
matrix_c_index
.
row
;
const
index_t
wo_thread_data_begin
=
matrix_c_index
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
matrix_c_index
.
col
-
wo_thread_data_begin
*
NPerBlock
;
#if 0
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
get_block_1d_id(), get_thread_local_1d_id(),
ho_block_data_begin, k_block_data_begin, wo_block_data_begin, n_block_data_begin,
ho_thread_data_begin, k_thread_data_begin, wo_thread_data_begin, n_thread_data_begin,
p_out_thread[0]);
#endif
// output: register to global mem,
// convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N]
constexpr
auto
reorder_khwn_from_hkwn
=
Sequence
<
1
,
0
,
2
,
3
>
{};
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src
(
out_hkwn_thread_desc
,
p_out_thread
,
out_khwn_global_desc
,
p_out_global
+
out_khwn_global_desc
.
Get1dIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_hkwn_thread_desc
.
GetLengths
(),
reorder_khwn_from_hkwn
);
}
src/include/threadwise_4d_tensor_op.hip.hpp
deleted
100644 → 0
View file @
6166233e
#pragma once
#include "ConstantTensorDescriptor.hip.hpp"
template
<
class
Float
,
class
Desc
,
class
F
>
__device__
void
threadwise_4d_tensor_pointwise_operation_unary
(
Desc
,
Float
*
__restrict__
p
,
F
f
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
desc
=
Desc
{};
#if 0
if(threadIdx.x == 0)
{
print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_op_unary: ");
}
#endif
for
(
index_t
did0
=
0
;
did0
<
desc
.
GetLength
(
I0
);
++
did0
)
{
for
(
index_t
did1
=
0
;
did1
<
desc
.
GetLength
(
I1
);
++
did1
)
{
for
(
index_t
did2
=
0
;
did2
<
desc
.
GetLength
(
I2
);
++
did2
)
{
for
(
index_t
did3
=
0
;
did3
<
desc
.
GetLength
(
I3
);
++
did3
)
{
const
index_t
dindex
=
desc
.
Get1dIndex
(
did0
,
did1
,
did2
,
did3
);
f
(
p
[
dindex
]);
}
}
}
}
}
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
template
<
class
SrcData
,
class
DstData
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
DstFromSrcReorder
,
class
F
>
__device__
void
threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
SrcDesc
,
const
SrcData
*
__restrict__
p_src
,
DstDesc
,
DstData
*
__restrict__
p_dst
,
SrcOpLengths
,
DstFromSrcReorder
,
F
f
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
IR0
=
DstFromSrcReorder
{}.
Get
(
I0
);
constexpr
index_t
IR1
=
DstFromSrcReorder
{}.
Get
(
I1
);
constexpr
index_t
IR2
=
DstFromSrcReorder
{}.
Get
(
I2
);
constexpr
index_t
IR3
=
DstFromSrcReorder
{}.
Get
(
I3
);
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
SrcOpLengths
{});
for
(
index_t
did0
=
0
;
did0
<
ref_desc
.
GetLength
(
I0
);
++
did0
)
{
for
(
index_t
did1
=
0
;
did1
<
ref_desc
.
GetLength
(
I1
);
++
did1
)
{
for
(
index_t
did2
=
0
;
did2
<
ref_desc
.
GetLength
(
I2
);
++
did2
)
{
for
(
index_t
did3
=
0
;
did3
<
ref_desc
.
GetLength
(
I3
);
++
did3
)
{
const
index_t
aindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
,
did2
,
did3
);
const
index_t
did
[
4
]
=
{
did0
,
did1
,
did2
,
did3
};
const
index_t
bindex
=
dst_desc
.
Get1dIndex
(
did
[
IR0
],
did
[
IR1
],
did
[
IR2
],
did
[
IR3
]);
f
(
p_src
[
aindex
],
p_dst
[
bindex
]);
}
}
}
}
}
template
<
class
Data
,
class
Desc
>
__device__
void
threadwise_4d_tensor_set_zero
(
Desc
,
Data
*
__restrict__
p
)
{
auto
f_set_zero
=
[](
Data
&
v
)
{
v
=
Data
(
0
);
};
threadwise_4d_tensor_pointwise_operation_unary
<
Data
,
Desc
,
decltype
(
f_set_zero
)
>
(
Desc
{},
p
,
f_set_zero
);
}
template
<
class
SrcData
,
class
DstData
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
DstFromSrcReorder
>
__device__
void
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src
(
SrcDesc
,
const
SrcData
*
__restrict__
p_src
,
DstDesc
,
DstData
*
__restrict__
p_dst
,
SrcOpLengths
,
DstFromSrcReorder
)
{
auto
f_copy
=
[](
const
SrcData
&
src
,
DstData
&
dst
)
{
dst
=
static_cast
<
DstData
>
(
src
);
};
threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
DstFromSrcReorder
{},
f_copy
);
}
template
<
class
SrcData
,
class
DstData
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
__device__
void
threadwise_4d_tensor_copy
(
SrcDesc
,
const
SrcData
*
__restrict__
p_src
,
DstDesc
,
DstData
*
__restrict__
p_dst
,
SrcOpLengths
)
{
auto
dst_from_src_reorder
=
Sequence
<
0
,
1
,
2
,
3
>
{};
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src
(
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
dst_from_src_reorder
);
}
// need to assume src and dst is aligned
template
<
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
index_t
DataPerRead
>
__device__
void
threadwise_4d_tensor_copy_v2
(
SrcDesc
,
const
Float
*
__restrict__
p_src
,
DstDesc
,
Float
*
__restrict__
p_dst
,
SrcOpLengths
,
Number
<
DataPerRead
>
)
{
using
Float2
=
float2
;
using
Float4
=
float4
;
static_assert
(
SrcDesc
{}.
GetDimension
()
==
4
&&
DstDesc
{}.
GetDimension
()
==
4
&&
SrcOpLengths
::
nDim
==
4
,
"wrong! should be 4 dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
SrcOpLengths
{});
static_assert
(
SrcDesc
{}.
GetStride
(
I3
)
==
1
&&
DstDesc
{}.
GetStride
(
I3
)
==
1
,
"wrong! only support stride3 == 1!
\n
"
);
static_assert
(
DataPerRead
==
1
||
DataPerRead
==
2
||
DataPerRead
==
4
,
"wrong! only support DataPerRead == 1, 2 or 4!
\n
"
);
static_assert
(
SrcDesc
{}.
GetStride
(
I2
)
%
DataPerRead
==
0
&&
DstDesc
{}.
GetStride
(
I2
)
%
DataPerRead
==
0
,
"wrong! src and dst stride should be multiple of DataPerRead to keep alignment"
);
constexpr
index_t
L3
=
SrcOpLengths
{}.
Get
(
I3
);
static_assert
(
L3
%
DataPerRead
==
0
,
"wrong! L3 should be evenly divided by DataPerRead"
);
constexpr
index_t
nloop_d3
=
L3
/
DataPerRead
;
for
(
index_t
did0
=
0
;
did0
<
ref_desc
.
GetLength
(
I0
);
++
did0
)
{
for
(
index_t
did1
=
0
;
did1
<
ref_desc
.
GetLength
(
I1
);
++
did1
)
{
for
(
index_t
did2
=
0
;
did2
<
ref_desc
.
GetLength
(
I2
);
++
did2
)
{
for
(
index_t
iloop_d3
=
0
;
iloop_d3
<
nloop_d3
;
++
iloop_d3
)
{
const
index_t
src_index
=
src_desc
.
Get1dIndex
(
did0
,
did1
,
did2
,
iloop_d3
*
DataPerRead
);
const
index_t
dst_index
=
dst_desc
.
Get1dIndex
(
did0
,
did1
,
did2
,
iloop_d3
*
DataPerRead
);
if
(
DataPerRead
==
1
)
{
p_dst
[
dst_index
]
=
p_src
[
src_index
];
}
else
if
(
DataPerRead
==
2
)
{
*
(
reinterpret_cast
<
Float2
*>
(
p_dst
+
dst_index
))
=
*
(
reinterpret_cast
<
const
Float2
*>
(
p_src
+
src_index
));
}
else
if
(
DataPerRead
==
4
)
{
*
(
reinterpret_cast
<
Float4
*>
(
p_dst
+
dst_index
))
=
*
(
reinterpret_cast
<
const
Float4
*>
(
p_src
+
src_index
));
}
else
{
assert
(
false
);
}
}
}
}
}
}
template
<
class
Float
,
class
Desc
,
class
IDim
,
class
NShift
>
__device__
void
threadwise_4d_tensor_shift_down
(
Desc
,
Float
*
__restrict__
p
,
IDim
,
NShift
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
desc
=
Desc
{};
#if 0
if(threadIdx.x == 0)
{
print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_shift_down: ");
}
#endif
constexpr
index_t
nshift
=
NShift
::
mValue
;
constexpr
index_t
did0_end
=
is_same
<
decltype
(
I0
),
IDim
>::
value
?
desc
.
GetLength
(
I0
)
-
nshift
:
desc
.
GetLength
(
I0
);
constexpr
index_t
did1_end
=
is_same
<
decltype
(
I1
),
IDim
>::
value
?
desc
.
GetLength
(
I1
)
-
nshift
:
desc
.
GetLength
(
I1
);
constexpr
index_t
did2_end
=
is_same
<
decltype
(
I2
),
IDim
>::
value
?
desc
.
GetLength
(
I2
)
-
nshift
:
desc
.
GetLength
(
I2
);
constexpr
index_t
did3_end
=
is_same
<
decltype
(
I3
),
IDim
>::
value
?
desc
.
GetLength
(
I3
)
-
nshift
:
desc
.
GetLength
(
I3
);
for
(
index_t
did0
=
0
;
did0
<
did0_end
;
++
did0
)
{
for
(
index_t
did1
=
0
;
did1
<
did1_end
;
++
did1
)
{
for
(
index_t
did2
=
0
;
did2
<
did2_end
;
++
did2
)
{
for
(
index_t
did3
=
0
;
did3
<
did3_end
;
++
did3
)
{
const
index_t
dindex
=
desc
.
Get1dIndex
(
did0
,
did1
,
did2
,
did3
);
const
index_t
sindex
=
dindex
+
nshift
*
desc
.
GetStride
(
IDim
{});
p
[
dindex
]
=
p
[
sindex
];
}
}
}
}
}
src/include/threadwise_direct_convolution.hip.hpp
deleted
100644 → 0
View file @
6166233e
#pragma once
#include "ConstantTensorDescriptor.hip.hpp"
// optimized for scenario if p_in, p_wei, p_out are in register
template
<
class
TInWei
,
class
TOut
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
__device__
void
threadwise_direct_convolution_1
(
InDesc
,
TInWei
*
const
__restrict__
p_in
,
WeiDesc
,
TInWei
*
const
__restrict__
p_wei
,
OutDesc
,
TOut
*
__restrict__
p_out
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
out_desc
=
OutDesc
{};
#if 0
if(blockIdx.x == 0 && threadIdx.x == 0)
{
print_ConstantTensorDescriptor(in_desc, "threadwise_direct_convolution: in_desc: ");
print_ConstantTensorDescriptor(wei_desc, "threadwise_direct_convolution: wei_desc: ");
print_ConstantTensorDescriptor(out_desc, "threadwise_direct_convolution: out_desc: ");
}
#endif
for
(
index_t
n
=
0
;
n
<
out_desc
.
GetLength
(
I0
);
++
n
)
{
for
(
index_t
k
=
0
;
k
<
out_desc
.
GetLength
(
I1
);
++
k
)
{
for
(
index_t
ho
=
0
;
ho
<
out_desc
.
GetLength
(
I2
);
++
ho
)
{
for
(
index_t
wo
=
0
;
wo
<
out_desc
.
GetLength
(
I3
);
++
wo
)
{
for
(
index_t
c
=
0
;
c
<
wei_desc
.
GetLength
(
I1
);
++
c
)
{
for
(
index_t
y
=
0
;
y
<
wei_desc
.
GetLength
(
I2
);
++
y
)
{
for
(
index_t
x
=
0
;
x
<
wei_desc
.
GetLength
(
I3
);
++
x
)
{
const
index_t
hi
=
ho
+
y
;
const
index_t
wi
=
wo
+
x
;
const
index_t
in_index
=
in_desc
.
Get1dIndex
(
n
,
c
,
hi
,
wi
);
const
index_t
wei_index
=
wei_desc
.
Get1dIndex
(
k
,
c
,
y
,
x
);
const
index_t
out_index
=
out_desc
.
Get1dIndex
(
n
,
k
,
ho
,
wo
);
fused_multiply_accumulate
(
p_out
[
out_index
],
p_wei
[
wei_index
],
p_in
[
in_index
]);
}
}
}
}
}
}
}
}
// Optimized for scenario if p_in and p_wei are in LDS, p_out are in register
// Copy in and wei into register before doing convolution
template
<
class
TInWei
,
class
TOut
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
__device__
void
threadwise_direct_convolution_2
(
InDesc
,
TInWei
*
const
__restrict__
p_in
,
WeiDesc
,
TInWei
*
const
__restrict__
p_wei
,
OutDesc
,
TOut
*
__restrict__
p_out
)
{
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
out_desc
=
OutDesc
{};
constexpr
auto
in_reg_desc
=
make_ConstantTensorDescriptor
(
in_desc
.
GetLengths
());
constexpr
auto
wei_reg_desc
=
make_ConstantTensorDescriptor
(
wei_desc
.
GetLengths
());
// register
TInWei
p_in_reg
[
in_reg_desc
.
GetElementSpace
()];
TInWei
p_wei_reg
[
wei_reg_desc
.
GetElementSpace
()];
// copy input tensor into register
threadwise_4d_tensor_copy
(
in_desc
,
p_in
,
in_reg_desc
,
p_in_reg
,
in_reg_desc
.
GetLengths
());
// copy input tensor into register
threadwise_4d_tensor_copy
(
wei_desc
,
p_wei
,
wei_reg_desc
,
p_wei_reg
,
wei_reg_desc
.
GetLengths
());
// do convolution
threadwise_direct_convolution_1
(
in_reg_desc
,
p_in_reg
,
wei_reg_desc
,
p_wei_reg
,
out_desc
,
p_out
);
}
// optimized for scenario where p_in and p_wei are in LDS, p_out is in register
// break down a non-1x1 convolution into a sequence of 1x1 convolutions,
// load 1x1 weight into register, and do 1x1 convolution in register.
template
<
class
Data
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
__device__
void
threadwise_direct_convolution_3
(
InDesc
,
Data
*
const
__restrict__
p_in
,
WeiDesc
,
Data
*
const
__restrict__
p_wei
,
OutDesc
,
Data
*
__restrict__
p_out
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
out_desc
=
OutDesc
{};
constexpr
auto
in_reg_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
in_desc
.
GetLength
(
I0
),
in_desc
.
GetLength
(
I1
),
out_desc
.
GetLength
(
I2
),
out_desc
.
GetLength
(
I3
)
>
{});
constexpr
auto
wei_reg_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
wei_desc
.
GetLength
(
I0
),
wei_desc
.
GetLength
(
I1
),
1
,
1
>
{});
Data
p_in_reg
[
in_reg_desc
.
GetElementSpace
()];
Data
p_wei_reg
[
wei_reg_desc
.
GetElementSpace
()];
constexpr
index_t
in_w_new_read
=
1
;
constexpr
auto
in_desc_reg_new_read
=
make_ConstantTensorDescriptor
(
Sequence
<
in_reg_desc
.
GetLength
(
I0
),
in_reg_desc
.
GetLength
(
I1
),
in_reg_desc
.
GetLength
(
I2
),
in_w_new_read
>
{});
#if 0
// this verison reused old input data in register, and read new data from LDS
// loop over vertical direction
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
{
// read first input
threadwise_4d_tensor_copy(in_desc,
p_in + in_desc.Get1dIndex(0, 0, y, 0),
in_reg_desc,
p_in_reg,
in_reg_desc.GetLengths());
// read first 1x1 weight
threadwise_4d_tensor_copy(wei_desc,
p_wei + wei_desc.Get1dIndex(0, 0, y, 0),
wei_reg_desc,
p_wei_reg,
wei_reg_desc.GetLengths());
// do first 1x1 conv
threadwise_direct_convolution_1(
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
// loop over horizontal direction
for(index_t x = 1; x < wei_desc.GetLength(I3); ++x)
{
// read new weight
threadwise_4d_tensor_copy(wei_desc,
p_wei + wei_desc.Get1dIndex(0, 0, y, x),
wei_reg_desc,
p_wei_reg,
wei_reg_desc.GetLengths());
// shift old input to the left
threadwise_4d_tensor_shift_down(in_reg_desc, p_in_reg, I3, Number<in_w_new_read>{});
// read new input
threadwise_4d_tensor_copy(
in_desc,
p_in + in_desc.Get1dIndex(0, 0, y, x + in_reg_desc.GetLength(I3) - 1),
in_reg_desc,
p_in_reg +
in_reg_desc.Get1dIndex(0, 0, 0, in_reg_desc.GetLength(I3) - in_w_new_read),
in_desc_reg_new_read.GetLengths());
// do 1x1 conv
threadwise_direct_convolution_1(
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
}
}
#elif
1
// this version read all input from LDS when filter moves
// loop over vertical direction
for
(
index_t
y
=
0
;
y
<
wei_desc
.
GetLength
(
I2
);
++
y
)
{
// loop over horizontal direction
for
(
index_t
x
=
0
;
x
<
wei_desc
.
GetLength
(
I3
);
++
x
)
{
// read new weight
threadwise_4d_tensor_copy
(
wei_desc
,
p_wei
+
wei_desc
.
Get1dIndex
(
0
,
0
,
y
,
x
),
wei_reg_desc
,
p_wei_reg
,
wei_reg_desc
.
GetLengths
());
// read new input
threadwise_4d_tensor_copy
(
in_desc
,
p_in
+
in_desc
.
Get1dIndex
(
0
,
0
,
y
,
x
),
in_reg_desc
,
p_in_reg
,
in_reg_desc
.
GetLengths
());
// do 1x1 conv
threadwise_direct_convolution_1
(
in_reg_desc
,
p_in_reg
,
wei_reg_desc
,
p_wei_reg
,
out_desc
,
p_out
);
}
}
#endif
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment