Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
28325204
"...resnet50_tensorflow.git" did not exist on "589ac399f0777c0c9e71a8a9bdb2b019cefeb536"
Commit
28325204
authored
Mar 18, 2019
by
Chao Liu
Browse files
adding fp16 direct that reads pre-vectorized data
parent
4f0fc72e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
23 deletions
+23
-23
src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
...se_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
+23
-23
No files found.
src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
View file @
28325204
...
@@ -27,8 +27,8 @@ template <class Float,
...
@@ -27,8 +27,8 @@ template <class Float,
unsigned
BlockSize
,
unsigned
BlockSize
,
unsigned
GridSize
>
unsigned
GridSize
>
__global__
void
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw
(
__global__
void
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw
(
const
typename
vector_type
<
Float
,
ScalarPerVector
>::
VectorType
*
const
__restrict__
p_in_global
,
const
typename
vector_type
<
Float
,
ScalarPerVector
>::
VectorType
*
const
__restrict__
p_in_
vec_
global
,
const
typename
vector_type
<
Float
,
ScalarPerVector
>::
VectorType
*
const
__restrict__
p_wei_global
,
const
typename
vector_type
<
Float
,
ScalarPerVector
>::
VectorType
*
const
__restrict__
p_wei_
vec_
global
,
Float
*
const
__restrict__
p_out_global
)
Float
*
const
__restrict__
p_out_global
)
{
{
using
scalar_t
=
Float
;
using
scalar_t
=
Float
;
...
@@ -76,25 +76,25 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
...
@@ -76,25 +76,25 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
?
InBlockCopyDataPerRead
?
InBlockCopyDataPerRead
:
WeiBlockCopyDataPerRead
;
:
WeiBlockCopyDataPerRead
;
__shared__
Floa
t
p_in_block
[
max_align
*
((
in_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
vector_
t
p_in_
vec_
block
[
max_align
*
((
in_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Floa
t
p_wei_block
[
max_align
*
((
wei_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
vector_
t
p_wei_
vec_
block
[
max_align
*
((
wei_block_size
+
max_align
-
1
)
/
max_align
)];
// threadwise tensors
// threadwise tensors
constexpr
unsigned
HiPerThread
=
HoPerThread
+
Y
-
1
;
constexpr
unsigned
HiPerThread
=
HoPerThread
+
Y
-
1
;
constexpr
unsigned
WiPerThread
=
WoPerThread
+
X
-
1
;
constexpr
unsigned
WiPerThread
=
WoPerThread
+
X
-
1
;
constexpr
auto
in_nchw_thread_block_desc
=
constexpr
auto
in_nchw_
vec_
thread_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
NPerThread
,
CPerThread
,
HiPerThread
,
WiPerThread
>
{},
make_ConstantTensorDescriptor
(
Sequence
<
NPerThread
,
CPerThread
,
HiPerThread
,
WiPerThread
>
{},
in_nchw_vec_block_desc
.
GetStrides
());
in_nchw_vec_block_desc
.
GetStrides
());
constexpr
auto
wei_kcyx_thread_block_desc
=
make_ConstantTensorDescriptor
(
constexpr
auto
wei_kcyx_
vec_
thread_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
CPerThread
,
Y
,
X
>
{},
wei_kcyx_vec_block_desc
.
GetStrides
());
Sequence
<
KPerThread
,
CPerThread
,
Y
,
X
>
{},
wei_kcyx_vec_block_desc
.
GetStrides
());
constexpr
auto
out_nkhw_thread_desc
=
get_convolution_output_default_4d_tensor_descriptor
(
constexpr
auto
out_nkhw_thread_desc
=
get_convolution_output_default_4d_tensor_descriptor
(
in_nchw_thread_block_desc
,
wei_kcyx_thread_block_desc
);
in_nchw_
vec_
thread_block_desc
,
wei_kcyx_
vec_
thread_block_desc
);
// register
// register
Floa
t
p_out_thread
[
out_nkhw_thread_desc
.
GetElementSpace
()];
scalar_
t
p_out_thread
[
out_nkhw_thread_desc
.
GetElementSpace
()];
// divide block work
// divide block work
constexpr
unsigned
NBlockWork
=
constexpr
unsigned
NBlockWork
=
...
@@ -150,7 +150,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
...
@@ -150,7 +150,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
constexpr
auto
blockwise_in_copy
=
constexpr
auto
blockwise_in_copy
=
Blockwise4dTensorCopy1
<
BlockSize
,
Blockwise4dTensorCopy1
<
BlockSize
,
Floa
t
,
vector_
t
,
decltype
(
in_nchw_vec_global_desc
),
decltype
(
in_nchw_vec_global_desc
),
decltype
(
in_nchw_vec_block_desc
),
decltype
(
in_nchw_vec_block_desc
),
decltype
(
in_nchw_vec_block_desc
.
GetLengths
()),
decltype
(
in_nchw_vec_block_desc
.
GetLengths
()),
...
@@ -159,7 +159,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
...
@@ -159,7 +159,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
#if 0
#if 0
constexpr auto blockwise_wei_copy =
constexpr auto blockwise_wei_copy =
Blockwise4dTensorCopy1<BlockSize,
Blockwise4dTensorCopy1<BlockSize,
Floa
t,
vector_
t,
decltype(wei_kcyx_vec_global_desc),
decltype(wei_kcyx_vec_global_desc),
decltype(wei_kcyx_vec_block_desc),
decltype(wei_kcyx_vec_block_desc),
decltype(wei_kcyx_vec_block_desc.GetLengths()),
decltype(wei_kcyx_vec_block_desc.GetLengths()),
...
@@ -167,7 +167,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
...
@@ -167,7 +167,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
#elif
1
#elif
1
const
auto
blockwise_wei_copy
=
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Blockwise2dTensorCopy3
<
BlockSize
,
Floa
t
,
vector_
t
,
decltype
(
wei_ke_vec_global_desc
),
decltype
(
wei_ke_vec_global_desc
),
decltype
(
wei_ke_vec_block_desc
),
decltype
(
wei_ke_vec_block_desc
),
decltype
(
wei_ke_vec_block_desc
.
GetLengths
()),
decltype
(
wei_ke_vec_block_desc
.
GetLengths
()),
...
@@ -181,16 +181,16 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
...
@@ -181,16 +181,16 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
{
{
// copy input tensor to LDS
// copy input tensor to LDS
blockwise_in_copy
.
Run
(
p_in_global
+
in_nchw_vec_global_desc
.
Get1dIndex
(
n_block_data_begin
,
blockwise_in_copy
.
Run
(
p_in_
vec_
global
+
in_nchw_vec_global_desc
.
Get1dIndex
(
n_block_data_begin
,
c_block_data_begin
,
c_block_data_begin
,
hi_block_data_begin
,
hi_block_data_begin
,
wi_block_data_begin
),
wi_block_data_begin
),
p_in_block
);
p_in_
vec_
block
);
// copy weight tensor to LDS
// copy weight tensor to LDS
blockwise_wei_copy
.
Run
(
p_wei_global
+
wei_kcyx_vec_global_desc
.
Get1dIndex
(
blockwise_wei_copy
.
Run
(
p_wei_
vec_
global
+
wei_kcyx_vec_global_desc
.
Get1dIndex
(
k_block_data_begin
,
c_block_data_begin
,
0
,
0
),
k_block_data_begin
,
c_block_data_begin
,
0
,
0
),
p_wei_block
);
p_wei_
vec_
block
);
__syncthreads
();
__syncthreads
();
...
@@ -199,25 +199,25 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
...
@@ -199,25 +199,25 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
// threadwise convolution
// threadwise convolution
#if 1
#if 1
threadwise_direct_convolution_2
(
threadwise_direct_convolution_2
(
in_nchw_thread_block_desc
,
in_nchw_
vec_
thread_block_desc
,
p_in_block
+
in_nchw_vec_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
p_in_
vec_
block
+
in_nchw_vec_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data
,
c_thread_data
,
hi_thread_data_begin
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wi_thread_data_begin
),
wei_kcyx_thread_block_desc
,
wei_kcyx_
vec_
thread_block_desc
,
p_wei_block
+
p_wei_
vec_
block
+
wei_kcyx_vec_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
wei_kcyx_vec_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
out_nkhw_thread_desc
,
out_nkhw_thread_desc
,
p_out_thread
);
p_out_thread
);
#elif 0
#elif 0
threadwise_direct_convolution_3
(
threadwise_direct_convolution_3
(
in_nchw_thread_block_desc
,
in_nchw_
vec_
thread_block_desc
,
p_in_block
+
in_nchw_vec_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
p_in_
vec_
block
+
in_nchw_vec_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data
,
c_thread_data
,
hi_thread_data_begin
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wi_thread_data_begin
),
wei_kcyx_thread_block_desc
,
wei_kcyx_
vec_
thread_block_desc
,
p_wei_block
+
p_wei_
vec_
block
+
wei_kcyx_vec_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
wei_kcyx_vec_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
out_nkhw_thread_desc
,
out_nkhw_thread_desc
,
p_out_thread
);
p_out_thread
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment