Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2558d019
Commit
2558d019
authored
Feb 21, 2021
by
Chao Liu
Browse files
making dynamic multi-index transform support compile-time info
parent
1e55a3b1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
208 additions
and
178 deletions
+208
-178
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+112
-104
composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp
...lude/tensor_description/dynamic_multi_index_transform.hpp
+68
-46
composable_kernel/include/utility/config.amd.hpp.in
composable_kernel/include/utility/config.amd.hpp.in
+1
-0
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+27
-28
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
2558d019
...
@@ -36,14 +36,20 @@ template <index_t BlockSize,
...
@@ -36,14 +36,20 @@ template <index_t BlockSize,
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
>
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
>
struct
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
struct
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
{
{
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
>
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
__host__
void
Run
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
__host__
void
Run
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
MultiIndex
<
2
>
conv_strides
,
const
ConvStrides
&
conv_strides
,
const
MultiIndex
<
2
>
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
MultiIndex
<
2
>
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
const
MultiIndex
<
2
>
in_right_pads
,
const
InRightPads
&
in_right_pads
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_in_global
,
const
Float
*
__restrict__
p_in_global
,
Float
*
__restrict__
p_out_global
)
const
Float
*
__restrict__
p_out_global
)
const
...
@@ -53,30 +59,30 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
...
@@ -53,30 +59,30 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
const
auto
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
const
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
const
auto
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
const
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
const
auto
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
const
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
const
auto
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
const
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
const
auto
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
const
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
const
auto
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
const
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
const
auto
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
const
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
const
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
const
auto
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
const
index_t
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
index_t
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
index_t
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
index_t
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
index_t
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
index_t
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
index_t
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
index_t
InRightPadW
=
in_right_pads
[
I1
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
// weight tensor
// weight tensor
const
auto
wei_gemmk_gemmm_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
wei_gemmk_gemmm_global_desc
=
transform_dynamic_tensor_descriptor
(
...
@@ -95,8 +101,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
...
@@ -95,8 +101,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLength
(
I2
);
const
auto
Hip
=
in_n_c_hip_wip_global_desc
.
GetLength
(
I2
);
const
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLength
(
I3
);
const
auto
Wip
=
in_n_c_hip_wip_global_desc
.
GetLength
(
I3
);
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_dynamic_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
in_n_c_hip_wip_global_desc
,
...
@@ -123,9 +129,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
...
@@ -123,9 +129,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
index_t
GemmM
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I0
);
const
auto
GemmM
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I0
);
const
index_t
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
auto
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
index_t
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
if
(
!
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
if
(
!
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
))
GemmK
%
GemmKPerBlock
==
0
))
...
@@ -133,21 +139,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
...
@@ -133,21 +139,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
}
constexpr
index_t
GemmM1
=
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
auto
GemmM1
=
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
GemmN1
=
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
auto
GemmN1
=
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
const
index_t
GemmM0
=
GemmM
/
GemmM1
;
const
auto
GemmM0
=
GemmM
/
GemmM1
;
const
index_t
GemmN0
=
GemmN
/
GemmN1
;
const
auto
GemmN0
=
GemmN
/
GemmN1
;
#if 0
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc =
transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(DynamicUnMerge<2>{make_multi_index(GemmM0, GemmM1)},
DynamicUnMerge<2>{make_multi_index(GemmN0, GemmN1)}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
#else
const
auto
GemmM0_GemmM1
=
make_tuple
(
GemmM0
,
Number
<
GemmM1
>
{});
const
auto
GemmM0_GemmM1
=
make_tuple
(
GemmM0
,
Number
<
GemmM1
>
{});
const
auto
GemmN0_GemmN1
=
make_tuple
(
GemmN0
,
Number
<
GemmN1
>
{});
const
auto
GemmN0_GemmN1
=
make_tuple
(
GemmN0
,
Number
<
GemmN1
>
{});
...
@@ -159,7 +156,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
...
@@ -159,7 +156,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
DynamicUnMerge
<
2
,
false
,
remove_cv_t
<
decltype
(
GemmN0_GemmN1
)
>>
{
GemmN0_GemmN1
}),
DynamicUnMerge
<
2
,
false
,
remove_cv_t
<
decltype
(
GemmN0_GemmN1
)
>>
{
GemmN0_GemmN1
}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
#endif
// hack to control index calculation when iterating over a_k_m_global tensor
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr
auto
a_k_m_global_iterator_hacks
=
constexpr
auto
a_k_m_global_iterator_hacks
=
...
@@ -235,7 +231,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
...
@@ -235,7 +231,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
decltype
(
a_k_m_global_move_slice_window_iterator_hack
),
decltype
(
a_k_m_global_move_slice_window_iterator_hack
),
decltype
(
b_k_n_global_move_slice_window_iterator_hack
)
>
;
decltype
(
b_k_n_global_move_slice_window_iterator_hack
)
>
;
const
index_t
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
const
auto
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
const
bool
has_main_k_block_loop
=
(
GemmK
+
GemmKPerBlock
)
/
(
2
*
GemmKPerBlock
)
>
1
;
const
bool
has_main_k_block_loop
=
(
GemmK
+
GemmKPerBlock
)
/
(
2
*
GemmKPerBlock
)
>
1
;
...
@@ -724,14 +720,20 @@ template <index_t BlockSize,
...
@@ -724,14 +720,20 @@ template <index_t BlockSize,
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
>
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
>
struct
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
struct
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
{
{
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
>
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
__host__
void
Run
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
__host__
void
Run
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
MultiIndex
<
2
>
conv_strides
,
const
ConvStrides
&
conv_strides
,
const
MultiIndex
<
2
>
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
MultiIndex
<
2
>
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
const
MultiIndex
<
2
>
in_right_pads
,
const
InRightPads
&
in_right_pads
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_in_global
,
const
Float
*
__restrict__
p_in_global
,
Float
*
__restrict__
p_out_global
)
const
Float
*
__restrict__
p_out_global
)
const
...
@@ -741,30 +743,30 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
...
@@ -741,30 +743,30 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
const
auto
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
const
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
const
auto
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
const
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
const
auto
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
const
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
const
auto
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
const
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
const
auto
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
const
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
const
auto
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
const
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
const
auto
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
const
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
const
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
const
auto
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
const
index_t
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
index_t
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
index_t
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
index_t
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
index_t
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
index_t
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
index_t
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
index_t
InRightPadW
=
in_right_pads
[
I1
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
if
(
!
(
InLeftPadH
==
0
&&
InLeftPadW
==
0
&&
InRightPadH
==
0
&&
InRightPadW
==
0
))
if
(
!
(
InLeftPadH
==
0
&&
InLeftPadW
==
0
&&
InRightPadH
==
0
&&
InRightPadW
==
0
))
{
{
...
@@ -791,8 +793,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
...
@@ -791,8 +793,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
// debug: don't do padding
// debug: don't do padding
const
auto
in_n_c_hip_wip_global_desc
=
in_n_c_hi_wi_global_desc
;
const
auto
in_n_c_hip_wip_global_desc
=
in_n_c_hi_wi_global_desc
;
const
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLength
(
I2
);
const
auto
Hip
=
in_n_c_hip_wip_global_desc
.
GetLength
(
I2
);
const
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLength
(
I3
);
const
auto
Wip
=
in_n_c_hip_wip_global_desc
.
GetLength
(
I3
);
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_dynamic_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
in_n_c_hip_wip_global_desc
,
...
@@ -828,9 +830,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
...
@@ -828,9 +830,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
#endif
#endif
const
index_t
GemmM
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I0
);
const
auto
GemmM
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I0
);
const
index_t
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
auto
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
index_t
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
if
(
!
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
if
(
!
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
))
GemmK
%
GemmKPerBlock
==
0
))
...
@@ -838,11 +840,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
...
@@ -838,11 +840,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
}
constexpr
index_t
GemmM1
=
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
auto
GemmM1
=
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
GemmN1
=
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
auto
GemmN1
=
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
const
index_t
GemmM0
=
GemmM
/
GemmM1
;
const
auto
GemmM0
=
GemmM
/
GemmM1
;
const
index_t
GemmN0
=
GemmN
/
GemmN1
;
const
auto
GemmN0
=
GemmN
/
GemmN1
;
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
=
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
=
transform_dynamic_tensor_descriptor
(
transform_dynamic_tensor_descriptor
(
...
@@ -924,7 +926,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
...
@@ -924,7 +926,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
decltype
(
a_k_m_global_move_slice_window_iterator_hack
),
decltype
(
a_k_m_global_move_slice_window_iterator_hack
),
decltype
(
b_k_n_global_move_slice_window_iterator_hack
)
>
;
decltype
(
b_k_n_global_move_slice_window_iterator_hack
)
>
;
const
index_t
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
const
auto
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
const
bool
has_main_k_block_loop
=
(
GemmK
+
GemmKPerBlock
)
/
(
2
*
GemmKPerBlock
)
>
1
;
const
bool
has_main_k_block_loop
=
(
GemmK
+
GemmKPerBlock
)
/
(
2
*
GemmKPerBlock
)
>
1
;
...
@@ -1410,14 +1412,20 @@ template <index_t BlockSize,
...
@@ -1410,14 +1412,20 @@ template <index_t BlockSize,
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
>
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
>
struct
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
struct
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
{
{
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
>
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
__host__
void
Run
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
__host__
void
Run
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
MultiIndex
<
2
>
conv_strides
,
const
ConvStrides
&
conv_strides
,
const
MultiIndex
<
2
>
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
MultiIndex
<
2
>
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
const
MultiIndex
<
2
>
in_right_pads
,
const
InRightPads
&
in_right_pads
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_in_global
,
const
Float
*
__restrict__
p_in_global
,
Float
*
__restrict__
p_out_global
)
const
Float
*
__restrict__
p_out_global
)
const
...
@@ -1427,30 +1435,30 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
...
@@ -1427,30 +1435,30 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
const
auto
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
const
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
const
auto
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
const
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
const
auto
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
const
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
const
auto
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
const
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
const
auto
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
const
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
const
auto
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
const
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
const
auto
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
const
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
const
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
const
auto
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
const
index_t
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
index_t
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
index_t
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
index_t
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
index_t
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
index_t
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
index_t
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
index_t
InRightPadW
=
in_right_pads
[
I1
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
if
(
!
(
Y
==
1
&&
X
==
1
&&
ConvStrideH
==
1
&&
ConvStrideW
==
1
&&
ConvDilationH
==
1
&&
if
(
!
(
Y
==
1
&&
X
==
1
&&
ConvStrideH
==
1
&&
ConvStrideW
==
1
&&
ConvDilationH
==
1
&&
ConvDilationW
==
1
&&
InLeftPadH
==
0
&&
InLeftPadW
==
0
&&
InRightPadH
==
0
&&
ConvDilationW
==
1
&&
InLeftPadH
==
0
&&
InLeftPadW
==
0
&&
InRightPadH
==
0
&&
...
@@ -1480,9 +1488,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
...
@@ -1480,9 +1488,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
index_t
GemmM
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I0
);
const
auto
GemmM
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I0
);
const
index_t
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
auto
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
index_t
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
if
(
!
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
if
(
!
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
))
GemmK
%
GemmKPerBlock
==
0
))
...
@@ -1490,11 +1498,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
...
@@ -1490,11 +1498,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
}
constexpr
index_t
GemmM1
=
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
auto
GemmM1
=
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
GemmN1
=
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
auto
GemmN1
=
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
const
index_t
GemmM0
=
GemmM
/
GemmM1
;
const
auto
GemmM0
=
GemmM
/
GemmM1
;
const
index_t
GemmN0
=
GemmN
/
GemmN1
;
const
auto
GemmN0
=
GemmN
/
GemmN1
;
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
=
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
=
transform_dynamic_tensor_descriptor
(
transform_dynamic_tensor_descriptor
(
...
@@ -1574,7 +1582,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
...
@@ -1574,7 +1582,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
decltype
(
a_k_m_global_move_slice_window_iterator_hack
),
decltype
(
a_k_m_global_move_slice_window_iterator_hack
),
decltype
(
b_k_n_global_move_slice_window_iterator_hack
)
>
;
decltype
(
b_k_n_global_move_slice_window_iterator_hack
)
>
;
const
index_t
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
const
auto
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
const
bool
has_main_k_block_loop
=
(
GemmK
+
GemmKPerBlock
)
/
(
2
*
GemmKPerBlock
)
>
1
;
const
bool
has_main_k_block_loop
=
(
GemmK
+
GemmKPerBlock
)
/
(
2
*
GemmKPerBlock
)
>
1
;
...
...
composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp
View file @
2558d019
...
@@ -6,17 +6,20 @@
...
@@ -6,17 +6,20 @@
namespace
ck
{
namespace
ck
{
template
<
typename
LowLength
=
index_t
>
struct
DynamicPassThrough
struct
DynamicPassThrough
{
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
UpperIndex
up_lengths_
;
using
UpLengths
=
decltype
(
make_tuple
(
LowLength
{}));
UpLengths
up_lengths_
;
__host__
__device__
constexpr
DynamicPassThrough
()
=
default
;
__host__
__device__
constexpr
DynamicPassThrough
()
=
default
;
__host__
__device__
constexpr
DynamicPassThrough
(
const
index_t
&
low_length
)
__host__
__device__
constexpr
DynamicPassThrough
(
const
LowLength
&
low_length
)
:
up_lengths_
{
make_
multi_index
(
low_length
)}
:
up_lengths_
{
make_
tuple
(
low_length
)}
{
{
}
}
...
@@ -75,27 +78,33 @@ struct DynamicPassThrough
...
@@ -75,27 +78,33 @@ struct DynamicPassThrough
{
{
printf
(
"{"
);
printf
(
"{"
);
printf
(
"DynamicPassThrough, "
);
printf
(
"DynamicPassThrough, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
printf
(
"}"
);
}
}
};
};
template
<
bool
SkipIsValidCheck
=
false
>
template
<
bool
SkipIsValidCheck
=
false
,
typename
LowLength
=
index_t
,
typename
LeftPad
=
index_t
,
typename
RightPad
=
index_t
>
struct
DynamicPad
struct
DynamicPad
{
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
UpperIndex
up_lengths_
;
using
UpLengths
=
decltype
(
make_tuple
(
LowLength
{}
+
LeftPad
{}
+
RightPad
{}));
index_t
left_pad_
;
index_t
right_pad_
;
UpLengths
up_lengths_
;
LeftPad
left_pad_
;
RightPad
right_pad_
;
__host__
__device__
constexpr
DynamicPad
()
=
default
;
__host__
__device__
constexpr
DynamicPad
()
=
default
;
__host__
__device__
constexpr
DynamicPad
(
const
index_t
&
low_length
,
__host__
__device__
constexpr
DynamicPad
(
const
LowLength
&
low_length
,
const
index_t
&
left_pad
,
const
LeftPad
&
left_pad
,
const
index_t
&
right_pad
)
const
RightPad
&
right_pad
)
:
up_lengths_
{
make_
multi_index
(
low_length
+
left_pad
+
right_pad
)},
:
up_lengths_
{
make_
tuple
(
low_length
+
left_pad
+
right_pad
)},
left_pad_
{
left_pad
},
left_pad_
{
left_pad
},
right_pad_
{
right_pad
}
right_pad_
{
right_pad
}
{
{
...
@@ -158,27 +167,30 @@ struct DynamicPad
...
@@ -158,27 +167,30 @@ struct DynamicPad
{
{
printf
(
"{"
);
printf
(
"{"
);
printf
(
"DynamicPad, "
);
printf
(
"DynamicPad, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
print_multi_index
(
up_lengths_
);
printf
(
"left_pad_ %d"
,
left_pad_
);
printf
(
"left_pad_ %d"
,
index_t
{
left_pad_
});
printf
(
", "
);
printf
(
"right_pad_ %d"
,
index_t
{
right_pad_
});
printf
(
"right_pad_ %d"
,
right_pad_
);
printf
(
"}"
);
printf
(
"}"
);
}
}
};
};
template
<
bool
SkipIsValidCheck
=
false
>
template
<
bool
SkipIsValidCheck
=
false
,
typename
LowLength
=
index_t
,
typename
LeftPad
=
index_t
>
struct
DynamicLeftPad
struct
DynamicLeftPad
{
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
UpperIndex
up_lengths_
;
using
UpLengths
=
decltype
(
make_tuple
(
LowLength
{}
+
LeftPad
{}));
index_t
left_pad_
;
UpLengths
up_lengths_
;
LeftPad
left_pad_
;
__host__
__device__
constexpr
DynamicLeftPad
()
=
default
;
__host__
__device__
constexpr
DynamicLeftPad
()
=
default
;
__host__
__device__
constexpr
DynamicLeftPad
(
const
index_t
&
low_length
,
const
index_t
&
left_pad
)
__host__
__device__
constexpr
DynamicLeftPad
(
const
LowLength
&
low_length
,
:
up_lengths_
{
make_multi_index
(
low_length
+
left_pad
)},
left_pad_
{
left_pad
}
const
LeftPad
&
left_pad
)
:
up_lengths_
{
make_tuple
(
low_length
+
left_pad
)},
left_pad_
{
left_pad
}
{
{
}
}
...
@@ -238,27 +250,30 @@ struct DynamicLeftPad
...
@@ -238,27 +250,30 @@ struct DynamicLeftPad
{
{
printf
(
"{"
);
printf
(
"{"
);
printf
(
"DynamicLeftPad, "
);
printf
(
"DynamicLeftPad, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
print_multi_index
(
up_lengths_
);
printf
(
"left_pad_ %d"
,
left_pad_
);
printf
(
"left_pad_ %d"
,
index_t
{
left_pad_
}
);
printf
(
"}"
);
printf
(
"}"
);
}
}
};
};
template
<
bool
SkipIsValidCheck
=
false
>
template
<
bool
SkipIsValidCheck
=
false
,
typename
LowLength
=
index_t
,
typename
RightPad
=
index_t
>
struct
DynamicRightPad
struct
DynamicRightPad
{
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
UpperIndex
up_lengths_
;
using
UpLengths
=
decltype
(
make_tuple
(
LowLength
{}
+
RightPad
{}));
index_t
low_length_
;
index_t
right_pad_
;
UpLengths
up_lengths_
;
LowLength
low_length_
;
RightPad
right_pad_
;
__host__
__device__
constexpr
DynamicRightPad
()
=
default
;
__host__
__device__
constexpr
DynamicRightPad
()
=
default
;
__host__
__device__
constexpr
DynamicRightPad
(
const
index_t
&
low_length
,
__host__
__device__
constexpr
DynamicRightPad
(
const
LowLength
&
low_length
,
const
index_t
&
right_pad
)
const
RightPad
&
right_pad
)
:
up_lengths_
{
make_
multi_index
(
low_length
+
right_pad
)},
:
up_lengths_
{
make_
tuple
(
low_length
+
right_pad
)},
low_length_
{
low_length
},
low_length_
{
low_length
},
right_pad_
{
right_pad
}
right_pad_
{
right_pad
}
{
{
...
@@ -320,8 +335,10 @@ struct DynamicRightPad
...
@@ -320,8 +335,10 @@ struct DynamicRightPad
{
{
printf
(
"{"
);
printf
(
"{"
);
printf
(
"DynamicRightPad, "
);
printf
(
"DynamicRightPad, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
print_multi_index
(
up_lengths_
);
printf
(
"left_pad_ %d"
,
right_pad_
);
printf
(
"low_length_ %d"
,
index_t
{
low_length_
});
printf
(
"left_pad_ %d"
,
index_t
{
right_pad_
});
printf
(
"}"
);
printf
(
"}"
);
}
}
};
};
...
@@ -422,24 +439,29 @@ struct DynamicEmbed
...
@@ -422,24 +439,29 @@ struct DynamicEmbed
}
}
};
};
template
<
index_t
NDimLow
>
template
<
index_t
NDimLow
,
typename
LowLengths
=
MultiIndex
<
NDimLow
>
>
struct
DynamicMerge
struct
DynamicMerge
{
{
using
LowerIndex
=
MultiIndex
<
NDimLow
>
;
using
LowerIndex
=
MultiIndex
<
NDimLow
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
LowerIndex
low_lengths_
;
using
LowLengthsScan
=
decltype
(
LowerIndex
low_lengths_scan_
;
container_reverse_exclusive_scan
(
LowLengths
{},
math
::
multiplies_v2
{},
Number
<
1
>
{}));
UpperIndex
up_lengths_
;
using
UpLengths
=
decltype
(
make_tuple
(
container_reduce
(
LowLengths
{},
math
::
multiplies_v2
{},
Number
<
1
>
{})));
LowLengths
low_lengths_
;
LowLengthsScan
low_lengths_scan_
;
UpLengths
up_lengths_
;
__host__
__device__
constexpr
DynamicMerge
()
=
default
;
__host__
__device__
constexpr
DynamicMerge
()
=
default
;
__host__
__device__
constexpr
DynamicMerge
(
const
Low
erIndex
&
low_lengths
)
__host__
__device__
constexpr
DynamicMerge
(
const
Low
Lengths
&
low_lengths
)
:
low_lengths_
{
low_lengths
},
:
low_lengths_
{
low_lengths
},
low_lengths_scan_
{
container_reverse_exclusive_scan
(
low_lengths_scan_
{
low_lengths
,
math
::
multiplies
<
index_t
>
{},
index_t
{
1
})},
container_reverse_exclusive_scan
(
low_lengths
,
math
::
multiplies_v2
{},
Number
<
1
>
{})},
up_lengths_
{
make_multi_index
(
up_lengths_
{
make_tuple
(
container_reduce
(
low_lengths
,
math
::
multiplies_v2
{},
Number
<
1
>
{}))}
container_reduce
(
low_lengths
,
math
::
multiplies
<
index_t
>
(),
index_t
{
1
}))}
{
{
static_assert
(
LowerIndex
::
Size
()
==
NDimLow
,
"wrong!"
);
static_assert
(
LowerIndex
::
Size
()
==
NDimLow
,
"wrong!"
);
}
}
...
@@ -1017,31 +1039,27 @@ struct DynamicUnMerge
...
@@ -1017,31 +1039,27 @@ struct DynamicUnMerge
{
{
printf
(
"{"
);
printf
(
"{"
);
printf
(
"DynamicUnMerge, "
);
printf
(
"DynamicUnMerge, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
print_multi_index
(
up_lengths_
);
print_multi_index
(
up_lengths_scan_
);
print_multi_index
(
up_lengths_scan_
);
printf
(
"}"
);
printf
(
"}"
);
}
}
};
};
template
<
typename
LowerIndex
=
index_t
>
struct
DynamicFreeze
struct
DynamicFreeze
{
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
0
>
;
LowerIndex
low_idx_
;
LowerIndex
low_idx_
;
__host__
__device__
constexpr
DynamicFreeze
()
=
default
;
__host__
__device__
constexpr
DynamicFreeze
()
=
default
;
__host__
__device__
constexpr
DynamicFreeze
(
const
index_t
&
low_idx
)
__host__
__device__
constexpr
DynamicFreeze
(
const
LowerIndex
&
low_idx
)
:
low_idx_
{
low_idx
}
{}
:
low_idx_
{
make_multi_index
(
low_idx
)}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
0
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
0
;
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
UpperIndex
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
Tuple
<>
{};
}
template
<
typename
LowIdx
,
typename
UpIdx
>
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
...
@@ -1081,7 +1099,11 @@ struct DynamicFreeze
...
@@ -1081,7 +1099,11 @@ struct DynamicFreeze
return
true
;
return
true
;
}
}
__host__
__device__
void
Print
()
const
{
printf
(
"DynamicFreeze"
);
}
__host__
__device__
void
Print
()
const
{
printf
(
"DynamicFreeze"
);
printf
(
"low_idx_ %d"
,
index_t
{
low_idx_
});
}
};
};
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/utility/config.amd.hpp.in
View file @
2558d019
...
@@ -118,6 +118,7 @@ enum InMemoryDataOperation
...
@@ -118,6 +118,7 @@ enum InMemoryDataOperation
AtomicAdd
AtomicAdd
};
};
// index type
using index_t = int32_t;
using index_t = int32_t;
typedef int32_t int32x2_t __attribute__((ext_vector_type(2)));
typedef int32_t int32x2_t __attribute__((ext_vector_type(2)));
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
2558d019
...
@@ -3,6 +3,19 @@
...
@@ -3,6 +3,19 @@
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
template
<
typename
T
>
__host__
__device__
constexpr
auto
sequence_to_tuple_of_number
(
const
T
&
x
)
{
using
namespace
ck
;
return
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
index_t
tmp
=
T
::
At
(
i
);
return
Number
<
tmp
>
{};
},
T
::
Size
());
}
template
<
class
T
,
template
<
class
T
,
class
InDesc
,
class
InDesc
,
class
WeiDesc
,
class
WeiDesc
,
...
@@ -27,11 +40,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
...
@@ -27,11 +40,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
using
TDevice
=
typename
conditional
<
is_same
<
half_float
::
half
,
T
>::
value
,
half_t
,
T
>::
type
;
using
TDevice
=
typename
conditional
<
is_same
<
half_float
::
half
,
T
>::
value
,
half_t
,
T
>::
type
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
std
::
size_t
data_sz
=
sizeof
(
T
);
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
...
@@ -41,7 +49,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
...
@@ -41,7 +49,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
// assume packed tensor
#if 1
const
auto
in_n_c_hi_wi_desc
=
const
auto
in_n_c_hi_wi_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
to_multi_index
(
InDesc
::
GetLengths
()));
make_dynamic_naive_tensor_descriptor_packed_v2
(
to_multi_index
(
InDesc
::
GetLengths
()));
const
auto
wei_k_c_y_x_desc
=
const
auto
wei_k_c_y_x_desc
=
...
@@ -53,6 +61,19 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
...
@@ -53,6 +61,19 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
const
auto
conv_dilations
=
to_multi_index
(
ConvDilations
{});
const
auto
conv_dilations
=
to_multi_index
(
ConvDilations
{});
const
auto
in_left_pads
=
to_multi_index
(
InLeftPads
{});
const
auto
in_left_pads
=
to_multi_index
(
InLeftPads
{});
const
auto
in_right_pads
=
to_multi_index
(
InRightPads
{});
const
auto
in_right_pads
=
to_multi_index
(
InRightPads
{});
#else
const
auto
in_n_c_hi_wi_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
sequence_to_tuple_of_number
(
InDesc
::
GetLengths
()));
const
auto
wei_k_c_y_x_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
sequence_to_tuple_of_number
(
WeiDesc
::
GetLengths
()));
const
auto
out_n_k_ho_wo_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
sequence_to_tuple_of_number
(
OutDesc
::
GetLengths
()));
const
auto
conv_strides
=
sequence_to_tuple_of_number
(
ConvStrides
{});
const
auto
conv_dilations
=
sequence_to_tuple_of_number
(
ConvDilations
{});
const
auto
in_left_pads
=
sequence_to_tuple_of_number
(
InLeftPads
{});
const
auto
in_right_pads
=
sequence_to_tuple_of_number
(
InRightPads
{});
#endif
#if 0
#if 0
// cdata = 64, BlockSize = 256, 128x128x2
// cdata = 64, BlockSize = 256, 128x128x2
...
@@ -210,28 +231,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
...
@@ -210,28 +231,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#endif
#endif
const
index_t
N
=
out_n_k_ho_wo_desc
.
GetLength
(
I0
);
const
index_t
K
=
out_n_k_ho_wo_desc
.
GetLength
(
I1
);
const
index_t
Ho
=
out_n_k_ho_wo_desc
.
GetLength
(
I2
);
const
index_t
Wo
=
out_n_k_ho_wo_desc
.
GetLength
(
I3
);
const
index_t
C
=
wei_k_c_y_x_desc
.
GetLength
(
I1
);
const
index_t
Y
=
wei_k_c_y_x_desc
.
GetLength
(
I2
);
const
index_t
X
=
wei_k_c_y_x_desc
.
GetLength
(
I3
);
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
N
*
Ho
*
Wo
;
const
index_t
GemmK
=
C
*
Y
*
X
;
if
(
!
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
))
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
const
index_t
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
conv_driver
=
constexpr
auto
conv_driver
=
#if 1
#if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment