Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4774d863
Commit
4774d863
authored
Jun 01, 2021
by
Chao Liu
Browse files
refactor
parent
5dd45128
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
186 additions
and
236 deletions
+186
-236
composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
...osable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
+33
-136
composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
...m_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
+2
-1
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
...l/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
+126
-98
composable_kernel/include/utility/config.amd.hpp.in
composable_kernel/include/utility/config.amd.hpp.in
+1
-1
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
...nvolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
+24
-0
No files found.
composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
View file @
4774d863
...
...
@@ -98,9 +98,26 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
using
CM0M1N0N1GridDesc
=
decltype
(
c_m0_m1_n0_n1_grid_desc
);
#if 0
const auto c_m0_m10_m
#endif
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
constexpr
auto
M11
=
Number
<
M1PerThread
*
M1N1ThreadClusterM11
*
M1N1ThreadClusterM10
>
{};
constexpr
auto
N11
=
Number
<
N1PerThread
*
M1N1ThreadClusterN11
*
M1N1ThreadClusterN10
>
{};
constexpr
auto
M10
=
M1
/
M11
;
constexpr
auto
N10
=
N1
/
N11
;
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
transform_dynamic_tensor_descriptor
(
c_m_n_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M10
,
M11
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N10
,
N11
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
using
CM0M10M11N0N10N11GridDesc
=
decltype
(
c_m0_m10_m11_n0_n10_n11_grid_desc
);
// out_gemm_block_cluster_desc
const
auto
c_block_cluster_desc
=
...
...
@@ -119,6 +136,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
BKNGridDesc
,
CM0M1N0N1GridDesc
,
CBlockClusterDesc
,
CM0M10M11N0N10N11GridDesc
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
...
...
@@ -160,7 +178,6 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
const
bool
has_double_tail_k_block_loop
=
(
K
/
KPerBlock
)
%
2
==
0
;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float
ave_time
=
0
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
...
...
@@ -173,6 +190,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
CM0M1N0N1GridDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
true
,
true
>
;
...
...
@@ -188,7 +206,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
c_m0_m1_n0_n1_grid_desc
,
c_block_cluster_desc
);
c_block_cluster_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
);
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
...
...
@@ -200,6 +219,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
CM0M1N0N1GridDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
true
,
false
>
;
...
...
@@ -215,7 +235,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
c_m0_m1_n0_n1_grid_desc
,
c_block_cluster_desc
);
c_block_cluster_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
);
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
...
...
@@ -227,6 +248,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
CM0M1N0N1GridDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
false
,
true
>
;
...
...
@@ -242,7 +264,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
c_m0_m1_n0_n1_grid_desc
,
c_block_cluster_desc
);
c_block_cluster_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
);
}
else
{
...
...
@@ -254,6 +277,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
CM0M1N0N1GridDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
false
,
false
>
;
...
...
@@ -269,138 +293,11 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
c_m0_m1_n0_n1_grid_desc
,
c_block_cluster_desc
);
c_block_cluster_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
);
}
return
ave_time
;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_k_m_grid_desc_device_buf
(
sizeof
(
AKMGridDesc
));
DeviceMem
b_k_n_grid_desc_device_buf
(
sizeof
(
BKNGridDesc
));
DeviceMem
c_m0_m1_n0_n1_grid_desc_device_buf
(
sizeof
(
CM0M1N0N1GridDesc
));
DeviceMem
c_block_cluster_desc_device_buf
(
sizeof
(
c_block_cluster_desc
));
a_k_m_grid_desc_device_buf
.
ToDevice
(
&
a_k_m_grid_desc
);
b_k_n_grid_desc_device_buf
.
ToDevice
(
&
b_k_n_grid_desc
);
c_m0_m1_n0_n1_grid_desc_device_buf
.
ToDevice
(
&
c_m0_m1_n0_n1_grid_desc
);
c_block_cluster_desc_device_buf
.
ToDevice
(
&
c_block_cluster_desc
);
float
ave_time
=
0
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AKMGridDesc
>
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
CM0M1N0N1GridDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
(
void
__CONSTANT__
*
)
a_k_m_grid_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_grid_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_grid_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AKMGridDesc
>
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
CM0M1N0N1GridDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
(
void
__CONSTANT__
*
)
a_k_m_grid_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_grid_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_grid_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AKMGridDesc
>
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
CM0M1N0N1GridDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
(
void
__CONSTANT__
*
)
a_k_m_grid_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_grid_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_grid_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
else
{
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AKMGridDesc
>
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
CM0M1N0N1GridDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
(
void
__CONSTANT__
*
)
a_k_m_grid_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_grid_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_grid_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
return
ave_time
;
#endif
}
}
// namespace ck
...
...
composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
View file @
4774d863
...
...
@@ -93,7 +93,8 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
// output tensor
const
auto
out_gemmm_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
View file @
4774d863
This diff is collapsed.
Click to expand it.
composable_kernel/include/utility/config.amd.hpp.in
View file @
4774d863
...
...
@@ -28,7 +28,7 @@
#endif
// launch bounds
#define CK_USE_LAUNCH_BOUNDS
1
#define CK_USE_LAUNCH_BOUNDS
0
#ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
View file @
4774d863
...
...
@@ -499,6 +499,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
constexpr
auto
in_gemmk_gemmn_grid_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
>
{};
#if 0
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_grid
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
...
...
@@ -509,6 +510,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
#else
constexpr
auto
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
#endif
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
...
...
@@ -553,7 +569,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
GemmCThreadTransferDstScalarPerVector_GemmN1
,
decltype
(
wei_gemmk_gemmm_grid_iterator_hacks
),
decltype
(
in_gemmk_gemmn_grid_iterator_hacks
),
#if 0
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks),
#else
decltype
(
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks
),
#endif
decltype
(
wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks
),
decltype
(
in_gemmk_gemmn_grid_move_slice_window_iterator_hacks
)
>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
...
...
@@ -566,7 +586,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
out_gemmm_gemmn_grid_desc
,
wei_gemmk_gemmm_grid_iterator_hacks
,
in_gemmk_gemmn_grid_iterator_hacks
,
#if 0
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks,
#else
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks
,
#endif
wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks
,
in_gemmk_gemmn_grid_move_slice_window_iterator_hacks
,
nrepeat
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment