Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8a4b5978
Commit
8a4b5978
authored
May 22, 2019
by
Chao Liu
Browse files
adding implicit gemm v3
parent
2a48812e
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
103 additions
and
62 deletions
+103
-62
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp
...plicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp
+6
-6
src/include/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hip.hpp
...dwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hip.hpp
+73
-42
src/include/integral_constant.hip.hpp
src/include/integral_constant.hip.hpp
+6
-0
src/include/threadwise_gemm.hip.hpp
src/include/threadwise_gemm.hip.hpp
+1
-1
src/include/threadwise_tensor_slice_op.hip.hpp
src/include/threadwise_tensor_slice_op.hip.hpp
+16
-12
src/include/vector_type.hip.hpp
src/include/vector_type.hip.hpp
+1
-1
No files found.
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp
View file @
8a4b5978
...
@@ -86,7 +86,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -86,7 +86,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
constexpr
index_t
HBlockWork
=
mod_conv
::
integer_divide_ceil
(
Ho
,
HoPerBlock
);
constexpr
index_t
HBlockWork
=
mod_conv
::
integer_divide_ceil
(
Ho
,
HoPerBlock
);
constexpr
index_t
WBlockWork
=
mod_conv
::
integer_divide_ceil
(
Wo
,
WoPerBlock
);
constexpr
index_t
WBlockWork
=
mod_conv
::
integer_divide_ceil
(
Wo
,
WoPerBlock
);
constexpr
auto
block_work_desc
=
make_
packed_
ConstantTensorDescriptor
(
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor
_default_rank_packed
(
Sequence
<
NBlockWork
,
KBlockWork
,
HBlockWork
,
WBlockWork
>
{});
Sequence
<
NBlockWork
,
KBlockWork
,
HBlockWork
,
WBlockWork
>
{});
const
auto
block_work_multi_id
=
const
auto
block_work_multi_id
=
...
@@ -110,7 +110,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -110,7 +110,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
);
GemmDataPerReadB
);
constexpr
auto
in_c_h_w_n_block_desc
=
make_
ranked_
ConstantTensorDescriptor_
with
_align
ment
(
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_
default_rank
_align
ed
(
Sequence
<
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerBlock
>
{},
Sequence
<
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerBlock
>
{},
Number
<
InBlockReorderDataPerWrite_N
>
{});
Number
<
InBlockReorderDataPerWrite_N
>
{});
...
@@ -119,12 +119,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -119,12 +119,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
static_assert
(
in_c_h_w_n_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
static_assert
(
in_c_h_w_n_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not meet"
);
"GemmDataPerReadB alignment requirement is not meet"
);
constexpr
auto
wei_c_k_block_desc
=
make_
ranked_
ConstantTensorDescriptor_
with
_align
ment
(
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_
default_rank
_align
ed
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
mod_conv
::
max
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
Number
<
mod_conv
::
max
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
// tensor view of threadwise output in register
// tensor view of threadwise output in register
constexpr
auto
out_k_h_w_n_thread_desc
=
make_
packed_
ConstantTensorDescriptor
(
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor
_default_rank_packed
(
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
// blockwise copy
// blockwise copy
...
@@ -152,7 +152,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -152,7 +152,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_block_desc
),
decltype
(
wei_c_k_block_desc
),
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead_K
>
{}
;
WeiBlockCopyDataPerRead_K
>
({
0
,
0
},
{
0
,
0
})
;
// a series of blockwise batched GEMM
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// C_matrix += transpose(A_matrix) * B_matrix
...
@@ -196,7 +196,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -196,7 +196,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
// choose GEMM implementation here
// choose GEMM implementation here
const
auto
run_blockwise_batch_gemm
=
[
&
](
auto
...
Xs
)
{
const
auto
run_blockwise_batch_gemm
=
[
&
](
auto
...
Xs
)
{
#if
0
#if
1
return
blockwise_batch_gemm
.
Run
(
Xs
...);
return
blockwise_batch_gemm
.
Run
(
Xs
...);
#elif 0
#elif 0
return
blockwise_batch_gemm
.
Run_asm
(
Xs
...);
return
blockwise_batch_gemm
.
Run_asm
(
Xs
...);
...
...
src/include/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hip.hpp
View file @
8a4b5978
#pragma once
#pragma once
#include "common.hip.hpp"
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMergedTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_merged_tensor_slice_op.hip.hpp"
#include "blockwise_gemm.hip.hpp"
#include "blockwise_gemm.hip.hpp"
#include "threadwise_tensor_slice_op.hip.hpp"
// define B = merge(N, Ho, Wo)
// define B = merge(N, Ho, Wo)
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
...
@@ -24,7 +27,12 @@ template <index_t GridSize,
...
@@ -24,7 +27,12 @@ template <index_t GridSize,
index_t
GemmNLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
>
index_t
GemmDataPerReadB
,
class
InBlockCopySubLengths_N1_N2_C_B
,
class
InBlockCopyClusterLengths_N1_N2_C_B
,
index_t
InBlockCopySrcDataPerRead_B
,
index_t
InBlockCopyDstDataPerWrite_N2
,
index_t
WeiBlockCopyDataPerAccess_K
>
struct
GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
struct
GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
{
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
...
@@ -34,12 +42,10 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -34,12 +42,10 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
// this is a mess
// this is a mess
// TODO: more elegent way of specifying (or calculating) performance variables
// TODO: more elegent way of specifying (or calculating) performance variables
static_assert
(
N2
==
GemmNPerThreadSubC
,
"wrong!"
);
static_assert
(
N2
==
GemmNPerThreadSubC
,
"wrong!"
);
static_assert
(
KPerBlock
==
static_assert
((
N1
*
N2
*
BPerBlock
)
%
N1
*
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
,
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
"wrong!"
);
static_assert
(
KPerBlock
%
(
N1
*
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -73,15 +79,14 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -73,15 +79,14 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
constexpr
index_t
B
=
N0
*
Ho
*
Wo
;
constexpr
index_t
B
=
N0
*
Ho
*
Wo
;
// divide block work by [K, B]
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
,
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
C
%
CPerBlock
==
0
,
C
%
CPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
make_ConstantTensorDescriptor
_default_rank_packed
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_multi_id
=
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
...
@@ -95,16 +100,20 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -95,16 +100,20 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
in_n_c_h_w_global_desc
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
in_n_c_h_w_global_desc
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
// merged tensor descriptor in device memory [N1, N2, C, B], src of blockwise copy
// merged tensor descriptor in device memory [N1, N2, C, B], src of blockwise copy
constexpr
auto
in_n1_n2_c_b_global_merged_desc
=
constexpr
auto
in_n1_n2_c_b_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
in_n0_n1_n2_c_h_w_global_desc
.
ReorderGivenNew2Old
(
Sequence
<
1
,
2
,
3
,
0
,
4
,
5
>
{})
in_n0_n1_n2_c_h_w_global_
mem_
desc
.
ReorderGivenNew2Old
(
Sequence
<
1
,
2
,
3
,
0
,
4
,
5
>
{})
.
Slice
(
I4
,
Number
<
Ho
>
{})
.
Slice
(
I4
,
Number
<
Ho
>
{})
.
Slice
(
I5
,
Number
<
Wo
>
{})
.
Slice
(
I5
,
Number
<
Wo
>
{}),
.
Merge
(
I3
,
I5
);
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
,
5
>
{});
// memory layout descriptor in LDS [C, N1, B, N2]
// memory layout descriptor in LDS [C, N1, B, N2]
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
in_c_n1_b_n2_block_mem_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
in_c_n1_b_n2_block_mem_desc
=
Sequence
<
CPerBlock
,
N1
,
BPerBlock
,
N2
>
{},
Number
<
InBlockCopyDstDataPerWrite_N2
>
{});
make_ConstantTensorDescriptor_default_rank_aligned
(
Sequence
<
CPerBlock
,
N1
,
BPerBlock
,
N2
>
{},
Number
<
InBlockCopyDstDataPerWrite_N2
>
{});
// tensor descriptor in LDS [N1, N2, C, B], dst of blockwise copy
// tensor descriptor in LDS [N1, N2, C, B], dst of blockwise copy
constexpr
auto
in_n1_n2_c_b_block_desc
=
constexpr
auto
in_n1_n2_c_b_block_desc
=
...
@@ -112,7 +121,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -112,7 +121,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
// this check is ad-hoc
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with alignment
// TODO: need to properly implement tensor descriptor with alignment
static_assert
(
in_c_n1_b_n2_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
static_assert
(
in_c_n1_b_n2_block_
mem_
desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not satisfied"
);
"GemmDataPerReadB alignment requirement is not satisfied"
);
// input blockwise copy
// input blockwise copy
...
@@ -129,7 +138,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -129,7 +138,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
Sequence
<
2
,
0
,
1
,
3
>
,
// thread_arrange_order [C, N1, N2, B]
Sequence
<
2
,
0
,
1
,
3
>
,
// thread_arrange_order [C, N1, N2, B]
Sequence
<
0
,
1
,
2
,
3
>
,
// src_access_order [N1, N2, C, B]
Sequence
<
0
,
1
,
2
,
3
>
,
// src_access_order [N1, N2, C, B]
Sequence
<
2
,
0
,
3
,
1
>
,
// dst_access_order [C, N1, B, N2]
Sequence
<
2
,
0
,
3
,
1
>
,
// dst_access_order [C, N1, B, N2]
>
({
0
,
0
,
0
,
b_block_data_on_global
},
{
0
,
0
,
0
,
0
});
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
>
({
0
,
0
,
0
,
b_block_data_on_global
},
{
0
,
0
,
0
,
0
});
// weight tensor
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
// tensor descriptor in device memory, src of blockwise copy
...
@@ -137,9 +147,9 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -137,9 +147,9 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
// tensor descriptor in LDS, dst of blockwise copy
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_
default_rank_
aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
mod_conv
::
max
(
WeiBlockCopyDataPer
Read
_K
,
GemmDataPerReadA
)
>
{});
Number
<
mod_conv
::
max
(
WeiBlockCopyDataPer
Access
_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// operator for blockwise copy of weight into LDS
// slicing a tensor
// slicing a tensor
...
@@ -150,7 +160,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -150,7 +160,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_block_desc
),
decltype
(
wei_c_k_block_desc
),
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead_K
>
({
0
,
k_block_data_on_global
},
{
0
,
0
});
WeiBlockCopyDataPerAccess_K
>
({
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
...
@@ -167,7 +178,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -167,7 +178,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
Number
<
in_c_n1_b_n2_block_mem_desc
.
GetStride
(
I0
)
>
{});
Number
<
in_c_n1_b_n2_block_mem_desc
.
GetStride
(
I0
)
>
{});
// sanity check
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
),
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
,
"wrong!"
);
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
constexpr
index_t
GemmMRepeat
=
...
@@ -194,8 +206,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -194,8 +206,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
GemmDataPerReadB
>
{};
GemmDataPerReadB
>
{};
// LDS allocation for input and weight: be careful of alignment
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
mod_conv
::
max
(
InBlock
Reorder
DataPerWrite_N
,
constexpr
index_t
max_align
=
mod_conv
::
max
(
InBlock
CopyDst
DataPerWrite_N
2
,
WeiBlockCopyDataPer
Read
_K
,
WeiBlockCopyDataPer
Access
_K
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
);
GemmDataPerReadB
);
...
@@ -211,7 +223,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -211,7 +223,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
Float
p_out_thread
[
c_k0k2_n1n2_thread_mtx_desc
.
GetElementSpace
()];
Float
p_out_thread
[
c_k0k2_n1n2_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
// zero out threadwise output
threadwise_matrix_set_zero
(
out_k0_k1_k2_n1_n0_h_w_
n2_thread_desc
,
p_out_thread
);
threadwise_matrix_set_zero
(
c_k0k2_n1
n2_thread_
mtx_
desc
,
p_out_thread
);
// do work
// do work
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
...
@@ -229,15 +241,15 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -229,15 +241,15 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
c_block_data_on_global
=
0
;
c_block_data_on_global
=
0
;
c_block_data_on_global
<
C
;
c_block_data_on_global
<
C
;
c_block_data_on_global
+=
CPerBlock
,
c_block_data_on_global
+=
CPerBlock
,
p_in_block_on
t
_global
+=
CPerBlock
*
in_n_c_h_w_global_desc
.
GetStride
(
I1
),
p_in_block_on_global
+=
CPerBlock
*
in_n_c_h_w_global_desc
.
GetStride
(
I1
),
p_wei_block_on_global
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
))
p_wei_block_on_global
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
))
{
{
blockwise_in_copy
.
r
un
(
p_in_block_on_global
,
p_in_block
);
blockwise_in_copy
.
R
un
(
p_in_block_on_global
,
p_in_block
);
blockwise_wei_copy
.
r
un
(
p_wei_block_on_global
,
p_wei_block
);
blockwise_wei_copy
.
R
un
(
p_wei_block_on_global
,
p_wei_block
);
__syncthreads
();
__syncthreads
();
blockwise_gemm
.
r
un
(
p_wei_block
,
p_in_block
,
p_out_thread
);
blockwise_gemm
.
R
un
(
p_wei_block
,
p_in_block
,
p_out_thread
);
__syncthreads
();
__syncthreads
();
}
}
...
@@ -253,19 +265,26 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -253,19 +265,26 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
// define tensor descriptor for threadwise copy
// define tensor descriptor for threadwise copy
// output tensor (also, memory layout) descriptor in register, src of threadwise
// output tensor (also, memory layout) descriptor in register, src of threadwise
// copy
// copy
constexpr
auto
out_k0_k1_k2_n1_b_n2_thread_mem_desc
=
make_ConstantTensorDescriptor
(
constexpr
auto
out_k0_k1_k2_n1_b_n2_thread_mem_desc
=
Sequence
<
KPerBlock
/
(
K1
*
K2
),
1
,
K2
,
N1
,
1
,
1
,
1
,
N2
>
{});
make_ConstantTensorDescriptor_default_rank_packed
(
Sequence
<
KPerBlock
/
(
K1
*
K2
),
1
,
K2
,
N1
,
1
,
N2
>
{});
// output memory layout descriptor in device memory
// output memory layout descriptor in device memory
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
=
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
=
out_n_k_h_w_global
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
out_n_k_h_w_global
_desc
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
// output merged tensor descriptor in device memory, dst of threadwise copy
// output merged tensor descriptor in device memory, dst of threadwise copy
constexpr
auto
out_k0_k1_k2_n1_b_n2_global_merged_desc
=
constexpr
auto
out_k0_k1_k2_n1_b_n2_global_merged_desc
=
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
make_ConstantMergedTensorDescriptor
(
.
ReorderGivenNew2Old
(
Sequence
<
3
,
4
,
5
,
1
,
0
,
6
,
7
,
2
>
{})
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
.
ReorderGivenNew2Old
(
.
Merge
(
I4
,
I6
);
Sequence
<
3
,
4
,
5
,
1
,
0
,
6
,
7
,
2
>
{}),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
,
5
,
6
>
{},
Sequence
<
7
>
{});
// calculate origin of thread output tensor on global memory
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// blockwise GEMM c matrix starting index
...
@@ -273,18 +292,30 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -273,18 +292,30 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
// origin of thread tensor on global
// origin of thread tensor on global
const
index_t
k_thread_data_on_global
k_block_data_on_global
+
const
index_t
k_thread_data_on_global
=
c_thread_mtx_on_block
.
row
;
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
// output merged global tensor descriptor, for calculating origin of thread tensor
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
// in global memory
#if 0 // unfold a merged tensor is not implemented yet
constexpr auto out_k_n1_b_n2_global_merged_desc =
constexpr auto out_k_n1_b_n2_global_merged_desc =
out_k0_k1_k2_n1_b_n2_global_merged_desc
.
Unfold
(
I1
,
I2
);
out_k0_k1_k2_n1_b_n2_global_merged_desc.Unfold(I0, I2);
#else
constexpr
auto
out_k_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
.
ReorderGivenNew2Old
(
Sequence
<
3
,
4
,
5
,
1
,
0
,
6
,
7
,
2
>
{})
.
Unfold
(
I0
,
I2
),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
,
4
>
{},
Sequence
<
5
>
{});
#endif
// origin of thread tensor in global memory
// origin of thread tensor in global memory
const
index_t
p_out_thread_on_global
=
Float
*
p_out_thread_on_global
=
p_out_global
+
p_out_global
+
out_k_n1_b_n2_global_merged_desc
.
GetOffsetFromMultiIndex
(
out_k_n1_b_n2_global_merged_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_on_global
,
0
,
0
,
0
);
// dst origin on merged global tensor
k_thread_data_on_global
,
0
,
0
,
0
);
// dst origin on merged global tensor
...
@@ -303,8 +334,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -303,8 +334,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
0
,
0
,
b_thread_data_on_global
,
b_thread_data_on_global
,
0
},
// starting point of slice w.r.t. origin of dst
0
},
// starting point of slice w.r.t. origin of dst
out_k0_k1_k2_n1_b_n2_thread_desc
.
GetLengths
(),
// slice lengths
out_k0_k1_k2_n1_b_n2_thread_
mem_
desc
.
GetLengths
(),
// slice lengths
Sequence
<
2
,
3
,
4
,
0
,
5
,
1
>
{}
// order of dimension access
Sequence
<
2
,
3
,
4
,
0
,
5
,
1
>
{}
// order of dimension access
);
);
}
}
}
}
...
...
src/include/
constant_integral
.hip.hpp
→
src/include/
integral_constant
.hip.hpp
View file @
8a4b5978
...
@@ -8,5 +8,11 @@ struct integral_constant
...
@@ -8,5 +8,11 @@ struct integral_constant
__host__
__device__
constexpr
T
Get
()
const
{
return
value
;
}
__host__
__device__
constexpr
T
Get
()
const
{
return
value
;
}
};
};
template
<
class
T
,
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
+
(
integral_constant
<
T
,
X
>
,
integral_constant
<
T
,
Y
>
)
{
return
integral_constant
<
T
,
X
+
Y
>
{};
}
template
<
index_t
N
>
template
<
index_t
N
>
using
Number
=
integral_constant
<
index_t
,
N
>
;
using
Number
=
integral_constant
<
index_t
,
N
>
;
src/include/threadwise_gemm.hip.hpp
View file @
8a4b5978
...
@@ -10,7 +10,7 @@ __device__ void threadwise_matrix_set_zero(Matrix, Float* __restrict__ p_thread)
...
@@ -10,7 +10,7 @@ __device__ void threadwise_matrix_set_zero(Matrix, Float* __restrict__ p_thread)
for
(
index_t
j
=
0
;
j
<
Matrix
::
NCol
();
++
j
)
for
(
index_t
j
=
0
;
j
<
Matrix
::
NCol
();
++
j
)
{
{
const
index_t
id
=
Matrix
::
GetOffsetFromMultiIndex
(
i
,
j
);
const
index_t
id
=
Matrix
::
GetOffsetFromMultiIndex
(
i
,
j
);
p_thread
[
id
]
=
0
;
p_thread
[
id
]
=
Float
(
0
)
;
}
}
}
}
}
}
...
...
src/include/threadwise_tensor_slice_op.hip.hpp
View file @
8a4b5978
...
@@ -19,7 +19,7 @@ __device__ void threadwise_tensor_slice_copy(SrcDesc,
...
@@ -19,7 +19,7 @@ __device__ void threadwise_tensor_slice_copy(SrcDesc,
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
ref_desc
=
make_
packed_
ConstantTensorDescriptor
(
SrcOpLengths
{});
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
_default_rank_packed
(
SrcOpLengths
{});
#if 0
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
...
@@ -194,16 +194,19 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
...
@@ -194,16 +194,19 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
}
}
template
<
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SliceLengths
,
class
DimAccessOrder
>
template
<
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SliceLengths
,
class
DimAccessOrder
>
__device__
void
__device__
void
threadwise_tensor_slice_copy_generic
(
threadwise_tensor_slice_copy_generic
(
SrcDesc
,
SrcDesc
,
const
Float
*
__restrict__
p_src
,
const
Float
*
__restrict__
p_src
,
Array
<
index_t
,
SrcDesc
::
GetNumOfDimension
()
>
src_multi_
offset
,
Array
<
index_t
,
SrcDesc
::
GetNumOfDimension
()
>
src_multi_
id_begin
,
DstDesc
,
DstDesc
,
Float
*
__restrict__
p_dst
,
Float
*
__restrict__
p_dst
,
Array
<
index_t
,
DstDesc
::
GetNumOfDimension
()
>
dst_multi_
offset
,
Array
<
index_t
,
DstDesc
::
GetNumOfDimension
()
>
dst_multi_
id_begin
,
SliceLengths
,
SliceLengths
,
DimAccessOrder
)
DimAccessOrder
)
{
{
static_assert
(
SrcDesc
::
GetNumOfDimension
()
==
DstDesc
::
GetNumOfDimension
(),
"wrong! # of dimensions not the same"
);
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
...
@@ -215,9 +218,10 @@ threadwise_tensor_slice_copy_generic(SrcDesc,
...
@@ -215,9 +218,10 @@ threadwise_tensor_slice_copy_generic(SrcDesc,
reorder_array_given_old2new
(
data_multi_id_in_access_order
,
DimAccessOrder
{});
reorder_array_given_old2new
(
data_multi_id_in_access_order
,
DimAccessOrder
{});
const
index_t
dst_index
=
const
index_t
dst_index
=
dst_desc
.
GetOffsetFromMultiIndex
(
src_multi_offset
+
data_multi_id
);
dst_desc
.
GetOffsetFromMultiIndex
(
src_multi_id_begin
+
data_multi_id
);
const
index_t
src_index
=
const
index_t
src_index
=
src_desc
.
GetOffsetFromMultiIndex
(
dst_multi_
offset
+
data_multi_id
);
src_desc
.
GetOffsetFromMultiIndex
(
dst_multi_
id_begin
+
data_multi_id
);
p_dst
[
dst_index
]
=
p_src
[
src_index
];
p_dst
[
dst_index
]
=
p_src
[
src_index
];
});
});
...
...
src/include/vector_type.hip.hpp
View file @
8a4b5978
#pragma once
#pragma once
#include "config.h"
#include "config.h"
#include "
constant_integral
.hip.hpp"
#include "
integral_constant
.hip.hpp"
template
<
class
T
,
index_t
N
>
template
<
class
T
,
index_t
N
>
struct
vector_type
struct
vector_type
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment