Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
b2888adf
Commit
b2888adf
authored
Feb 15, 2019
by
Chao Liu
Browse files
change file extension to hip.hpp and hip.cpp
parent
a414e3fd
Changes
35
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
88 additions
and
99 deletions
+88
-99
driver/device_direct_convolution_1.hpp
driver/device_direct_convolution_1.hpp
+1
-1
driver/device_direct_convolution_2.hpp
driver/device_direct_convolution_2.hpp
+1
-1
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp
+1
-1
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp
...ice_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp
+2
-2
driver/device_implicit_gemm_convolution_1_nchw_kcsr_nkhw.hpp
driver/device_implicit_gemm_convolution_1_nchw_kcsr_nkhw.hpp
+1
-1
driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.hpp
driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.hpp
+1
-1
driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.hpp
driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.hpp
+2
-2
driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.hpp
driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.hpp
+2
-2
driver/driver.hip.cpp
driver/driver.hip.cpp
+11
-11
src/include/ConstantMatrixDescriptor.hip.hpp
src/include/ConstantMatrixDescriptor.hip.hpp
+1
-1
src/include/ConstantTensorDescriptor.hip.hpp
src/include/ConstantTensorDescriptor.hip.hpp
+1
-1
src/include/blockwise_2d_tensor_op.hip.hpp
src/include/blockwise_2d_tensor_op.hip.hpp
+1
-1
src/include/blockwise_4d_tensor_op.hip.hpp
src/include/blockwise_4d_tensor_op.hip.hpp
+5
-6
src/include/blockwise_direct_convolution.hip.hpp
src/include/blockwise_direct_convolution.hip.hpp
+15
-18
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+5
-7
src/include/common.hip.hpp
src/include/common.hip.hpp
+0
-0
src/include/conv_common.hip.hpp
src/include/conv_common.hip.hpp
+1
-1
src/include/device.hpp
src/include/device.hpp
+2
-2
src/include/gridwise_direct_convolution_1.hip.hpp
src/include/gridwise_direct_convolution_1.hip.hpp
+13
-14
src/include/gridwise_direct_convolution_2.hip.hpp
src/include/gridwise_direct_convolution_2.hip.hpp
+22
-26
No files found.
driver/device_direct_convolution_1.
cuh
→
driver/device_direct_convolution_1.
hpp
View file @
b2888adf
#pragma once
#pragma once
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "gridwise_direct_convolution_1.
cuh
"
#include "gridwise_direct_convolution_1.
hip.hpp
"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_direct_convolution_1
(
InDesc
,
void
device_direct_convolution_1
(
InDesc
,
...
...
driver/device_direct_convolution_2.
cuh
→
driver/device_direct_convolution_2.
hpp
View file @
b2888adf
#pragma once
#pragma once
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "gridwise_direct_convolution_2.
cuh
"
#include "gridwise_direct_convolution_2.
hip.hpp
"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_direct_convolution_2
(
InDesc
,
void
device_direct_convolution_2
(
InDesc
,
...
...
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.
cuh
→
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.
hpp
View file @
b2888adf
#pragma once
#pragma once
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.
cuh
"
#include "gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.
hip.hpp
"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_implicit_gemm_convolution_1_chwn_csrk_khwn
(
InDesc
,
void
device_implicit_gemm_convolution_1_chwn_csrk_khwn
(
InDesc
,
...
...
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.
cuh
→
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.
hpp
View file @
b2888adf
#pragma once
#pragma once
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.
cuh
"
#include "gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.
hip.hpp
"
#include "gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_pipeline.
cuh
"
#include "gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_pipeline.
hip.hpp
"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
LowerPads
,
class
UpperPads
>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
LowerPads
,
class
UpperPads
>
void
device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded
(
InDesc
,
void
device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded
(
InDesc
,
...
...
driver/device_implicit_gemm_convolution_1_nchw_kcsr_nkhw.
cuh
→
driver/device_implicit_gemm_convolution_1_nchw_kcsr_nkhw.
hpp
View file @
b2888adf
#pragma once
#pragma once
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "gridwise_implicit_gemm_convolution_1_nchw_kcsr_nkhw.
cuh
"
#include "gridwise_implicit_gemm_convolution_1_nchw_kcsr_nkhw.
hip.hpp
"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_implicit_gemm_convolution_1_nchw_kcsr_nkhw
(
InDesc
,
void
device_implicit_gemm_convolution_1_nchw_kcsr_nkhw
(
InDesc
,
...
...
driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.
cuh
→
driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.
hpp
View file @
b2888adf
#pragma once
#pragma once
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.
cuh
"
#include "gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.
hip.hpp
"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_implicit_gemm_convolution_1_nchw_srck_nkhw
(
InDesc
,
void
device_implicit_gemm_convolution_1_nchw_srck_nkhw
(
InDesc
,
...
...
driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.
cuh
→
driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.
hpp
View file @
b2888adf
#pragma once
#pragma once
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.
cuh
"
#include "gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.
hip.hpp
"
#include "gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer.
cuh
"
#include "gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer.
hip.hpp
"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_implicit_gemm_convolution_2_cnhw_csrk_knhw
(
InDesc
,
void
device_implicit_gemm_convolution_2_cnhw_csrk_knhw
(
InDesc
,
...
...
driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.
cuh
→
driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.
hpp
View file @
b2888adf
#pragma once
#pragma once
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.
cuh
"
#include "gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.
hip.hpp
"
#include "gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline.
cuh
"
#include "gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline.
hip.hpp
"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_implicit_gemm_convolution_2_cnhw_srck_knhw
(
InDesc
,
void
device_implicit_gemm_convolution_2_cnhw_srck_knhw
(
InDesc
,
...
...
driver/driver.hip.cpp
View file @
b2888adf
...
@@ -4,17 +4,17 @@
...
@@ -4,17 +4,17 @@
#include <cstdlib>
#include <cstdlib>
#include "config.h"
#include "config.h"
#include "tensor.hpp"
#include "tensor.hpp"
#include "ConstantTensorDescriptor.
cuh
"
#include "ConstantTensorDescriptor.
hip.hpp
"
#include "conv_common.
cuh
"
#include "conv_common.
hip.hpp
"
#include "device_direct_convolution_1.
cuh
"
#include "device_direct_convolution_1.
hpp
"
#include "device_direct_convolution_2.
cuh
"
#include "device_direct_convolution_2.
hpp
"
#include "device_implicit_gemm_convolution_1_nchw_kcsr_nkhw.
cuh
"
#include "device_implicit_gemm_convolution_1_nchw_kcsr_nkhw.
hpp
"
#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.
cuh
"
#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.
hpp
"
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn.
cuh
"
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn.
hpp
"
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.
cuh
"
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.
hpp
"
#include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.
cuh
"
#include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.
hpp
"
#include "device_implicit_gemm_convolution_2_cnhw_csrk_knhw.
cuh
"
#include "device_implicit_gemm_convolution_2_cnhw_csrk_knhw.
hpp
"
//#include "device_winograd_convolution.
cuh
"
//#include "device_winograd_convolution.
hip.hpp
"
struct
GeneratorTensor_1
struct
GeneratorTensor_1
{
{
...
...
src/include/ConstantMatrixDescriptor.
cuh
→
src/include/ConstantMatrixDescriptor.
hip.hpp
View file @
b2888adf
#pragma once
#pragma once
#include "common.
cuh
"
#include "common.
hip.hpp
"
template
<
unsigned
NRow_
,
unsigned
NCol_
,
unsigned
RowStride_
>
template
<
unsigned
NRow_
,
unsigned
NCol_
,
unsigned
RowStride_
>
struct
ConstantMatrixDescriptor
struct
ConstantMatrixDescriptor
...
...
src/include/ConstantTensorDescriptor.
cuh
→
src/include/ConstantTensorDescriptor.
hip.hpp
View file @
b2888adf
#pragma once
#pragma once
#include "common.
cuh
"
#include "common.
hip.hpp
"
// this is ugly, only for 2d
// this is ugly, only for 2d
template
<
unsigned
L0
,
unsigned
L1
>
template
<
unsigned
L0
,
unsigned
L1
>
...
...
src/include/blockwise_2d_tensor_op.
cuh
→
src/include/blockwise_2d_tensor_op.
hip.hpp
View file @
b2888adf
#pragma once
#pragma once
#include "ConstantTensorDescriptor.
cuh
"
#include "ConstantTensorDescriptor.
hip.hpp
"
template
<
unsigned
BlockSize
,
class
Float
,
class
DstDesc
,
class
F
>
template
<
unsigned
BlockSize
,
class
Float
,
class
DstDesc
,
class
F
>
__device__
void
__device__
void
...
...
src/include/blockwise_4d_tensor_op.
cuh
→
src/include/blockwise_4d_tensor_op.
hip.hpp
View file @
b2888adf
#pragma once
#pragma once
#include "ConstantTensorDescriptor.
cuh
"
#include "ConstantTensorDescriptor.
hip.hpp
"
template
<
unsigned
BlockSize
,
class
Float
,
class
DstDesc
,
class
F
>
template
<
unsigned
BlockSize
,
class
Float
,
class
DstDesc
,
class
F
>
__device__
void
__device__
void
...
@@ -245,8 +245,7 @@ struct BlockwiseChwnTensorCopyPadded
...
@@ -245,8 +245,7 @@ struct BlockwiseChwnTensorCopyPadded
constexpr
unsigned
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
constexpr
unsigned
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
const
Float
*
p_src_tmp
=
const
Float
*
p_src_tmp
=
p_src
+
p_src
+
src_desc
.
Get1dIndex
(
c_block_data_begin
,
src_desc
.
Get1dIndex
(
c_block_data_begin
,
(
ho_block_data_begin
+
h_block_pad_low
)
-
h_global_pad_low
,
(
ho_block_data_begin
+
h_block_pad_low
)
-
h_global_pad_low
,
(
wo_block_data_begin
+
w_block_pad_low
)
-
w_global_pad_low
,
(
wo_block_data_begin
+
w_block_pad_low
)
-
w_global_pad_low
,
n_block_data_begin
);
n_block_data_begin
);
...
...
src/include/blockwise_direct_convolution.
cuh
→
src/include/blockwise_direct_convolution.
hip.hpp
View file @
b2888adf
#pragma once
#pragma once
#include "ConstantTensorDescriptor.
cuh
"
#include "ConstantTensorDescriptor.
hip.hpp
"
#include "threadwise_4d_tensor_op.
cuh
"
#include "threadwise_4d_tensor_op.
hip.hpp
"
#include "threadwise_direct_convolution.
cuh
"
#include "threadwise_direct_convolution.
hip.hpp
"
template
<
unsigned
BlockSize
,
template
<
unsigned
BlockSize
,
class
Float
,
class
Float
,
...
@@ -95,8 +95,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
...
@@ -95,8 +95,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
Float
p_out_thread
[
out_thread_desc
.
GetElementSpace
()];
Float
p_out_thread
[
out_thread_desc
.
GetElementSpace
()];
threadwise_4d_tensor_copy
(
out_block_desc
,
threadwise_4d_tensor_copy
(
out_block_desc
,
p_out_block
+
p_out_block
+
out_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
out_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
k_thread_data_begin
,
k_thread_data_begin
,
ho_thread_data_begin
,
ho_thread_data_begin
,
wo_thread_data_begin
),
wo_thread_data_begin
),
...
@@ -110,8 +109,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
...
@@ -110,8 +109,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
// threadwise convolution
// threadwise convolution
threadwise_direct_convolution_2
(
threadwise_direct_convolution_2
(
in_thread_block_desc
,
in_thread_block_desc
,
p_in_block
+
p_in_block
+
in_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
in_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data_begin
,
c_thread_data_begin
,
hi_thread_data_begin
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wi_thread_data_begin
),
...
@@ -126,8 +124,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
...
@@ -126,8 +124,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
threadwise_4d_tensor_copy
(
out_thread_desc
,
threadwise_4d_tensor_copy
(
out_thread_desc
,
p_out_thread
,
p_out_thread
,
out_block_desc
,
out_block_desc
,
p_out_block
+
p_out_block
+
out_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
out_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
k_thread_data_begin
,
k_thread_data_begin
,
ho_thread_data_begin
,
ho_thread_data_begin
,
wo_thread_data_begin
),
wo_thread_data_begin
),
...
...
src/include/blockwise_gemm.
cuh
→
src/include/blockwise_gemm.
hip.hpp
View file @
b2888adf
#pragma once
#pragma once
#include "threadwise_gemm.
cuh
"
#include "threadwise_gemm.
hip.hpp
"
template
<
unsigned
BlockSize
,
template
<
unsigned
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixA
,
...
@@ -305,8 +305,7 @@ struct BlockwiseGemmBlockABlockBThreadC
...
@@ -305,8 +305,7 @@ struct BlockwiseGemmBlockABlockBThreadC
constexpr
unsigned
NClusterWork
=
constexpr
unsigned
NClusterWork
=
(
NPerBlock
+
NPerThread
*
NThreadPerCluster
-
1
)
/
(
NPerThread
*
NThreadPerCluster
);
(
NPerBlock
+
NPerThread
*
NThreadPerCluster
-
1
)
/
(
NPerThread
*
NThreadPerCluster
);
static_assert
(
BlockSize
==
static_assert
(
BlockSize
==
(
MClusterWork
*
MThreadPerCluster
)
*
(
MClusterWork
*
MThreadPerCluster
)
*
(
NClusterWork
*
NThreadPerCluster
),
(
NClusterWork
*
NThreadPerCluster
),
"wrong! wrong BlockSize"
);
"wrong! wrong BlockSize"
);
...
@@ -907,8 +906,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -907,8 +906,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
c_thread_sub_mtx
,
c_thread_sub_mtx
,
False
,
False
,
p_c_thread
+
p_c_thread
+
c_thread_mtx
.
Get1dIndex
(
m_repeat
*
MPerThreadSubC
,
c_thread_mtx
.
Get1dIndex
(
m_repeat
*
MPerThreadSubC
,
n_repeat
*
NPerThreadSubC
),
n_repeat
*
NPerThreadSubC
),
f_accum
);
f_accum
);
}
}
...
...
src/include/common.
cuh
→
src/include/common.
hip.hpp
View file @
b2888adf
File moved
src/include/conv_common.
cuh
→
src/include/conv_common.
hip.hpp
View file @
b2888adf
#pragma once
#pragma once
#include "ConstantTensorDescriptor.
cuh
"
#include "ConstantTensorDescriptor.
hip.hpp
"
// this is ugly, only for 4d
// this is ugly, only for 4d
template
<
class
InDesc
,
class
WeiDesc
>
template
<
class
InDesc
,
class
WeiDesc
>
...
...
src/include/device.hpp
View file @
b2888adf
src/include/gridwise_direct_convolution_1.
cuh
→
src/include/gridwise_direct_convolution_1.
hip.hpp
View file @
b2888adf
#pragma once
#pragma once
#include "common.
cuh
"
#include "common.
hip.hpp
"
#include "ConstantTensorDescriptor.
cuh
"
#include "ConstantTensorDescriptor.
hip.hpp
"
#include "blockwise_4d_tensor_op.
cuh
"
#include "blockwise_4d_tensor_op.
hip.hpp
"
#include "blockwise_direct_convolution.
cuh
"
#include "blockwise_direct_convolution.
hip.hpp
"
template
<
class
Float
,
template
<
class
Float
,
class
InGlobalDesc
,
class
InGlobalDesc
,
...
@@ -147,8 +147,7 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
...
@@ -147,8 +147,7 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
c_block_work_begin
+=
CPerBlock
)
c_block_work_begin
+=
CPerBlock
)
{
{
// copy input tensor to LDS
// copy input tensor to LDS
blockwise_in_copy
.
Run
(
p_in_global
+
blockwise_in_copy
.
Run
(
p_in_global
+
in_global_desc
.
Get1dIndex
(
n_block_work_begin
,
in_global_desc
.
Get1dIndex
(
n_block_work_begin
,
c_block_work_begin
,
c_block_work_begin
,
hi_block_work_begin
,
hi_block_work_begin
,
wi_block_work_begin
),
wi_block_work_begin
),
...
@@ -178,9 +177,9 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
...
@@ -178,9 +177,9 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
}
}
// copy output tensor from LDS to device mem
// copy output tensor from LDS to device mem
blockwise_out_copy
.
Run
(
blockwise_out_copy
.
Run
(
p_out_block
,
p_out_block
,
p_out_global
+
out_global_desc
.
Get1dIndex
(
n_block_work_begin
,
p_out_global
+
k_block_work_begin
,
out_global_desc
.
Get1dIndex
(
ho_block_work_begin
,
n_block_work_begin
,
k_block_work_begin
,
ho_block_work_begin
,
wo_block_work_begin
));
wo_block_work_begin
));
}
}
src/include/gridwise_direct_convolution_2.
cuh
→
src/include/gridwise_direct_convolution_2.
hip.hpp
View file @
b2888adf
#pragma once
#pragma once
#include "common.
cuh
"
#include "common.
hip.hpp
"
#include "ConstantTensorDescriptor.
cuh
"
#include "ConstantTensorDescriptor.
hip.hpp
"
#include "blockwise_4d_tensor_op.
cuh
"
#include "blockwise_4d_tensor_op.
hip.hpp
"
#include "blockwise_direct_convolution.
cuh
"
#include "blockwise_direct_convolution.
hip.hpp
"
#include "threadwise_4d_tensor_op.
cuh
"
#include "threadwise_4d_tensor_op.
hip.hpp
"
#include "threadwise_direct_convolution.
cuh
"
#include "threadwise_direct_convolution.
hip.hpp
"
template
<
class
Float
,
template
<
class
Float
,
class
InGlobalDesc
,
class
InGlobalDesc
,
...
@@ -163,8 +163,7 @@ __global__ void gridwise_direct_convolution_2(const Float* const __restrict__ p_
...
@@ -163,8 +163,7 @@ __global__ void gridwise_direct_convolution_2(const Float* const __restrict__ p_
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
{
{
// copy input tensor to LDS
// copy input tensor to LDS
blockwise_in_copy
.
Run
(
p_in_global
+
blockwise_in_copy
.
Run
(
p_in_global
+
in_global_desc
.
Get1dIndex
(
n_block_data_begin
,
in_global_desc
.
Get1dIndex
(
n_block_data_begin
,
c_block_data_begin
,
c_block_data_begin
,
hi_block_data_begin
,
hi_block_data_begin
,
wi_block_data_begin
),
wi_block_data_begin
),
...
@@ -183,8 +182,7 @@ __global__ void gridwise_direct_convolution_2(const Float* const __restrict__ p_
...
@@ -183,8 +182,7 @@ __global__ void gridwise_direct_convolution_2(const Float* const __restrict__ p_
#if 1
#if 1
threadwise_direct_convolution_2
(
threadwise_direct_convolution_2
(
in_thread_block_desc
,
in_thread_block_desc
,
p_in_block
+
p_in_block
+
in_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
in_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data
,
c_thread_data
,
hi_thread_data_begin
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wi_thread_data_begin
),
...
@@ -195,8 +193,7 @@ __global__ void gridwise_direct_convolution_2(const Float* const __restrict__ p_
...
@@ -195,8 +193,7 @@ __global__ void gridwise_direct_convolution_2(const Float* const __restrict__ p_
#elif 0
#elif 0
threadwise_direct_convolution_3
(
threadwise_direct_convolution_3
(
in_thread_block_desc
,
in_thread_block_desc
,
p_in_block
+
p_in_block
+
in_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
in_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data
,
c_thread_data
,
hi_thread_data_begin
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wi_thread_data_begin
),
...
@@ -213,8 +210,7 @@ __global__ void gridwise_direct_convolution_2(const Float* const __restrict__ p_
...
@@ -213,8 +210,7 @@ __global__ void gridwise_direct_convolution_2(const Float* const __restrict__ p_
out_thread_desc
,
out_thread_desc
,
p_out_thread
,
p_out_thread
,
out_global_desc
,
out_global_desc
,
p_out_global
+
p_out_global
+
out_global_desc
.
Get1dIndex
(
n_block_data_begin
+
n_thread_data_begin
,
out_global_desc
.
Get1dIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
wo_block_data_begin
+
wo_thread_data_begin
),
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment