Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
318db82b
Commit
318db82b
authored
May 31, 2021
by
Chao Liu
Browse files
overhauling fwd-v4r4
parent
d99e020d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
69 deletions
+62
-69
composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
...m_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
+3
-29
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
...nvolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
+59
-40
No files found.
composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
View file @
318db82b
...
@@ -10,11 +10,7 @@ namespace ck {
...
@@ -10,11 +10,7 @@ namespace ck {
// GemmM = K
// GemmM = K
// GemmN = N * Ho * Wo
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
// GemmK = C * Y * X
template
<
index_t
GemmMPerBlock
,
template
<
typename
...
Wei
,
index_t
GemmNPerBlock
,
index_t
GemmM1
,
index_t
GemmN1
,
typename
...
Wei
,
typename
...
In
,
typename
...
In
,
typename
...
Out
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvStrides
,
...
@@ -101,30 +97,8 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
...
@@ -101,30 +97,8 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
GemmM
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I0
);
return
make_tuple
(
const
auto
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
wei_gemmk_gemmm_global_desc
,
in_gemmk_gemmn_global_desc
,
out_gemmm_gemmn_global_desc
);
const
auto
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
);
const
auto
GemmM0
=
GemmM
/
Number
<
GemmM1
>
{};
const
auto
GemmN0
=
GemmN
/
Number
<
GemmN1
>
{};
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
=
transform_dynamic_tensor_descriptor
(
out_gemmm_gemmn_global_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM0
,
GemmM1
)),
make_unmerge_transform
(
make_tuple
(
GemmN0
,
GemmN1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// out_gemm_block_cluster_desc
const
auto
out_gemm_block_cluster_desc
=
make_cluster_descriptor_v2
(
make_tuple
(
GemmM
/
Number
<
GemmMPerBlock
>
{},
GemmN
/
Number
<
GemmNPerBlock
>
{}));
return
make_tuple
(
wei_gemmk_gemmm_global_desc
,
in_gemmk_gemmn_global_desc
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
out_gemm_block_cluster_desc
);
}
}
}
// namespace ck
}
// namespace ck
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
View file @
318db82b
...
@@ -469,42 +469,61 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
...
@@ -469,42 +469,61 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#endif
#endif
const
auto
descs
=
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad
(
wei_k_c_y_x_desc
,
in_n_c_hi_wi_desc
,
out_n_k_ho_wo_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
);
const
auto
wei_gemmk_gemmm_grid_desc
=
descs
[
I0
];
const
auto
in_gemmk_gemmn_grid_desc
=
descs
[
I1
];
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
GemmM
=
out_gemmm_gemmn_grid_desc
.
GetLength
(
I0
);
const
auto
GemmN
=
out_gemmm_gemmn_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK
=
wei_gemmk_gemmm_grid_desc
.
GetLength
(
I0
);
constexpr
index_t
GemmM1
=
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
GemmM1
=
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
GemmN1
=
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
GemmN1
=
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
const
auto
descs
=
assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
);
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad
<
GemmMPerBlock
,
GemmNPerBlock
,
const
auto
GemmM0
=
GemmM
/
Number
<
GemmM1
>
{};
GemmM1
,
const
auto
GemmN0
=
GemmN
/
Number
<
GemmN1
>
{};
GemmN1
>
(
wei_k_c_y_x_desc
,
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc
=
transform_dynamic_tensor_descriptor
(
in_n_c_hi_wi_desc
,
out_gemmm_gemmn_grid_desc
,
out_n_k_ho_wo_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM0
,
GemmM1
)),
conv_strides
,
make_unmerge_transform
(
make_tuple
(
GemmN0
,
GemmN1
))),
conv_dilations
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
in_left_pads
,
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
in_right_pads
);
// out_gemm_block_cluster_desc
// hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor
const
auto
out_gemm_block_cluster_desc
=
make_cluster_descriptor_v2
(
constexpr
auto
wei_gemmk_gemmm_global_iterator_hacks
=
make_tuple
(
GemmM
/
Number
<
GemmMPerBlock
>
{},
GemmN
/
Number
<
GemmNPerBlock
>
{}));
// hack to control index calculation when iterating over wei_gemmk_gemmm_grid tensor
constexpr
auto
wei_gemmk_gemmm_grid_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
constexpr
auto
wei_gemmk_gemmm_g
lobal
_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
wei_gemmk_gemmm_g
rid
_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over in_gemmk_gemmn_g
lobal
tensor
// hack to control index calculation when iterating over in_gemmk_gemmn_g
rid
tensor
constexpr
auto
in_gemmk_gemmn_g
lobal
_iterator_hacks
=
constexpr
auto
in_gemmk_gemmn_g
rid
_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
>
{}));
constexpr
auto
in_gemmk_gemmn_g
lobal
_move_slice_window_iterator_hacks
=
constexpr
auto
in_gemmk_gemmn_g
rid
_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
>
{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_grid
// tensor hack for NKHW format
constexpr
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks
=
constexpr
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
...
@@ -522,10 +541,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
...
@@ -522,10 +541,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
TAcc
,
TAcc
,
TOut
,
TOut
,
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
decltype
(
descs
[
I0
]
),
decltype
(
wei_gemmk_gemmm_grid_desc
),
decltype
(
descs
[
I1
]
),
decltype
(
in_gemmk_gemmn_grid_desc
),
decltype
(
descs
[
I2
]
),
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc
),
decltype
(
descs
[
I3
]
),
decltype
(
out_gemm_block_cluster_desc
),
GemmMPerBlock
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmKPerBlock
,
...
@@ -556,25 +575,25 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
...
@@ -556,25 +575,25 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
Sequence
<
2
,
3
,
0
,
1
>
,
Sequence
<
2
,
3
,
0
,
1
>
,
3
,
3
,
GemmCThreadTransferDstScalarPerVector_GemmN1
,
GemmCThreadTransferDstScalarPerVector_GemmN1
,
decltype
(
wei_gemmk_gemmm_g
lobal
_iterator_hacks
),
decltype
(
wei_gemmk_gemmm_g
rid
_iterator_hacks
),
decltype
(
in_gemmk_gemmn_g
lobal
_iterator_hacks
),
decltype
(
in_gemmk_gemmn_g
rid
_iterator_hacks
),
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_g
lobal
_iterator_hacks
),
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_g
rid
_iterator_hacks
),
decltype
(
wei_gemmk_gemmm_g
lobal
_move_slice_window_iterator_hacks
),
decltype
(
wei_gemmk_gemmm_g
rid
_move_slice_window_iterator_hacks
),
decltype
(
in_gemmk_gemmn_g
lobal
_move_slice_window_iterator_hacks
)
>
(
decltype
(
in_gemmk_gemmn_g
rid
_move_slice_window_iterator_hacks
)
>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
descs
[
I0
]
,
wei_gemmk_gemmm_grid_desc
,
descs
[
I1
]
,
in_gemmk_gemmn_grid_desc
,
descs
[
I2
]
,
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc
,
descs
[
I3
]
,
out_gemm_block_cluster_desc
,
wei_gemmk_gemmm_g
lobal
_iterator_hacks
,
wei_gemmk_gemmm_g
rid
_iterator_hacks
,
in_gemmk_gemmn_g
lobal
_iterator_hacks
,
in_gemmk_gemmn_g
rid
_iterator_hacks
,
out_gemmm0_gemmm1_gemmn0_gemmn1_g
lobal
_iterator_hacks
,
out_gemmm0_gemmm1_gemmn0_gemmn1_g
rid
_iterator_hacks
,
wei_gemmk_gemmm_g
lobal
_move_slice_window_iterator_hacks
,
wei_gemmk_gemmm_g
rid
_move_slice_window_iterator_hacks
,
in_gemmk_gemmn_g
lobal
_move_slice_window_iterator_hacks
,
in_gemmk_gemmn_g
rid
_move_slice_window_iterator_hacks
,
nrepeat
);
nrepeat
);
float
perf
=
(
float
)
calculate_convolution_flops
(
float
perf
=
(
float
)
calculate_convolution_flops
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment