Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
67c6f73f
You need to sign in or sign up before continuing.
Commit
67c6f73f
authored
Feb 15, 2019
by
Chao Liu
Browse files
hip build
parent
121693b3
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
303 additions
and
210 deletions
+303
-210
CMakeLists.txt
CMakeLists.txt
+15
-5
build/cmake-cuda.sh
build/cmake-cuda.sh
+0
-0
build/cmake-hip.sh
build/cmake-hip.sh
+16
-0
driver/CMakeLists.txt
driver/CMakeLists.txt
+4
-2
driver/device_direct_convolution_1.cuh
driver/device_direct_convolution_1.cuh
+21
-29
driver/device_direct_convolution_2.cuh
driver/device_direct_convolution_2.cuh
+21
-29
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh
+7
-15
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh
...ice_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh
+8
-15
driver/device_implicit_gemm_convolution_1_nchw_kcsr_nkhw.cuh
driver/device_implicit_gemm_convolution_1_nchw_kcsr_nkhw.cuh
+20
-29
driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
+7
-15
driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
+7
-15
driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
+7
-14
driver/driver.cpp
driver/driver.cpp
+9
-8
src/CMakeLists.txt
src/CMakeLists.txt
+7
-11
src/device.cpp
src/device.cpp
+114
-0
src/include/blockwise_4d_tensor_op.cuh
src/include/blockwise_4d_tensor_op.cuh
+5
-4
src/include/blockwise_direct_convolution.cuh
src/include/blockwise_direct_convolution.cuh
+15
-12
src/include/blockwise_gemm.cuh
src/include/blockwise_gemm.cuh
+8
-6
src/include/common.cuh
src/include/common.cuh
+1
-1
src/include/config.h.in
src/include/config.h.in
+11
-0
No files found.
CMakeLists.txt
View file @
67c6f73f
...
...
@@ -5,6 +5,7 @@ project(modular_convolution)
enable_language
(
CXX
)
set
(
CMAKE_CXX_STANDARD 14
)
set
(
CMAKE_CXX_STANDARD_REQUIRED ON
)
set
(
CMAKE_CXX_EXTENSIONS OFF
)
message
(
"CMAKE_CXX_COMPILER_ID:
${
CMAKE_CXX_COMPILER_ID
}
"
)
#boost
...
...
@@ -16,7 +17,7 @@ message("Boost_LIBRARY_DIRS: ${Boost_LIBRARY_DIRS}")
include_directories
(
BEFORE
${
Boost_INCLUDE_DIRS
}
)
link_directories
(
${
Boost_LIBRARY_DIRS
}
)
#
o
penMP
#
O
penMP
if
(
NOT
(
${
CMAKE_CXX_COMPILER_ID
}
STREQUAL
"AppleClang"
)
)
find_package
(
OpenMP REQUIRED
)
...
...
@@ -30,11 +31,20 @@ if( NOT( ${CMAKE_CXX_COMPILER_ID} STREQUAL "AppleClang") )
link_libraries
(
${
OpenMP_pthread_LIBRARY
}
)
endif
(
NOT
(
${
CMAKE_CXX_COMPILER_ID
}
STREQUAL
"AppleClang"
)
)
#cuda
enable_language
(
CUDA
)
include_directories
(
BEFORE
${
CUDA_COMMON_INCLUDE_DIR
}
)
#GPU backend
if
(
DEVICE_BACKEND STREQUAL
"HIP"
)
set
(
DEVICE_BACKEND_HIP 1
)
set
(
CMAKE_MODULE_PATH
"/opt/rocm/hip/cmake"
${
CMAKE_MODULE_PATH
}
)
find_package
(
HIP REQUIRED
)
elseif
(
DEVICE_BACKEND STREQUAL
"CUDA"
)
set
(
DEVICE_BACKEND_CUDA 1
)
enable_language
(
CUDA
)
include_directories
(
BEFORE
${
CUDA_COMMON_INCLUDE_DIR
}
)
endif
()
#
include_directories
(
BEFORE src/include
)
include_directories
(
BEFORE src/include
${
PROJECT_BINARY_DIR
}
/src/include
)
add_subdirectory
(
src
)
add_subdirectory
(
driver
)
build/cmake.sh
→
build/cmake
-cuda
.sh
View file @
67c6f73f
File moved
build/cmake-hip.sh
0 → 100755
View file @
67c6f73f
#!/bin/bash
rm
-f
CMakeCache.txt
rm
-f
*
.cmake
rm
-rf
CMakeFiles
MY_PROJECT_SOURCE
=
/home/chao/code/modular_convolution
MY_PROJECT_INSTALL
=
../install.dir
cmake
\
-D
CMAKE_INSTALL_PREFIX
=
${
MY_PROJECT_INSTALL
}
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
DEVICE_BACKEND
=
"HIP"
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
${
MY_PROJECT_SOURCE
}
driver/CMakeLists.txt
View file @
67c6f73f
add_executable
(
conv conv.cu
)
target_link_libraries
(
conv tensor device
)
set
(
DRIVER_SOURCE driver.cpp
)
add_executable
(
driver
${
DRIVER_SOURCE
}
)
target_link_libraries
(
driver PRIVATE tensor
)
driver/device_direct_convolution_1.cuh
View file @
67c6f73f
...
...
@@ -54,39 +54,31 @@ void device_direct_convolution_1(InDesc,
(
out_desc
.
GetLength
(
I2
)
/
(
OutTileSizeH
*
YPerBlock
))
*
(
out_desc
.
GetLength
(
I3
)
/
(
OutTileSizeW
*
XPerBlock
));
dim3
block_dim
(
BlockSize
);
dim3
grid_dim
(
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
{
const
void
*
f
=
reinterpret_cast
<
const
void
*>
(
gridwise_direct_convolution_1
<
T
,
InDesc
,
WeiDesc
,
OutDesc
,
OutTileSizeH
,
OutTileSizeW
,
NPerBlock
,
KPerBlock
,
CPerBlock
,
YPerBlock
,
XPerBlock
,
NPerThread
,
KPerThread
,
CPerThread
,
BlockSize
,
GridSize
>
);
T
*
in_dev_ptr
=
static_cast
<
T
*>
(
in_device_buf
.
GetDeviceBuffer
());
T
*
wei_dev_ptr
=
static_cast
<
T
*>
(
wei_device_buf
.
GetDeviceBuffer
());
T
*
out_dev_ptr
=
static_cast
<
T
*>
(
out_device_buf
.
GetDeviceBuffer
());
void
*
args
[]
=
{
&
in_dev_ptr
,
&
wei_dev_ptr
,
&
out_dev_ptr
};
float
time
=
0
;
launch_kernel
(
f
,
grid_dim
,
block_dim
,
args
,
time
);
float
time
=
launch_kernel
(
gridwise_direct_convolution_1
<
T
,
InDesc
,
WeiDesc
,
OutDesc
,
OutTileSizeH
,
OutTileSizeW
,
NPerBlock
,
KPerBlock
,
CPerBlock
,
YPerBlock
,
XPerBlock
,
NPerThread
,
KPerThread
,
CPerThread
,
BlockSize
,
GridSize
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
static_cast
<
T
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
...
...
driver/device_direct_convolution_2.cuh
View file @
67c6f73f
...
...
@@ -69,39 +69,31 @@ void device_direct_convolution_2(InDesc,
(
out_desc
.
GetLength
(
I2
)
/
(
OutTileSizeH
*
YPerBlock
))
*
(
out_desc
.
GetLength
(
I3
)
/
(
OutTileSizeW
*
XPerBlock
));
dim3
block_dim
(
BlockSize
);
dim3
grid_dim
(
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
{
const
void
*
f
=
reinterpret_cast
<
const
void
*>
(
gridwise_direct_convolution_2
<
T
,
InDesc
,
WeiDesc
,
OutDesc
,
OutTileSizeH
,
OutTileSizeW
,
NPerBlock
,
KPerBlock
,
CPerBlock
,
YPerBlock
,
XPerBlock
,
NPerThread
,
KPerThread
,
CPerThread
,
BlockSize
,
GridSize
>
);
T
*
in_dev_ptr
=
static_cast
<
T
*>
(
in_device_buf
.
GetDeviceBuffer
());
T
*
wei_dev_ptr
=
static_cast
<
T
*>
(
wei_device_buf
.
GetDeviceBuffer
());
T
*
out_dev_ptr
=
static_cast
<
T
*>
(
out_device_buf
.
GetDeviceBuffer
());
void
*
args
[]
=
{
&
in_dev_ptr
,
&
wei_dev_ptr
,
&
out_dev_ptr
};
float
time
=
0
;
launch_kernel
(
f
,
grid_dim
,
block_dim
,
args
,
time
);
float
time
=
launch_kernel
(
gridwise_direct_convolution_2
<
T
,
InDesc
,
WeiDesc
,
OutDesc
,
OutTileSizeH
,
OutTileSizeW
,
NPerBlock
,
KPerBlock
,
CPerBlock
,
YPerBlock
,
XPerBlock
,
NPerThread
,
KPerThread
,
CPerThread
,
BlockSize
,
GridSize
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
static_cast
<
T
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
...
...
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh
View file @
67c6f73f
...
...
@@ -194,14 +194,11 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
((
N
+
NPerBlock
-
1
)
/
NPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
)
*
((
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
)
*
((
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
);
dim3
block_dim
(
BlockSize
);
dim3
grid_dim
(
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
{
const
void
*
f
=
reinterpret_cast
<
const
void
*>
(
float
time
=
launch_kernel
(
gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn
<
GridSize
,
BlockSize
,
T
,
...
...
@@ -221,17 +218,12 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
>
);
T
*
in_dev_ptr
=
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
());
T
*
wei_dev_ptr
=
static_cast
<
T
*>
(
wei_csrk_device_buf
.
GetDeviceBuffer
());
T
*
out_dev_ptr
=
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
());
void
*
args
[]
=
{
&
in_dev_ptr
,
&
wei_dev_ptr
,
&
out_dev_ptr
};
float
time
=
0
;
launch_kernel
(
f
,
grid_dim
,
block_dim
,
args
,
time
);
WeiBlockCopyDataPerRead
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_csrk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
...
...
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh
View file @
67c6f73f
...
...
@@ -94,7 +94,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(InDesc,
constexpr unsigned WeiBlockCopyThreadPerDim1 = 1;
constexpr unsigned BlockSize = 8;
#elif
0
#elif
1
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
constexpr
unsigned
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
64
;
...
...
@@ -246,14 +246,11 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(InDesc,
((
N
+
NPerBlock
-
1
)
/
NPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
)
*
((
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
)
*
((
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
);
dim3
block_dim
(
BlockSize
);
dim3
grid_dim
(
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
{
const
void
*
f
=
reinterpret_cast
<
const
void
*>
(
float
time
=
launch_kernel
(
#if 0
gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded
#elif
1
...
...
@@ -278,17 +275,13 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(InDesc,
HoPerThread
,
WoPerThread
,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
);
T
*
in_dev_ptr
=
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
());
T
*
wei_dev_ptr
=
static_cast
<
T
*>
(
wei_csrk_device_buf
.
GetDeviceBuffer
());
T
*
out_dev_ptr
=
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
());
void
*
args
[]
=
{
&
in_dev_ptr
,
&
wei_dev_ptr
,
&
out_dev_ptr
};
float
time
=
0
;
WeiBlockCopyThreadPerDim1
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
launch_kernel
(
f
,
grid_dim
,
block_dim
,
args
,
time
);
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_csrk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
...
...
driver/device_implicit_gemm_convolution_1_nchw_kcsr_nkhw.cuh
View file @
67c6f73f
...
...
@@ -52,39 +52,30 @@ void device_implicit_gemm_convolution_1_nchw_kcsr_nkhw(InDesc,
(
out_desc
.
GetLength
(
I0
)
/
NPerBlock
)
*
(
out_desc
.
GetLength
(
I1
)
/
KPerBlock
)
*
(
out_desc
.
GetLength
(
I2
)
/
HoPerBlock
)
*
(
out_desc
.
GetLength
(
I3
)
/
WoPerBlock
);
dim3
block_dim
(
BlockSize
);
dim3
grid_dim
(
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
{
const
void
*
f
=
reinterpret_cast
<
const
void
*>
(
gridwise_implicit_gemm_convolution_1_nchw_kcsr_nkhw
<
GridSize
,
BlockSize
,
T
,
InDesc
,
WeiDesc
,
OutDesc
,
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
KPerThread
,
CPerThread
,
HoPerThread
,
WoPerThread
>
);
T
*
in_dev_ptr
=
static_cast
<
T
*>
(
in_device_buf
.
GetDeviceBuffer
());
T
*
wei_dev_ptr
=
static_cast
<
T
*>
(
wei_device_buf
.
GetDeviceBuffer
());
T
*
out_dev_ptr
=
static_cast
<
T
*>
(
out_device_buf
.
GetDeviceBuffer
());
void
*
args
[]
=
{
&
in_dev_ptr
,
&
wei_dev_ptr
,
&
out_dev_ptr
};
float
time
=
0
;
launch_kernel
(
f
,
grid_dim
,
block_dim
,
args
,
time
);
float
time
=
launch_kernel
(
gridwise_implicit_gemm_convolution_1_nchw_kcsr_nkhw
<
GridSize
,
BlockSize
,
T
,
InDesc
,
WeiDesc
,
OutDesc
,
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
KPerThread
,
CPerThread
,
HoPerThread
,
WoPerThread
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
static_cast
<
T
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
...
...
driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
View file @
67c6f73f
...
...
@@ -104,14 +104,11 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
((
N
+
NPerBlock
-
1
)
/
NPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
)
*
((
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
)
*
((
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
);
dim3
block_dim
(
BlockSize
);
dim3
grid_dim
(
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
{
const
void
*
f
=
reinterpret_cast
<
const
void
*>
(
float
time
=
launch_kernel
(
gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw
<
GridSize
,
BlockSize
,
T
,
...
...
@@ -127,17 +124,12 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
KPerThread
,
CPerThread
,
HoPerThread
,
WoPerThread
>
);
T
*
in_dev_ptr
=
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
());
T
*
wei_dev_ptr
=
static_cast
<
T
*>
(
wei_srck_device_buf
.
GetDeviceBuffer
());
T
*
out_dev_ptr
=
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
());
void
*
args
[]
=
{
&
in_dev_ptr
,
&
wei_dev_ptr
,
&
out_dev_ptr
};
float
time
=
0
;
launch_kernel
(
f
,
grid_dim
,
block_dim
,
args
,
time
);
WoPerThread
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_srck_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
...
...
driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
View file @
67c6f73f
...
...
@@ -195,9 +195,6 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
constexpr
unsigned
GridSize
=
((
N
*
Hi
*
Wi
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
);
dim3
block_dim
(
BlockSize
);
dim3
grid_dim
(
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
// mem
...
...
@@ -213,7 +210,7 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
{
const
void
*
f
=
reinterpret_cast
<
const
void
*>
(
float
time
=
launch_kernel
(
#if 0
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw
#else
...
...
@@ -244,17 +241,12 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
>
);
T
*
in_dev_ptr
=
static_cast
<
T
*>
(
in_cnhw_device_buf
.
GetDeviceBuffer
());
T
*
wei_dev_ptr
=
static_cast
<
T
*>
(
wei_csrk_device_buf
.
GetDeviceBuffer
());
T
*
out_dev_ptr
=
static_cast
<
T
*>
(
out_knhw_device_buf
.
GetDeviceBuffer
());
void
*
args
[]
=
{
&
in_dev_ptr
,
&
wei_dev_ptr
,
&
out_dev_ptr
};
float
time
;
launch_kernel
(
f
,
grid_dim
,
block_dim
,
args
,
time
);
WeiBlockCopyDataPerRead
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
static_cast
<
T
*>
(
in_cnhw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_csrk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_knhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
...
...
driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
View file @
67c6f73f
...
...
@@ -123,9 +123,6 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
constexpr
unsigned
GridSize
=
((
N
*
Hi
*
Wi
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
);
dim3
block_dim
(
BlockSize
);
dim3
grid_dim
(
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
// mem
...
...
@@ -141,7 +138,7 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
{
const
void
*
f
=
reinterpret_cast
<
const
void
*>
(
float
time
=
launch_kernel
(
#if 1
gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw
#else
...
...
@@ -162,17 +159,13 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
GemmThreadPerColumnPerCluster
,
GemmThreadPerRowPerCluster
,
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
>
);
T
*
in_dev_ptr
=
static_cast
<
T
*>
(
in_cnhw_device_buf
.
GetDeviceBuffer
());
T
*
wei_dev_ptr
=
static_cast
<
T
*>
(
wei_srck_device_buf
.
GetDeviceBuffer
());
T
*
out_dev_ptr
=
static_cast
<
T
*>
(
out_knhw_device_buf
.
GetDeviceBuffer
());
void
*
args
[]
=
{
&
in_dev_ptr
,
&
wei_dev_ptr
,
&
out_dev_ptr
};
float
time
=
0
;
InBlockCopyThreadPerDim1
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
launch_kernel
(
f
,
grid_dim
,
block_dim
,
args
,
time
);
static_cast
<
T
*>
(
in_cnhw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_srck_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_knhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
...
...
driver/
conv.cu
→
driver/
driver.cpp
View file @
67c6f73f
...
...
@@ -2,6 +2,7 @@
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "config.h"
#include "tensor.hpp"
#include "ConstantTensorDescriptor.cuh"
#include "conv_common.cuh"
...
...
@@ -49,7 +50,7 @@ struct GeneratorTensor_3
std
::
initializer_list
<
std
::
size_t
>
ids
=
{
static_cast
<
std
::
size_t
>
(
is
)...};
std
::
vector
<
std
::
size_t
>
lens
(
sizeof
...(
Is
),
100
);
std
::
vector
<
std
::
size_t
>
strides
(
sizeof
...(
Is
),
1
);
std
::
partial_sum
(
lens
.
rbegin
(),
lens
.
rbegin
()
+
(
sizeof
...(
Is
)
-
1
),
strides
.
rbegin
()
+
1
);
std
::
partial_sum
(
lens
.
rbegin
(),
lens
.
rbegin
()
+
(
sizeof
...(
Is
)
-
1
),
strides
.
rbegin
()
+
1
);
return
std
::
inner_product
(
ids
.
begin
(),
ids
.
end
(),
strides
.
begin
(),
std
::
size_t
(
0
))
+
1
;
#endif
}
...
...
@@ -339,7 +340,7 @@ void host_winograd_3x3_convolution(
std
::
size_t
ho
=
OutTileSizeH
*
y
+
j
;
for
(
int
i
=
0
;
i
<
OutTileSizeW
;
++
i
)
{
std
::
size_t
wo
=
OutTileSizeW
*
x
+
i
;
std
::
size_t
wo
=
OutTileSizeW
*
x
+
i
;
out
(
n
,
k
,
ho
,
wo
)
=
out_hold
(
n
,
k
,
y
,
x
,
j
,
i
);
}
}
...
...
@@ -392,13 +393,13 @@ int main()
constexpr unsigned WPad = 0;
#elif
0
// 3x3, 34x34
constexpr
unsigned
N
=
64
;
constexpr
unsigned
C
=
256
;
constexpr
unsigned
N
=
64
;
constexpr
unsigned
C
=
256
;
constexpr
unsigned
HI
=
34
;
constexpr
unsigned
WI
=
34
;
constexpr
unsigned
K
=
64
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
constexpr
unsigned
K
=
64
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
constexpr
unsigned
HPad
=
0
;
constexpr
unsigned
WPad
=
0
;
...
...
@@ -601,7 +602,7 @@ int main()
#endif
(
in_nchw_desc
,
in_nchw
,
wei_kcsr_desc
,
wei_kcsr
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
#elif
0
#elif
1
device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded
(
in_nchw_desc
,
in_nchw
,
wei_kcsr_desc
,
...
...
src/CMakeLists.txt
View file @
67c6f73f
configure_file
(
"
${
PROJECT_SOURCE_DIR
}
/src/include/config.h.in"
"
${
PROJECT_BINARY_DIR
}
/src/include/config.h"
)
set
(
TENSOR_SOURCE
tensor.cpp;
device.cpp;
)
add_library
(
tensor SHARED
${
TENSOR_SOURCE
}
)
set_target_properties
(
tensor PROPERTIES PREFIX
""
)
target_compile_features
(
tensor PUBLIC
)
set_target_properties
(
tensor PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS tensor LIBRARY DESTINATION lib
)
set
(
DEVICE_
SOURCE
device.cu;
)
if
(
DEVICE_
BACKEND STREQUAL
"CUDA"
)
target_link_libraries
(
device nvToolsExt cudart
)
endif
(
)
add_library
(
device SHARED
${
DEVICE_SOURCE
}
)
set_target_properties
(
device PROPERTIES PREFIX
""
)
target_compile_features
(
device PUBLIC
)
set_target_properties
(
device PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device LIBRARY DESTINATION lib
)
target_link_libraries
(
device nvToolsExt cudart
)
install
(
TARGETS tensor LIBRARY DESTINATION lib
)
src/device.c
u
→
src/device.c
pp
View file @
67c6f73f
#include "config.h"
#include "device.hpp"
#include "cuda_runtime.h"
#include "nvToolsExt.h"
#include "helper_cuda.h"
DeviceMem
::
DeviceMem
(
std
::
size_t
mem_size
)
:
mMemSize
(
mem_size
)
{
#if DEVICE_BACKEND_HIP
hipGetErrorString
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
#elif DEVICE_BACKEND_CUDA
checkCudaErrors
(
cudaMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
#endif
}
void
*
DeviceMem
::
GetDeviceBuffer
()
{
return
mpDeviceBuf
;
}
void
DeviceMem
::
ToDevice
(
const
void
*
p
)
{
#if DEVICE_BACKEND_HIP
hipGetErrorString
(
hipMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
mMemSize
,
hipMemcpyHostToDevice
));
#elif DEVICE_BACKEND_CUDA
checkCudaErrors
(
cudaMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
mMemSize
,
cudaMemcpyHostToDevice
));
#endif
}
void
DeviceMem
::
FromDevice
(
void
*
p
)
{
#if DEVICE_BACKEND_HIP
hipGetErrorString
(
hipMemcpy
(
p
,
mpDeviceBuf
,
mMemSize
,
hipMemcpyDeviceToHost
));
#elif DEVICE_BACKEND_CUDA
checkCudaErrors
(
cudaMemcpy
(
p
,
mpDeviceBuf
,
mMemSize
,
cudaMemcpyDeviceToHost
));
#endif
}
DeviceMem
::~
DeviceMem
()
{
checkCudaErrors
(
cudaFree
(
mpDeviceBuf
));
}
DeviceMem
::~
DeviceMem
()
{
#if DEVICE_BACKEND_HIP
hipGetErrorString
(
hipFree
(
mpDeviceBuf
));
#elif DEVICE_BACKEND_CUDA
checkCudaErrors
(
cudaFree
(
mpDeviceBuf
));
#endif
}
struct
KernelTimerImpl
{
KernelTimerImpl
()
{
#if DEVICE_BACKEND_HIP
hipEventCreate
(
&
mStart
);
hipEventCreate
(
&
mEnd
);
#elif DEVICE_BACKEND_CUDA
cudaEventCreate
(
&
mStart
);
cudaEventCreate
(
&
mEnd
);
#endif
}
~
KernelTimerImpl
()
{
#if DEVICE_BACKEND_HIP
hipEventDestroy
(
mStart
);
hipEventDestroy
(
mEnd
);
#elif DEVICE_BACKEND_CUDA
cudaEventDestroy
(
mStart
);
cudaEventDestroy
(
mEnd
);
#endif
}
void
Start
()
{
cudaEventRecord
(
mStart
,
0
);
}
void
Start
()
{
#if DEVICE_BACKEND_HIP
hipEventRecord
(
mStart
,
0
);
#elif DEVICE_BACKEND_CUDA
cudaEventRecord
(
mStart
,
0
);
#endif
}
void
End
()
{
#if DEVICE_BACKEND_HIP
hipEventRecord
(
mEnd
,
0
);
hipEventSynchronize
(
mEnd
);
#elif DEVICE_BACKEND_CUDA
cudaEventRecord
(
mEnd
,
0
);
cudaEventSynchronize
(
mEnd
);
#endif
}
float
GetElapsedTime
()
const
{
float
time
;
#if DEVICE_BACKEND_HIP
hipEventElapsedTime
(
&
time
,
mStart
,
mEnd
);
#elif DEVICE_BACKEND_CUDA
cudaEventElapsedTime
(
&
time
,
mStart
,
mEnd
);
#endif
return
time
;
}
#if DEVICE_BACKEND_HIP
hipEvent_t
mStart
,
mEnd
;
#elif DEVICE_BACKEND_CUDA
cudaEvent_t
mStart
,
mEnd
;
#endif
};
KernelTimer
::
KernelTimer
()
:
impl
(
new
KernelTimerImpl
())
{}
...
...
@@ -64,16 +112,3 @@ void KernelTimer::Start() { impl->Start(); }
void
KernelTimer
::
End
()
{
impl
->
End
();
}
float
KernelTimer
::
GetElapsedTime
()
const
{
return
impl
->
GetElapsedTime
();
}
void
launch_kernel
(
const
void
*
func
,
dim3
grid_dim
,
dim3
block_dim
,
void
**
args
,
float
&
time
)
{
KernelTimer
timer
;
timer
.
Start
();
cudaError_t
error
=
cudaLaunchKernel
(
func
,
grid_dim
,
block_dim
,
args
,
0
,
0
);
timer
.
End
();
time
=
timer
.
GetElapsedTime
();
checkCudaErrors
(
error
);
}
src/include/blockwise_4d_tensor_op.cuh
View file @
67c6f73f
...
...
@@ -245,10 +245,11 @@ struct BlockwiseChwnTensorCopyPadded
constexpr
unsigned
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
const
Float
*
p_src_tmp
=
p_src
+
src_desc
.
Get1dIndex
(
c_block_data_begin
,
(
ho_block_data_begin
+
h_block_pad_low
)
-
h_global_pad_low
,
(
wo_block_data_begin
+
w_block_pad_low
)
-
w_global_pad_low
,
n_block_data_begin
);
p_src
+
src_desc
.
Get1dIndex
(
c_block_data_begin
,
(
ho_block_data_begin
+
h_block_pad_low
)
-
h_global_pad_low
,
(
wo_block_data_begin
+
w_block_pad_low
)
-
w_global_pad_low
,
n_block_data_begin
);
#if 0
if(get_thread_local_1d_id() == 0)
...
...
src/include/blockwise_direct_convolution.cuh
View file @
67c6f73f
...
...
@@ -95,10 +95,11 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
Float
p_out_thread
[
out_thread_desc
.
GetElementSpace
()];
threadwise_4d_tensor_copy
(
out_block_desc
,
p_out_block
+
out_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
k_thread_data_begin
,
ho_thread_data_begin
,
wo_thread_data_begin
),
p_out_block
+
out_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
k_thread_data_begin
,
ho_thread_data_begin
,
wo_thread_data_begin
),
out_thread_desc
,
p_out_thread
,
out_thread_desc
.
GetLengths
());
...
...
@@ -109,10 +110,11 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
// threadwise convolution
threadwise_direct_convolution_2
(
in_thread_block_desc
,
p_in_block
+
in_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data_begin
,
hi_thread_data_begin
,
wi_thread_data_begin
),
p_in_block
+
in_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data_begin
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wei_thread_block_desc
,
p_wei_block
+
wei_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data_begin
,
0
,
0
),
...
...
@@ -124,10 +126,11 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
threadwise_4d_tensor_copy
(
out_thread_desc
,
p_out_thread
,
out_block_desc
,
p_out_block
+
out_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
k_thread_data_begin
,
ho_thread_data_begin
,
wo_thread_data_begin
),
p_out_block
+
out_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
k_thread_data_begin
,
ho_thread_data_begin
,
wo_thread_data_begin
),
out_thread_desc
.
GetLengths
());
}
}
src/include/blockwise_gemm.cuh
View file @
67c6f73f
...
...
@@ -305,8 +305,9 @@ struct BlockwiseGemmBlockABlockBThreadC
constexpr
unsigned
NClusterWork
=
(
NPerBlock
+
NPerThread
*
NThreadPerCluster
-
1
)
/
(
NPerThread
*
NThreadPerCluster
);
static_assert
(
BlockSize
==
(
MClusterWork
*
MThreadPerCluster
)
*
(
NClusterWork
*
NThreadPerCluster
),
static_assert
(
BlockSize
==
(
MClusterWork
*
MThreadPerCluster
)
*
(
NClusterWork
*
NThreadPerCluster
),
"wrong! wrong BlockSize"
);
if
(
DistributeThreadAlongColumnFirst
)
...
...
@@ -685,7 +686,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// preload A, B
// preload A, B
#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
// copy A-sub to form A
...
...
@@ -718,7 +719,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA
*
p_a_thread_next
=
even_loop
?
p_a_thread_1
:
p_a_thread_0
;
FloatB
*
p_b_thread_next
=
even_loop
?
p_b_thread_1
:
p_b_thread_0
;
// preload next A, B
// preload next A, B
#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
// copy A-sub to form A
...
...
@@ -906,8 +907,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
c_thread_sub_mtx
,
False
,
p_c_thread
+
c_thread_mtx
.
Get1dIndex
(
m_repeat
*
MPerThreadSubC
,
n_repeat
*
NPerThreadSubC
),
p_c_thread
+
c_thread_mtx
.
Get1dIndex
(
m_repeat
*
MPerThreadSubC
,
n_repeat
*
NPerThreadSubC
),
f_accum
);
}
}
...
...
src/include/common.cuh
View file @
67c6f73f
...
...
@@ -62,4 +62,4 @@ struct Sequence
printf
(
"Sequence::ReorderByPutOldToNew not implemented"
);
assert
(
false
);
}
};
\ No newline at end of file
};
src/include/config.h.in
0 → 100644
View file @
67c6f73f
#pragma once
#cmakedefine01 DEVICE_BACKEND_HIP
#cmakedefine01 DEVICE_BACKEND_CUDA
#if DEVICE_BACKEND_HIP
#include "hip/hip_runtime.h"
#elif DEVICE_BACKEND_CUDA
#include "cuda_runtime.h"
#include "nvToolsExt.h"
#include "helper_cuda.h"
#endif
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment