Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
98716c83
"...text-generation-inference.git" did not exist on "00e6ce44b165555087265286d416cad4826c3791"
Commit
98716c83
authored
Jan 14, 2020
by
Chao Liu
Browse files
added bwd data v3r1
parent
9750de73
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
183 additions
and
188 deletions
+183
-188
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
+167
-148
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
+13
-7
driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
+0
-30
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+3
-3
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
View file @
98716c83
...
@@ -46,6 +46,34 @@ template <index_t GridSize,
...
@@ -46,6 +46,34 @@ template <index_t GridSize,
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
>
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
>
struct
GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
struct
GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
{
{
// this is a hack, should query this info from gridwise_gemm instead of duplicate its logic
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
index_t
max_lds_align
=
math
::
lcm
(
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_gemmk_gemmm_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
GemmKPerBlock
,
GemmMPerBlock
>
{},
Number
<
max_lds_align
>
{});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_gemmk_gemmn_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
GemmKPerBlock
,
GemmNPerBlock
>
{},
Number
<
max_lds_align
>
{});
// LDS allocation for A and B: be careful of alignment
constexpr
index_t
a_block_space
=
math
::
integer_least_multiple
(
a_gemmk_gemmm_block_desc
.
GetElementSpace
(),
max_lds_align
);
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_gemmk_gemmn_block_desc
.
GetElementSpace
(),
max_lds_align
);
return
2
*
(
a_block_space
+
b_block_space
)
*
sizeof
(
Float
);
}
__device__
void
Run
(
Float
*
__restrict__
p_in_global
,
__device__
void
Run
(
Float
*
__restrict__
p_in_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_out_global
)
const
const
Float
*
__restrict__
p_out_global
)
const
...
@@ -117,11 +145,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
...
@@ -117,11 +145,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
Embed
<
Y
,
Embed
<
Y
,
Sequence
<
Ydot
,
Ytilda
>
,
Sequence
<
Ydot
,
Ytilda
>
,
Sequence
<
ConvStrideH
/
hcf_stride_dilation_h
,
1
,
0
>
,
Sequence
<
ConvStrideH
/
hcf_stride_dilation_h
,
1
,
0
>
,
fals
e
>
{},
tru
e
>
{},
Embed
<
X
,
Embed
<
X
,
Sequence
<
Xdot
,
Xtilda
>
,
Sequence
<
Xdot
,
Xtilda
>
,
Sequence
<
ConvStrideW
/
hcf_stride_dilation_w
,
1
,
0
>
,
Sequence
<
ConvStrideW
/
hcf_stride_dilation_w
,
1
,
0
>
,
fals
e
>
{}),
tru
e
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
...
@@ -205,155 +233,146 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
...
@@ -205,155 +233,146 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
// get a series of GEMMs
auto
f_get_gemm
=
[
&
](
auto
ytilda_
,
auto
xtilda_
)
{
constexpr
index_t
ytilda
=
decltype
(
ytilda_
){};
constexpr
index_t
xtilda
=
decltype
(
xtilda_
){};
constexpr
index_t
Ydotnonzero
=
(
ytilda
+
1
)
*
Ydot
<=
Y
?
Ydot
:
Y
%
Ydot
;
constexpr
index_t
Xdotnonzero
=
(
xtilda
+
1
)
*
Xdot
<=
X
?
Xdot
:
X
%
Xdot
;
// A matrix
constexpr
auto
wei_k_c_ydotnonzero_1_xdotnonzero_1_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
Trim
<
Sequence
<
Ydot
,
Xdot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
Ydot
-
Ydotnonzero
,
Xdot
-
Xdotnonzero
>>
{},
Trim
<
Sequence
<
Ytilda
,
Xtilda
>
,
Sequence
<
ytilda
,
xtilda
>
,
Sequence
<
Ytilda
-
ytilda
-
1
,
Xtilda
-
xtilda
-
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}));
constexpr
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydotnonzero_1_xdotnonzero_1_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
Ydotnonzero
,
Xdotnonzero
>>
{},
Merge
<
Sequence
<
C
,
1
,
1
>>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// B matrix
constexpr
auto
out_n_k_ydotnonzero_htildatrim_xdotnonzero_wtildatrim_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
PassThrough
<
HtildaTrim
>
{},
PassThrough
<
WtildaTrim
>
{},
Trim
<
Sequence
<
Ydot
,
Xdot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
Ydot
-
Ydotnonzero
,
Xdot
-
Xdotnonzero
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}));
constexpr
auto
out_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydotnonzero_htildatrim_xdotnonzero_wtildatrim_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
Ydotnonzero
,
Xdotnonzero
>>
{},
Merge
<
Sequence
<
N
,
HtildaTrim
,
WtildaTrim
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// C matrix
constexpr
auto
in_n_c_1_htildatrim_1_wtildatrim_global_desc
=
transform_tensor_descriptor
(
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
PassThrough
<
HtildaTrim
>
{},
PassThrough
<
WtildaTrim
>
{},
Trim
<
Sequence
<
Ytilda
,
Xtilda
>
,
Sequence
<
ytilda
,
xtilda
>
,
Sequence
<
Ytilda
-
ytilda
-
1
,
Xtilda
-
xtilda
-
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}));
constexpr
auto
in_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_1_htildatrim_1_wtildatrim_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
1
,
1
>>
{},
Merge
<
Sequence
<
N
,
HtildaTrim
,
WtildaTrim
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalC_v1
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
out_gemmk_gemmn_global_desc
),
decltype
(
in_gemmm_gemmn_global_desc
),
InMemoryDataOperation
::
none
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
return
gridwise_gemm
;
};
// GEMMs
// GEMMs
index_t
shared_mem_size
=
0
;
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
Float
);
static_for
<
0
,
Ytilda
,
1
>
{}([
&
](
auto
ytilda
)
{
static_for
<
0
,
Xtilda
,
1
>
{}([
&
](
auto
xtilda
)
{
auto
gemm
=
f_get_gemm
(
ytilda
,
xtilda
);
shared_mem_size
=
math
::
max
(
shared_mem_size
,
gemm
.
GetSharedMemorySize
());
});
});
__shared__
Float
p_shared_
f
lo
at
[
shared_
mem_size
/
sizeof
(
Float
)
];
__shared__
Float
p_shared_
b
lo
ck
[
shared_
block_size
];
// GEMMs
#if 1 // debug
static_for
<
0
,
Ytilda
,
1
>
{}([
&
](
auto
ytilda
)
{
static_for
<
0
,
Ytilda
,
1
>
{}([
&
](
auto
ytilda_
)
{
static_for
<
0
,
Xtilda
,
1
>
{}([
&
](
auto
xtilda
)
{
static_for
<
0
,
Xtilda
,
1
>
{}([
&
](
auto
xtilda_
)
{
auto
gemm
=
f_get_gemm
(
ytilda
,
xtilda
);
#else
static_for
<
0
,
1
,
1
>
{}([
&
](
auto
ytilda_
)
{
gemm
.
Run
(
p_wei_global
,
p_in_global
,
p_out_global
,
p_shared_float
);
static_for
<
0
,
1
,
1
>
{}([
&
](
auto
xtilda_
)
{
#endif
constexpr
index_t
ytilda
=
decltype
(
ytilda_
){};
constexpr
index_t
xtilda
=
decltype
(
xtilda_
){};
constexpr
index_t
YdotNonZero
=
(
ytilda
+
1
)
*
Ydot
<=
Y
?
Ydot
:
Y
%
Ydot
;
constexpr
index_t
XdotNonZero
=
(
xtilda
+
1
)
*
Xdot
<=
X
?
Xdot
:
X
%
Xdot
;
// A matrix
constexpr
auto
wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
Trim
<
Sequence
<
Ydot
,
Xdot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
Ydot
-
YdotNonZero
,
Xdot
-
XdotNonZero
>>
{},
Trim
<
Sequence
<
Ytilda
,
Xtilda
>
,
Sequence
<
ytilda
,
xtilda
>
,
Sequence
<
Ytilda
-
ytilda
-
1
,
Xtilda
-
xtilda
-
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}));
constexpr
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
YdotNonZero
,
XdotNonZero
>>
{},
Merge
<
Sequence
<
C
,
1
,
1
>>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// B matrix
constexpr
auto
out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
PassThrough
<
HtildaTrim
>
{},
PassThrough
<
WtildaTrim
>
{},
Trim
<
Sequence
<
Ydot
,
Xdot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
Ydot
-
YdotNonZero
,
Xdot
-
XdotNonZero
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}));
constexpr
auto
out_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
YdotNonZero
,
XdotNonZero
>>
{},
Merge
<
Sequence
<
N
,
HtildaTrim
,
WtildaTrim
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// C matrix
constexpr
auto
in_n_c_1_htildatrim_1_wtildatrim_global_desc
=
transform_tensor_descriptor
(
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
PassThrough
<
HtildaTrim
>
{},
PassThrough
<
WtildaTrim
>
{},
Trim
<
Sequence
<
Ytilda
,
Xtilda
>
,
Sequence
<
ytilda
,
xtilda
>
,
Sequence
<
Ytilda
-
ytilda
-
1
,
Xtilda
-
xtilda
-
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}));
constexpr
auto
in_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_1_htildatrim_1_wtildatrim_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
1
,
1
>>
{},
Merge
<
Sequence
<
N
,
HtildaTrim
,
WtildaTrim
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalC_v1
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
out_gemmk_gemmn_global_desc
),
decltype
(
in_gemmm_gemmn_global_desc
),
InMemoryDataOperation
::
none
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_out_global
,
p_in_global
,
p_shared_block
);
});
});
});
});
}
}
...
...
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
View file @
98716c83
...
@@ -50,7 +50,7 @@ template <index_t GridSize,
...
@@ -50,7 +50,7 @@ template <index_t GridSize,
index_t
CThreadCopyDstDataPerWrite
>
index_t
CThreadCopyDstDataPerWrite
>
struct
GridwiseGemmTransposedANormalBNormalC_v1
struct
GridwiseGemmTransposedANormalBNormalC_v1
{
{
__host__
__device__
static
constexpr
index_t
GetSharedMemory
Siz
e
()
__host__
__device__
static
constexpr
index_t
GetSharedMemory
NumberOfByt
e
()
{
{
constexpr
index_t
max_lds_align
=
math
::
lcm
(
ABlockCopyDstDataPerWrite_M
,
constexpr
index_t
max_lds_align
=
math
::
lcm
(
ABlockCopyDstDataPerWrite_M
,
BBlockCopyDstDataPerWrite_N
,
BBlockCopyDstDataPerWrite_N
,
...
@@ -80,7 +80,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -80,7 +80,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
__device__
void
Run
(
const
Float
*
__restrict__
p_a_global
,
__device__
void
Run
(
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_b_global
,
const
Float
*
__restrict__
p_b_global
,
Float
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_c_global
,
void
*
__restrict__
p_shared
)
const
Float
*
__restrict__
p_shared
_block
)
const
{
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
...
@@ -92,6 +92,12 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -92,6 +92,12 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
constexpr
auto
M
=
a_k_m_global_desc
.
GetLengths
()[
1
];
constexpr
auto
M
=
a_k_m_global_desc
.
GetLengths
()[
1
];
constexpr
auto
N
=
b_k_n_global_desc
.
GetLengths
()[
1
];
constexpr
auto
N
=
b_k_n_global_desc
.
GetLengths
()[
1
];
// don't do anything if K == 0
if
(
K
==
0
)
{
return
;
}
// lds max alignment
// lds max alignment
constexpr
index_t
max_lds_align
=
math
::
lcm
(
ABlockCopyDstDataPerWrite_M
,
constexpr
index_t
max_lds_align
=
math
::
lcm
(
ABlockCopyDstDataPerWrite_M
,
BBlockCopyDstDataPerWrite_N
,
BBlockCopyDstDataPerWrite_N
,
...
@@ -212,8 +218,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -212,8 +218,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
constexpr
index_t
b_block_space
=
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpace
(),
max_lds_align
);
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpace
(),
max_lds_align
);
Float
*
p_a_block_double
=
reinterpret_cast
<
Float
*>
(
p_shared
)
;
Float
*
p_a_block_double
=
p_shared
_block
;
Float
*
p_b_block_double
=
p_
a_block_double
+
2
*
a_block_space
;
Float
*
p_b_block_double
=
p_
shared_block
+
2
*
a_block_space
;
// register allocation for output
// register allocation for output
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_mtx_desc
.
GetElementSpace
()];
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_mtx_desc
.
GetElementSpace
()];
...
@@ -362,11 +368,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -362,11 +368,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
const
Float
*
__restrict__
p_b_global
,
const
Float
*
__restrict__
p_b_global
,
Float
*
__restrict__
p_c_global
)
const
Float
*
__restrict__
p_c_global
)
const
{
{
constexpr
index_t
shared_
mem
_size
=
GetSharedMemory
Size
(
);
constexpr
index_t
shared_
block
_size
=
GetSharedMemory
NumberOfByte
()
/
sizeof
(
Float
);
__shared__
Float
p_shared_
f
lo
at
[
shared_
mem_size
/
sizeof
(
Float
)
];
__shared__
Float
p_shared_
b
lo
ck
[
shared_
block_size
];
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
p_shared_
f
lo
at
);
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
p_shared_
b
lo
ck
);
}
}
};
};
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
View file @
98716c83
...
@@ -84,36 +84,6 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
...
@@ -84,36 +84,6 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif 1
// BlockSize = 256, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
2
,
1
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmM
=
1
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
1
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#endif
#endif
...
...
driver/src/conv_bwd_data_driver.cpp
View file @
98716c83
...
@@ -25,10 +25,10 @@ int main(int argc, char* argv[])
...
@@ -25,10 +25,10 @@ int main(int argc, char* argv[])
#if 1
#if 1
// 3x3 filter, 2x2 stride, 35x35 input
// 3x3 filter, 2x2 stride, 35x35 input
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1
28
;
constexpr
index_t
C
=
1
024
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
1
28
;
constexpr
index_t
K
=
1
024
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
X
=
3
;
...
@@ -157,7 +157,7 @@ int main(int argc, char* argv[])
...
@@ -157,7 +157,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
2
,
2
>
;
using
LeftPads
=
Sequence
<
2
,
2
>
;
using
RightPads
=
Sequence
<
2
,
2
>
;
using
RightPads
=
Sequence
<
2
,
2
>
;
#elif
0
#elif
1
// 1x7 filter, 0x3 pad, 17x17 input
// 1x7 filter, 0x3 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
C
=
128
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment