Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3c4fb1dd
Commit
3c4fb1dd
authored
Nov 23, 2023
by
Umang Yadav
Browse files
Merge remote-tracking branch 'origin/develop' into migx_merge
parents
57cdd70b
e8cddfdc
Changes
385
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1283 additions
and
533 deletions
+1283
-533
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+13
-2
include/ck/stream_config.hpp
include/ck/stream_config.hpp
+2
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp
.../ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp
+0
-370
include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp
include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp
+348
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+94
-41
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+71
-72
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
+214
-0
include/ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp
.../tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp
+33
-0
include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp
..._operation/gpu/device/device_contraction_multiple_abd.hpp
+61
-0
include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp
...or_operation/gpu/device/device_contraction_multiple_d.hpp
+2
-1
include/ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp
...sor_operation/gpu/device/device_conv_tensor_rearrange.hpp
+78
-0
include/ck/tensor_operation/gpu/device/device_elementwise_scale.hpp
.../tensor_operation/gpu/device/device_elementwise_scale.hpp
+55
-0
include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp
.../tensor_operation/gpu/device/device_gemm_multiple_abd.hpp
+60
-0
include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp
...ude/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp
+6
-3
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp
...on/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp
+3
-1
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
...r_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
+3
-1
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
...ation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
+132
-0
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp
...eration/gpu/device/device_grouped_conv_fwd_multiple_d.hpp
+42
-40
include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp
...sor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp
+63
-0
include/ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp
...de/ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp
+3
-2
No files found.
Too many changes to show.
To preserve performance only
385 of 385+
files are displayed.
Plain diff
Email patch
include/ck/host_utility/kernel_launch.hpp
View file @
3c4fb1dd
...
...
@@ -33,9 +33,13 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
printf
(
"Warm up 1 time
\n
"
);
#endif
// warm up
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
for
(
int
i
=
0
;
i
<
stream_config
.
cold_niters_
;
++
i
)
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
}
const
int
nrepeat
=
10
;
const
int
nrepeat
=
stream_config
.
nrepeat_
;
#if DEBUG_LOG
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
#endif
...
...
@@ -50,6 +54,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
}
hip_check_error
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
...
...
@@ -64,11 +69,13 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
else
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
return
0
;
}
#else
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
return
0
;
#endif
...
...
@@ -101,6 +108,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
// warm up
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
const
int
nrepeat
=
10
;
#if DEBUG_LOG
...
...
@@ -118,6 +126,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
{
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
}
hip_check_error
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
...
...
@@ -133,11 +142,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
{
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
return
0
;
}
#else
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
return
0
;
#endif
...
...
include/ck/stream_config.hpp
View file @
3c4fb1dd
...
...
@@ -11,4 +11,6 @@ struct StreamConfig
hipStream_t
stream_id_
=
nullptr
;
bool
time_kernel_
=
false
;
int
log_level_
=
0
;
int
cold_niters_
=
50
;
int
nrepeat_
=
200
;
};
include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp
deleted
100644 → 0
View file @
57cdd70b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp"
namespace
ck
{
/**
* DPP8 version of blockwise GEMM algorithm. It uses DPP8 instruction modifier to limit
* the data loaded from LDS to registers.
*
* The algorithm groups threads into groups of size `dpp8::lane_group_size` and splits the matrix C
* between them in such a way that threads from the same group need the same chunk of either
* matrix A (or B, respectively). Without the usage of DPP8, each thread would need to load the
* whole chunk from LDS to its own register space.
* Usage of DPP8 modifiers allow each thread to load less data, exactly `1 / dpp8::lane_group_size`
* of the chunk, and then share that data with other threads from the same lane group.
*
* Assumptions coming from the usage of DPP8:
* 1. `BM10BN10ThreadClusterBM10Xs[1] == dpp8::lane_group_size` or
* `BM10BN10ThreadClusterBN10Xs[1] == dpp8::lane_group_size` -
* - it makes consecutive `dpp8::lane_group_size` threads use the same chunk of either
* matrix A or B;
* - based on these values we determine which matrix to share.
* 2. `BM1PerThreadBM11 % dpp8::lane_group_size == 0` (if sharing A) or
* `BN1PerThreadBN11 % dpp8::lane_group_size == 0` (if sharing B) -
* - we have to make sure that the data to split is divisible by the number of
* threads in the group.
*
* General algorithm:
* C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
* A and B are visible to the whole block, C is distributed among each thread
* Assume:
* 1. A:
* 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
* 2. ABlockBuffer is DynamicBuffer
* 2. B:
* 1. BBlockDesc_BK0_BN_BK1 is known at compile-time
* 2. BBlockBuffer is DynamicBuffer
* 3. C:
* 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time
* 2. CThreadBuffer is StaticBuffer
* 4. BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2
*/
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
ABlockDesc_BK0_BM_BK1
,
typename
BBlockDesc_BK0_BN_BK1
,
index_t
BM1PerThreadBM11
,
index_t
BN1PerThreadBN11
,
index_t
BK0PerThread
,
typename
BM10BN10ThreadClusterBM10Xs
,
// Sequence<BM10BN10ThreadClusterBM100,
// BM10BN10ThreadClusterBM101, ...>
typename
BM10BN10ThreadClusterBN10Xs
,
// Sequence<BM10BN10ThreadClusterBN100,
// BM10BN10ThreadClusterBN101, ...>
index_t
AThreadCopyScalarPerVector_BM11
,
index_t
BThreadCopyScalarPerVector_BN11
,
typename
enable_if
<
ABlockDesc_BK0_BM_BK1
::
IsKnownAtCompileTime
()
&&
BBlockDesc_BK0_BN_BK1
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0
{
using
AIndex
=
MultiIndex
<
4
>
;
using
BIndex
=
MultiIndex
<
4
>
;
using
CIndex
=
MultiIndex
<
4
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
index_t
BK0
=
ABlockDesc_BK0_BM_BK1
{}.
GetLength
(
I0
);
static
constexpr
index_t
BK1
=
ABlockDesc_BK0_BM_BK1
{}.
GetLength
(
I2
);
static
constexpr
index_t
BM
=
ABlockDesc_BK0_BM_BK1
{}.
GetLength
(
I1
);
static
constexpr
index_t
BN
=
BBlockDesc_BK0_BN_BK1
{}.
GetLength
(
I1
);
static
constexpr
index_t
BM100
=
BM10BN10ThreadClusterBM10Xs
{}[
I0
];
static
constexpr
index_t
BN100
=
BM10BN10ThreadClusterBN10Xs
{}[
I0
];
static
constexpr
index_t
BM101
=
BM10BN10ThreadClusterBM10Xs
{}[
I1
];
static
constexpr
index_t
BN101
=
BM10BN10ThreadClusterBN10Xs
{}[
I1
];
static
constexpr
index_t
BM11
=
BM1PerThreadBM11
;
static
constexpr
index_t
BN11
=
BN1PerThreadBN11
;
static
constexpr
index_t
BM1
=
BM100
*
BM101
*
BM11
;
static
constexpr
index_t
BN1
=
BN100
*
BN101
*
BN11
;
static
constexpr
index_t
BM0
=
BM
/
BM1
;
static
constexpr
index_t
BN0
=
BN
/
BN1
;
// We assume that either `BM101` or `BN101` is equal to `dpp8::lane_group_size`. It makes all
// threads in a lane group need the same chunk of B or A matrices and we can share them using
// DPP.
static_assert
(
BM101
==
dpp8
::
lane_group_size
||
BN101
==
dpp8
::
lane_group_size
);
static
constexpr
bool
ShareB
=
BM101
==
dpp8
::
lane_group_size
?
true
:
false
;
static
constexpr
bool
ShareA
=
!
ShareB
;
// If DPP shares A (B, respectively), lane group gets `BM1PerThreadBM11` (`BN1PerThreadBN11`,
// respectively) elements, so we split them between threads in lane group so each thread loads
// less data from LDS.
static
constexpr
index_t
BM1PerThread
=
ShareA
?
BM1PerThreadBM11
/
dpp8
::
lane_group_size
:
BM1PerThreadBM11
;
static
constexpr
index_t
BN1PerThread
=
ShareB
?
BN1PerThreadBN11
/
dpp8
::
lane_group_size
:
BN1PerThreadBN11
;
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_BK0_BM0_BM1_BK1
(
const
ABlockDesc_BK0_BM_BK1
&
a_block_desc_bk0_bm_bk1
)
{
const
auto
a_block_bk0_bm0_bm1_bk1
=
transform_tensor_descriptor
(
a_block_desc_bk0_bm_bk1
,
make_tuple
(
make_pass_through_transform
(
Number
<
BK0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BM0
>
{},
Number
<
BM1
>
{})),
make_pass_through_transform
(
Number
<
BK1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
a_block_bk0_bm0_bm1_bk1
;
}
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_BK0_BN0_BN1_BK1
(
const
BBlockDesc_BK0_BN_BK1
&
b_block_desc_bk0_bn_bk1
)
{
const
auto
b_block_desc_bk0_bn0_bn1_bk1
=
transform_tensor_descriptor
(
b_block_desc_bk0_bn_bk1
,
make_tuple
(
make_pass_through_transform
(
Number
<
BK0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BN0
>
{},
Number
<
BN1
>
{})),
make_pass_through_transform
(
Number
<
BK1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
b_block_desc_bk0_bn0_bn1_bk1
;
}
__host__
__device__
static
constexpr
auto
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN
()
{
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM, BN]
constexpr
auto
c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
BM0
>
{},
Number
<
BM100
>
{},
Number
<
BM101
>
{},
Number
<
BM11
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
BN0
>
{},
Number
<
BN100
>
{},
Number
<
BN101
>
{},
Number
<
BN11
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
,
5
,
6
,
7
>
{}));
return
c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n
;
}
__host__
__device__
static
constexpr
auto
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1
()
{
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM0, BM1, BN0, BN1]
constexpr
auto
c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_pass_through_transform
(
Number
<
BM0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BM100
>
{},
Number
<
BM101
>
{},
Number
<
BM11
>
{})),
make_pass_through_transform
(
Number
<
BN0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BN100
>
{},
Number
<
BN101
>
{},
Number
<
BN11
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
,
7
>
{}));
return
c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1
;
}
__host__
__device__
static
constexpr
auto
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
()
{
return
Sequence
<
BM0
,
BM11
,
BN0
,
BN11
>
{};
}
static
constexpr
auto
a_block_desc_bk0_bm0_bm1_bk1_
=
MakeABlockDescriptor_BK0_BM0_BM1_BK1
(
ABlockDesc_BK0_BM_BK1
{});
static
constexpr
auto
b_block_desc_bk0_bn0_bn1_bk1_
=
MakeBBlockDescriptor_BK0_BN0_BN1_BK1
(
BBlockDesc_BK0_BN_BK1
{});
public:
__device__
BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0
()
:
c_thread_origin_data_idx_
{
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
get_thread_local_1d_id
())},
a_thread_copy_
{
CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1
()},
b_thread_copy_
{
CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1
()}
{
static_assert
(
ABlockDesc_BK0_BM_BK1
::
IsKnownAtCompileTime
()
&&
BBlockDesc_BK0_BN_BK1
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
BM
%
BM1
==
0
&&
BN
%
BN1
==
0
,
"wrong!"
);
static_assert
(
ABlockDesc_BK0_BM_BK1
{}.
GetLength
(
I0
)
==
BBlockDesc_BK0_BN_BK1
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
static_assert
(
BM10BN10ThreadClusterBM10Xs
::
Size
()
==
2
&&
BM10BN10ThreadClusterBN10Xs
::
Size
()
==
2
,
"wrong!"
);
}
__device__
static
CIndex
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
index_t
thread_id
)
{
// lower: [BM0, BM1, BN0, BN1]
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
constexpr
auto
adaptor0
=
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1
();
// lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// upper: [Tid, BM0, BM11, BN0, BN11]
constexpr
auto
adaptor1
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
BM100
,
BN100
,
BM101
,
BN101
)),
make_pass_through_transform
(
BM0
),
make_pass_through_transform
(
BM11
),
make_pass_through_transform
(
BN0
),
make_pass_through_transform
(
BN11
)),
make_tuple
(
Sequence
<
1
,
5
,
2
,
6
>
{},
Sequence
<
0
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
constexpr
auto
adaptor
=
chain_tensor_adaptors
(
adaptor0
,
adaptor1
);
return
adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
,
0
,
0
,
0
,
0
));
}
__device__
AIndex
CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1
()
{
const
auto
offsetBM0
=
c_thread_origin_data_idx_
[
I0
];
// If sharing matrix A, we need a separate BM1 offset for each thread in lane group.
const
auto
offsetBM1
=
ShareA
?
c_thread_origin_data_idx_
[
I1
]
+
dpp8
::
get_thread_idx_in_lane_group
()
*
BM1PerThread
:
c_thread_origin_data_idx_
[
I1
];
return
make_tuple
(
0
,
offsetBM0
,
offsetBM1
,
0
);
}
__device__
BIndex
CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1
()
{
const
auto
offsetBN0
=
c_thread_origin_data_idx_
[
I2
];
// If sharing matrix B, we need a separate BN1 offset for each thread in lane group.
const
auto
offsetBN1
=
ShareB
?
c_thread_origin_data_idx_
[
I3
]
+
dpp8
::
get_thread_idx_in_lane_group
()
*
BN1PerThread
:
c_thread_origin_data_idx_
[
I3
];
return
make_tuple
(
0
,
offsetBN0
,
offsetBN1
,
0
);
}
template
<
typename
CThreadDesc_BM0_BM11_BN0_BN11
,
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
CThreadDesc_BM0_BM11_BN0_BN11
&
,
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
static_assert
(
CThreadDesc_BM0_BM11_BN0_BN11
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
a_thread_desc_bk0_bm0_bm1_bk1_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
b_thread_desc_bk0_bn0_bn1_bk1_
.
GetElementSpaceSize
());
constexpr
auto
threadwise_contraction
=
ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
<
FloatA
,
FloatB
,
FloatC
,
decltype
(
a_thread_desc_bk0_bm0_bm1_bk1_
),
decltype
(
b_thread_desc_bk0_bn0_bn1_bk1_
),
CThreadDesc_BM0_BM11_BN0_BN11
,
Sequence
<
BK0PerThread
,
BK1
>
,
Sequence
<
1
,
BM1PerThreadBM11
>
,
Sequence
<
1
,
BN1PerThreadBN11
>
,
ShareA
>
{};
static_for
<
0
,
BN0
,
1
>
{}([
&
](
auto
bn0
)
{
static_for
<
0
,
BM0
,
1
>
{}([
&
](
auto
bm0
)
{
a_thread_copy_
.
Run
(
a_block_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
I0
,
bm0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
b_thread_copy_
.
Run
(
b_block_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
I0
,
bn0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
threadwise_contraction
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
bm0
,
I0
,
bn0
,
I0
));
static_for
<
BK0PerThread
,
BK0
,
BK0PerThread
>
{}([
&
](
auto
bk0
)
{
a_thread_copy_
.
Run
(
a_block_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
bk0
,
bm0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
b_thread_copy_
.
Run
(
b_block_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
bk0
,
bn0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
threadwise_contraction
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
bm0
,
I0
,
bn0
,
I0
));
});
});
});
}
private:
// A[BK0, BM0, BM1, BK1]
static
constexpr
auto
a_thread_desc_bk0_bm0_bm1_bk1_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BK0PerThread
>
{},
Number
<
BM0
>
{},
Number
<
BM1PerThread
>
{},
Number
<
BK1
>
{}));
// B[BK0, BN0, BN1, BK1]
static
constexpr
auto
b_thread_desc_bk0_bn0_bn1_bk1_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BK0PerThread
>
{},
Number
<
BN0
>
{},
Number
<
BN1PerThread
>
{},
Number
<
BK1
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4r1
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_bk0_bm0_bm1_bk1_
),
decltype
(
a_thread_desc_bk0_bm0_bm1_bk1_
),
Sequence
<
BK0PerThread
,
1
,
BM1PerThread
,
BK1
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
>
,
// DimAccessOrder
Sequence
<
1
,
1
,
BM1PerThread
,
BK1
>
,
// SrcVectorTensorLengths
Sequence
<
0
,
1
,
2
,
3
>>
;
// SrcVectorTensorContiguousDimOrder
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4r1
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_bk0_bn0_bn1_bk1_
),
decltype
(
b_thread_desc_bk0_bn0_bn1_bk1_
),
Sequence
<
BK0PerThread
,
1
,
BN1PerThread
,
BK1
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
>
,
// DimAccessOrder
Sequence
<
1
,
1
,
BN1PerThread
,
BK1
>
,
// SrcVectorTensorLengths
Sequence
<
0
,
1
,
2
,
3
>>
;
// SrcVectorTensorContiguousDimOrder
CIndex
c_thread_origin_data_idx_
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp
0 → 100644
View file @
3c4fb1dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/dpp_gemm.hpp"
namespace
ck
{
/**
* Blockwise GEMM that uses DPP instruction modifier to limit the amount of data loaded for each
* thread by sharing the data between threads in a lanegroup.
*
* In every iteration, each wave calculates a C tile of size `MPerDpp` * `NPerDpp`, there are
* `MRepeat` iterations for `M` dimension and `NRepeat` for `N` one.
* In total, the algorithm runs using
* `MPerBlock / (MRepeat * MPerDpp) * NPerBlock / (NRepeat * NPerDpp)` waves.
*/
template
<
index_t
BlockSize
,
typename
ABDataType
,
typename
AccDataType
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
index_t
MPerDpp
,
index_t
NPerDpp
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
struct
BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
static
constexpr
index_t
WaveSize
=
get_warp_size
();
static
constexpr
index_t
MPerBlock
=
AK0MK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
NPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
KPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
)
*
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
A_K0
=
AK0MK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
B_K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
dpp_gemm
=
DppGemm
<
ABDataType
,
MPerDpp
,
NPerDpp
,
KPack
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
dpp_gemm
.
K0PerDpp
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerDpp
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerDpp
);
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MRepeat
*
NRepeat
,
dpp_gemm
.
GetRegSizePerDpp
(),
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
ThisThreadBlock
::
GetThreadId
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
CalculateAThreadOriginDataIndex_M0_M1_M2_K
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
dpp_a_idx
=
dpp_gemm
.
CalculateAThreadOriginDataIndex_K_M
();
const
auto
dpp_a_idx_k
=
dpp_a_idx
[
I0
];
const
auto
dpp_a_idx_m
=
dpp_a_idx
[
I1
];
return
make_tuple
(
0
,
waveId_m
,
dpp_a_idx_m
,
KPerThread
*
dpp_a_idx_k
);
}
__device__
static
auto
CalculateBThreadOriginDataIndex_N0_N1_N2_K
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
dpp_b_idx
=
dpp_gemm
.
CalculateBThreadOriginDataIndex_K_N
();
const
auto
dpp_b_idx_k
=
dpp_b_idx
[
I0
];
const
auto
dpp_b_idx_n
=
dpp_b_idx
[
I1
];
return
make_tuple
(
0
,
waveId_n
,
dpp_b_idx_n
,
KPerThread
*
dpp_b_idx_k
);
}
template
<
index_t
m0
,
index_t
n0
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
dpp_gemm
.
GetBeginOfThreadBlk
();
const
auto
blk_m_offset
=
blk_idx
[
I0
];
const
auto
blk_n_offset
=
blk_idx
[
I1
];
constexpr
auto
mrepeat_mwave_MPerDpp_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerDpp
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_NPerDpp_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerDpp
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_MPerDpp_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_m_offset
))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_NPerDpp_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
n0
,
waveId_n
,
blk_n_offset
))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
__host__
__device__
BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2
()
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
BK0NK1BlockDesc
::
IsKnownAtCompileTime
(),
"Wrong! Block descriptors should be known at the time of compilation."
);
#if defined(__HIP_DEVICE_COMPILE__)
// Host wave size can be different than the device one and this assert could fail for host,
// but it does matter only for device.
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
#endif
static_assert
(
MPerBlock
%
(
MPerDpp
*
MRepeat
)
==
0
,
"Invalid parameters. MPerBlock must be divisible by MPerDpp * MRepeat."
);
static_assert
(
NPerBlock
%
(
NPerDpp
*
NRepeat
)
==
0
,
"Invalid parameters. NPerBlock must be divisible by NPerDpp * NRepeat."
);
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2
()
{
constexpr
auto
c_m_n_tblk_lens
=
dpp_gemm
.
GetCMNThreadBlkLengths
();
constexpr
auto
M
=
c_m_n_tblk_lens
[
I0
];
constexpr
auto
N
=
c_m_n_tblk_lens
[
I1
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_N2
()
{
constexpr
auto
c_m_n_tblk_lens
=
dpp_gemm
.
GetCMNThreadBlkLengths
();
constexpr
auto
M
=
c_m_n_tblk_lens
[
I0
];
constexpr
auto
N
=
c_m_n_tblk_lens
[
I1
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerDpp
>
{},
Number
<
NPerDpp
>
{}));
return
c_block_desc_m0_n0_m1_n1_m2_n2
;
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_N2
()
{
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerDpp
>
{},
Number
<
NPerDpp
>
{}));
return
c_block_desc_g_m0_n0_m1_n1_m2_n2
;
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerDpp
),
MWaves
,
MPerDpp
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerDpp
),
NWaves
,
NPerDpp
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
return
c_grid_desc_m0_n0_m1_n1_m2_n2
;
}
template
<
typename
CGridDesc_G_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_N2
(
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
{
const
auto
G
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_g_m_n
.
GetLength
(
I1
);
const
auto
N
=
c_grid_desc_g_m_n
.
GetLength
(
I2
);
const
auto
c_grid_desc_g_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_g_m_n
,
make_tuple
(
make_pass_through_transform
(
G
),
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerDpp
),
MWaves
,
MPerDpp
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerDpp
),
NWaves
,
NPerDpp
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
2
,
4
,
6
>
{}));
return
c_grid_desc_g_m0_n0_m1_n1_m2_n2
;
}
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_M0_M1_M2_K
()
{
return
transform_tensor_descriptor
(
AK0MK1BlockDesc
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
A_K0
>
{},
Number
<
A_K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerDpp
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_N0_N1_N2_K
()
{
return
transform_tensor_descriptor
(
BK0NK1BlockDesc
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
B_K0
>
{},
Number
<
B_K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerDpp
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
static
constexpr
auto
a_block_desc_m0_m1_m2_k
=
MakeABlockDescriptor_M0_M1_M2_K
();
static
constexpr
auto
b_block_desc_n0_n1_n2_k
=
MakeBBlockDescriptor_N0_N1_N2_K
();
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ABDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ABDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
vector_type
<
ABDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ABDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
ABDataType
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
ABDataType
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
});
using
dpp_input_type
=
typename
vector_type
<
ABDataType
,
dpp_gemm
.
K1PerDpp
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
dpp_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
dpp_input_type
>(),
b_thread_vec
.
template
AsType
<
dpp_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
}
protected:
// A[M0, M1, M2, KPerThread]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPerThread
>
{}));
// B[N0, N1, N2, KPerThread]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPerThread
>
{}));
// C[M, N, NumRegDpp]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
dpp_gemm
.
GetRegSizePerDpp
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
ABDataType
,
ABDataType
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
ABDataType
,
ABDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex_M0_M1_M2_K
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex_N0_N1_N2_K
()};
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
3c4fb1dd
...
...
@@ -221,49 +221,102 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
>
{},
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
>
{},
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
,
m0
,
0
,
0
,
i
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
,
n0
,
0
,
0
,
i
%
B_K1
))
>
{}];
// basic intrinsic to determine loopover direction
if
constexpr
(
MRepeat
<
NRepeat
)
{
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
>
{},
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
>
{},
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
,
m0
,
0
,
0
,
i
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
,
n0
,
0
,
0
,
i
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
}
else
{
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
>
{},
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
>
{},
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
,
m0
,
0
,
0
,
i
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
,
n0
,
0
,
0
,
i
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
}
}
protected:
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
3c4fb1dd
...
...
@@ -4,27 +4,13 @@
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace
ck
{
enum
struct
LoopScheduler
{
Default
,
Interwave
,
};
constexpr
LoopScheduler
make_default_loop_scheduler
()
{
#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
return
LoopScheduler
::
Interwave
;
#else
return
LoopScheduler
::
Default
;
#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
}
template
<
index_t
MNXdlPerWave
,
index_t
MNWaves
,
index_t
MNPerXdl
,
typename
TileDesc_K0_MN_K1
>
__host__
__device__
static
constexpr
auto
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
(
const
TileDesc_K0_MN_K1
&
)
...
...
@@ -42,7 +28,8 @@ MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
}
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatA
,
typename
FloatB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
...
...
@@ -72,7 +59,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatA
B
,
MPerXDL
,
NPerXDL
,
KPack
>
{};
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatA
,
MPerXDL
,
NPerXDL
,
KPack
,
FloatB
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
...
...
@@ -308,9 +295,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
B
>
(
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
A
B
>
(
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
...
...
@@ -332,25 +319,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
b_thread_buf
);
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
vector_type
<
FloatA
B
,
KPack
>
a_thread_vec
;
vector_type
<
Float
A
B
,
KPack
>
b_thread_vec
;
vector_type
<
FloatA
,
KPack
>
a_thread_vec
;
vector_type
<
FloatB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatA
B
>()(
i
)
=
a_thread_buf
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
Float
A
B
>()(
i
)
=
b_thread_buf
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
using
mfma_input_type_a
=
typename
vector_type
<
FloatA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
using
mfma_input_type_b
=
typename
vector_type
<
FloatB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
_a
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
...
...
@@ -370,8 +359,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
xdlops_gemm
.
GetRegSizePerXdlops
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
B
,
FloatA
B
,
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
...
...
@@ -380,8 +369,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
Float
A
B
,
Float
A
B
,
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
...
...
@@ -399,7 +388,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
// default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatA
,
typename
FloatB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
...
...
@@ -411,7 +401,8 @@ template <index_t BlockSize,
index_t
NumMacClusters
=
CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
>
struct
BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
:
public
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatA
,
FloatB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
...
...
@@ -422,7 +413,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatA
,
FloatB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
...
...
@@ -454,9 +446,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
B
>
(
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
A
B
>
(
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
KPerThread
,
KPerInnerLoop
>
{}([
&
](
auto
k
)
{
...
...
@@ -493,20 +485,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
FloatA
B
,
KPack
>
a_thread_vec
;
vector_type
<
Float
A
B
,
KPack
>
b_thread_vec
;
vector_type
<
FloatA
,
KPack
>
a_thread_vec
;
vector_type
<
FloatB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatA
B
>()(
i
)
=
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
0
,
0
,
k_
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
Float
A
B
>()(
i
)
=
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
0
,
0
,
k_
+
i
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
using
mfma_input_type_a
=
typename
vector_type
<
FloatA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
using
mfma_input_type_b
=
typename
vector_type
<
FloatB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
...
...
@@ -528,8 +522,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// TODO: insert setprio in more precise manner since we
// could have more than >1 MFMA instructions in single call
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
_a
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
...
...
@@ -555,8 +549,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
KPerInnerLoop
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
B
,
FloatA
B
,
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
...
...
@@ -565,8 +559,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
Float
A
B
,
Float
A
B
,
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
...
...
@@ -582,7 +576,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
};
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatA
,
typename
FloatB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
...
...
@@ -597,7 +592,8 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
{
return
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatA
,
FloatB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
...
...
@@ -610,7 +606,8 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
{
return
BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatA
,
FloatB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
...
...
@@ -632,26 +629,27 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
* 3. configurable k index starting position and step size after each FMA/XDL instruction
*/
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
bool
TransposeC
=
false
,
index_t
AMmaKStride
=
KPack
*
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
TransposeC
>{}.
K0PerXdlops
,
index_t
BMmaKStride
=
KPack
*
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
TransposeC
>
{}.
K0PerXdlops
>
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
bool
TransposeC
=
false
,
index_t
AMmaKStride
=
KPack
*
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
FloatAB
,
TransposeC
>{}.
K0PerXdlops
,
index_t
BMmaKStride
=
KPack
*
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
FloatAB
,
TransposeC
>
{}.
K0PerXdlops
>
struct
BlockwiseGemmXdlops_v2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -668,7 +666,8 @@ struct BlockwiseGemmXdlops_v2
static
constexpr
index_t
A_K1
=
ATileDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BTileDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
TransposeC
>
{};
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
FloatAB
,
TransposeC
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
0 → 100644
View file @
3c4fb1dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp"
#include "ck/utility/is_detected.hpp"
namespace
ck
{
// Thread-group level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
// 3. DstInMemOps are per destination tensor
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
//
// Does following things to avoid scratch memory issue
// 1. Pass tensor descritpors by reference (or tuple of references)
// 2. Does not keep reference to tensor descriptor
// 3. Does not construct new tensor coordinate when call Run()
template
<
typename
ThreadGroup
,
typename
SrcDatas
,
typename
DstDatas
,
typename
SrcDescs
,
typename
DstDescs
,
typename
ElementwiseOperation
,
typename
DstInMemOps
,
// Sequence<InMemoryDataOperationEnum ...>
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
,
typename
ThreadTransferSrcResetCoordinateAfterRunFlags
,
typename
ThreadTransferDstResetCoordinateAfterRunFlags
>
struct
ThreadGroupTensorSliceTransfer_v7r2
{
static
constexpr
index_t
nDim
=
remove_cvref_t
<
tuple_element_t
<
0
,
SrcDescs
>>::
GetNumOfDimension
();
static
constexpr
index_t
nSrc
=
remove_cvref_t
<
SrcDescs
>::
Size
();
static
constexpr
index_t
nDst
=
remove_cvref_t
<
DstDescs
>::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
__device__
constexpr
ThreadGroupTensorSliceTransfer_v7r2
(
const
SrcDescs
&
src_descs
,
const
StaticallyIndexedArray
<
Index
,
nSrc
>&
src_block_slice_origins
,
const
DstDescs
&
dst_descs
,
const
StaticallyIndexedArray
<
Index
,
nDst
>&
dst_block_slice_origins
,
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
src_descs
,
StaticallyIndexedArray
<
Index
,
nSrc
>
{},
dst_descs
,
StaticallyIndexedArray
<
Index
,
nDst
>
{},
element_op
)
{
static_assert
(
nSrc
==
SrcDatas
::
Size
()
&&
nSrc
==
SrcDescs
::
Size
()
&&
nSrc
==
ThreadTransferSrcResetCoordinateAfterRunFlags
::
Size
()
&&
nDst
==
DstDatas
::
Size
()
&&
nDst
==
DstDescs
::
Size
()
&&
nDst
==
ThreadTransferDstResetCoordinateAfterRunFlags
::
Size
(),
"wrong!"
);
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
static_assert
(
nDim
==
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
SrcDescs
>>::
GetNumOfDimension
(),
"wrong!"
);
});
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
static_assert
(
nDim
==
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DstDescs
>>::
GetNumOfDimension
(),
"wrong!"
);
});
static_assert
(
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
SrcDimAccessOrder
::
Size
()
&&
nDim
==
DstDimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! ThreadGroup::GetNumOfThread() too small"
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
const
auto
src_thread_slice_origins
=
generate_tuple
(
[
&
](
auto
i
)
{
return
src_block_slice_origins
[
i
]
+
thread_data_idx_begin
;
},
Number
<
nSrc
>
{});
const
auto
dst_thread_slice_origins
=
generate_tuple
(
[
&
](
auto
i
)
{
return
dst_block_slice_origins
[
i
]
+
thread_data_idx_begin
;
},
Number
<
nDst
>
{});
threadwise_transfer_
.
SetSrcSliceOrigins
(
src_descs
,
src_thread_slice_origins
);
threadwise_transfer_
.
SetDstSliceOrigins
(
dst_descs
,
dst_thread_slice_origins
);
}
}
template
<
typename
SrcBuffers
>
__device__
void
RunRead
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunRead
(
src_descs
,
src_bufs
);
}
}
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
template
<
typename
DstBuffers
>
__device__
void
RunWrite
(
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
if
constexpr
(
is_detected
<
is_tuple
,
decltype
(
dst_bufs
)
>::
value
)
threadwise_transfer_
.
RunWrite
(
dst_descs
,
dst_bufs
);
else
threadwise_transfer_
.
RunWrite
(
dst_descs
,
tie
(
dst_bufs
));
}
}
template
<
typename
SrcBuffers
,
typename
DstBuffers
>
__device__
void
Run
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
,
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
{
RunRead
(
src_descs
,
src_bufs
);
RunWrite
(
dst_descs
,
dst_bufs
);
}
template
<
index_t
ISrc
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
Number
<
ISrc
>
iSrc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_descs
,
iSrc
,
step
);
}
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
const
Index
&
step
)
{
static_for
<
0
,
SrcDescs
::
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
MoveSrcSliceWindow
(
src_descs
,
i
,
step
);
});
}
template
<
index_t
IDst
>
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
Number
<
IDst
>
iDst
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_descs
,
iDst
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
const
Index
&
step
)
{
static_for
<
0
,
DstDescs
::
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
MoveDstSliceWindow
(
dst_descs
,
i
,
step
);
});
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v7r2
<
SrcDatas
,
DstDatas
,
SrcDescs
,
DstDescs
,
ElementwiseOperation
,
DstInMemOps
,
decltype
(
thread_slice_lengths
),
SrcDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorDim
,
DstVectorDim
,
SrcScalarPerVector
,
DstScalarPerVector
,
ThreadTransferSrcResetCoordinateAfterRunFlags
,
ThreadTransferDstResetCoordinateAfterRunFlags
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp
0 → 100644
View file @
3c4fb1dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
namespace
conv_tensor_rearrange_op
{
struct
BaseConvTensorRearrangeOp
{
};
struct
ImageToColumn
:
public
BaseConvTensorRearrangeOp
{
static
constexpr
const
char
*
name
=
"Image to Column"
;
};
struct
ColumnToImage
:
public
BaseConvTensorRearrangeOp
{
static
constexpr
const
char
*
name
=
"Column to Image"
;
};
template
<
typename
Op
,
typename
std
::
enable_if
<
std
::
is_base_of
<
BaseConvTensorRearrangeOp
,
Op
>
::
value
,
bool
>::
type
=
false
>
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
BaseConvTensorRearrangeOp
&
)
{
os
<<
Op
::
name
;
return
os
;
}
}
// namespace conv_tensor_rearrange_op
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp
0 → 100644
View file @
3c4fb1dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// GEMM:
// input : A0[M0, M1, ... K0, K1, ...], ...
// input : B0[N0, N1, ... K0, K1, ...], ...
// input : D0[M0, M1, ... N0, N1, ...], D1[M0, M1, ... N0, N1, ...], ...
// output : E[M0, M1, ... N0, N1, ...]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template
<
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
typename
AsDataType
,
typename
BsDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceContractionMultipleABD
:
public
BaseOperator
{
static
constexpr
index_t
NumATensor
=
AsDataType
::
Size
();
static
constexpr
index_t
NumBTensor
=
BsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
array
<
const
void
*
,
NumATensor
>
p_as
,
std
::
array
<
const
void
*
,
NumBTensor
>
p_bs
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
a_ms_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
a_ms_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
b_ns_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
b_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
d_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
d_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_ms_ns_length
,
const
std
::
vector
<
index_t
>&
e_ms_ns_stride
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp
View file @
3c4fb1dd
...
...
@@ -33,7 +33,8 @@ template <index_t NumDimM,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
typename
CDEElementwiseOperation
,
typename
ComputeDataType
=
ADataType
>
struct
DeviceContractionMultipleD
:
public
BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
...
include/ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp
0 → 100644
View file @
3c4fb1dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
/**
* \brief Convolution Tensor Rearrange.
*
* This Device operator supports converting an image to
* the GEMM representation (Image to Column) and
* converting a GEMM form to the image (Column to Image).
* Supported layouts:
* [G, N, Di, Hi, Wi, C] <-> [G, N * Do * Ho * Wo, Z * Y * X * C]
* [N, Di, Hi, Wi, G, C] <-> [N * Do * Ho * Wo, G, Z * Y * X * C]
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ImageLayout Input Layout.
* \tparam InputDataType Input Data Type.
* \tparam OutputDataType Output Data Type.
* \tparam ConvTensorRearrangeOp Operation type: ImageToColumn, ColumnToImage.
*/
template
<
index_t
NDimSpatial
,
typename
ImageLayout
,
typename
InputDataType
,
typename
OutputDataType
,
typename
ConvTensorRearrangeOp
>
struct
DeviceConvTensorRearrange
:
public
BaseOperator
{
/**
* \brief Make argument pointer for image to column.
*
* \param p_in A pointer to the device memory of the input image.
* \param p_out A pointer to the device memory of the output.
* \param G Convolution number of groups.
* \param N Convolution batch size.
* \param C Convolution number of channels.
* \param input_spatial_lengths Input spatial lengths.
* \param filter_spatial_lengths Filter spatial lengths.
* \param output_spatial_lengths Output spatial lengths.
* \param image_g_n_c_wis_strides Image strides in order [G, N, C, D, H, W].
* \param gemm_g_m_k_strides Gemm form strides.
* \param conv_filter_strides Convolution filter strides.
* \param conv_filter_dilations Convolution filter dilations.
* \param input_left_pads Convolution left pads.
* \param input_right_pads Convolution right pads.
* \return Pointer to the argument.
*/
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in
,
void
*
p_out
,
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
3
>&
gemm_g_m_k_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_elementwise_scale.hpp
0 → 100644
View file @
3c4fb1dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <memory>
#include <array>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InDataTypeTuple
,
typename
OutDataTypeTuple
,
typename
ElementwiseOperation
,
typename
UnaryOperation
,
typename
Scale
,
index_t
NumDim
>
struct
DeviceElementwise
:
public
BaseOperator
{
static
constexpr
int
NumInput
=
InDataTypeTuple
::
Size
();
static
constexpr
int
NumOutput
=
OutDataTypeTuple
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
NumDim
>
lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
std
::
array
<
void
*
,
NumOutput
>
out_dev_buffers
,
ElementwiseOperation
elementwise_op
,
UnaryOperation
unary_op
,
Scale
scale_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
// namespace device
template
<
typename
InDataTypeTuple
,
typename
OutDataTypeTuple
,
typename
ElementwiseOperation
,
typename
UnaryOperation
,
typename
Scale
,
index_t
NumDim
>
using
DeviceElementwisePtr
=
std
::
unique_ptr
<
DeviceElementwise
<
InDataTypeTuple
,
OutDataTypeTuple
,
ElementwiseOperation
,
UnaryOperation
,
Scale
,
NumDim
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp
0 → 100644
View file @
3c4fb1dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// GEMM:
// input : A0[M, K], B0[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template
<
typename
AsLayout
,
typename
BsLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
AsDataType
,
typename
BsDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceGemmMultipleABD
:
public
BaseOperator
{
static
constexpr
index_t
NumATensor
=
AsDataType
::
Size
();
static
constexpr
index_t
NumBTensor
=
BsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
array
<
const
void
*
,
NumATensor
>
p_as
,
std
::
array
<
const
void
*
,
NumBTensor
>
p_bs
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
std
::
array
<
ck
::
index_t
,
NumATensor
>
StrideAs
,
std
::
array
<
ck
::
index_t
,
NumBTensor
>
StrideBs
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
ck
::
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp
View file @
3c4fb1dd
...
...
@@ -20,7 +20,8 @@ template <typename ALayout,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
,
typename
ComputeType
=
CDataType
>
struct
DeviceGemmSplitK
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
...
...
@@ -48,7 +49,8 @@ template <typename ALayout,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
,
typename
ComputeType
=
CDataType
>
using
DeviceGemmSplitKPtr
=
std
::
unique_ptr
<
DeviceGemmSplitK
<
ALayout
,
BLayout
,
CLayout
,
...
...
@@ -57,7 +59,8 @@ using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
CElementwiseOperation
,
ComputeType
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp
View file @
3c4fb1dd
...
...
@@ -29,7 +29,9 @@ template <ck::index_t NDimSpatial,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
typename
CDEElementwiseOperation
,
typename
AComputeType
=
ADataType
,
typename
BComputeType
=
AComputeType
>
struct
DeviceGroupedConvBwdDataMultipleD
:
public
BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
View file @
3c4fb1dd
...
...
@@ -20,7 +20,9 @@ template <ck::index_t NDimSpatial,
typename
OutDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
typename
OutElementwiseOperation
,
typename
ComputeTypeA
=
InDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
DeviceGroupedConvBwdWeight
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
0 → 100644
View file @
3c4fb1dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/utility/is_detected.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
/**
* \brief Grouped Convolution Forward
*
* \details
* input : input image A[G, N, C, Hi, Wi], A1[G, N, C, Hi, Wi]...
* input : weight B[G, K, C, Y, X], B1[G, K, C, Y, X]...
* input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
* output : output image E[G, N, K, Ho, Wo]
*
* C = a_op(A, A1...) * b_op(B, B1...)
* E = cde_op(C, D0, D1, ...)
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ALayout Input layout (also for a1, a2...).
* \tparam BLayout Weight layout (also for b1, b2...).
* \tparam DsLayout Ds layouts.
* \tparam ELayout Output layout.
* \tparam ADataType Input data type. Pass tuple if there is multiple A.
* \tparam BDataType Weight data type. Pass tuple if there is multiple B.
* \tparam DsDataType D data types.
* \tparam EDataType Output data type.
* \tparam AElementwiseOperation A elementwise operation.
* \tparam BElementwiseOperation B elementwise operation.
* \tparam CDEElementwiseOperation CDE elementwise operation.
* \tparam ComputeType Compute data type (default: ADataType, first if tuple passed).
*/
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
ComputeType
=
decltype
(
UnpackDataType
<
is_detected
<
is_tuple
,
ADataType
>
::
value
,
Number
<
0
>
,
ADataType
>
())
>
// ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
struct
DeviceGroupedConvFwdMultipleABD
:
public
BaseOperator
{
static
constexpr
bool
isMultiA
=
is_detected
<
is_tuple
,
ADataType
>::
value
;
static
constexpr
bool
isMultiB
=
is_detected
<
is_tuple
,
BDataType
>::
value
;
static
constexpr
index_t
NumATensor
=
GetNumABTensors
<
isMultiA
,
ADataType
>
();
static
constexpr
index_t
NumBTensor
=
GetNumABTensors
<
isMultiB
,
BDataType
>
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static_assert
(
NumDTensor
==
DsLayout
::
Size
(),
"wrong! Inconsistent NumDTensor"
);
// If DataType is tuple, user has to pass std::array with pointers.
using
APointers
=
std
::
conditional_t
<
isMultiA
,
std
::
array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
using
BPointers
=
std
::
conditional_t
<
isMultiB
,
std
::
array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
/**
* \brief Make argument pointer for grouped conv fwd.
*
* \param p_a A pointer to the input (std::array<const void*, NumA> with
pointers for multiple A).
* \param p_b A pointer to the weight (std::array<const void*, NumA> with
pointers for multiple B).
* \param p_ds A pointers to the Ds.
* \param p_e A pointers to the output.
* \param a_g_n_c_wis_lengths Input lengths [G, N, C, Spatial...] (for 3d).
* \param a_g_n_c_wis_strides Input strides [G, N, C, Spatial...] (for 3d).
* \param b_g_k_c_xs_lengths Weight lengths [G, K, C, Spatial...] (for 3d).
* \param b_g_k_c_xs_strides Weight strides [G, K, C, Spatial...] (for 3d).
* \param ds_g_n_k_wos_lengths Ds lengths [G, N, K, Spatial...] (for 3d).
* \param ds_g_n_k_wos_strides Ds strides [G, N, K, Spatial...] (for 3d).
* \param e_g_n_k_wos_lengths Output lengths [G, N, K, Spatial...] (for 3d).
* \param e_g_n_k_wos_strides Output strides [G, N, K, Spatial...] (for 3d).
* \param conv_filter_strides Convolution filter strides.
* \param conv_filter_dilations Convolution filter dilations.
* \param input_left_pads Input left paddings.
* \param input_right_pads Input right paddings.
* \param a_element_op A elementwise operation object.
* \param b_element_op B elementwise operation object.
* \param cde_element_op CDE elementwise operation object.
* \return Pointer to the argument.
*/
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
APointers
p_a
,
BPointers
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp
View file @
3c4fb1dd
...
...
@@ -3,21 +3,33 @@
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Convolution Forward:
// input : input image A[G, N, C, Hi, Wi],
// input : weight B[G, K, C, Y, X],
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
// output : output image E[G, N, K, Ho, Wo]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
/**
* \brief Grouped Convolution Forward
*
* \note This structure is deprecated (left for backwards compatibility). Please use
* DeviceGroupedConvFwdMultipleABD.
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ALayout Input layout (also for a1, a2...).
* \tparam BLayout Weight layout (also for b1, b2...).
* \tparam DsLayout Ds layouts.
* \tparam ELayout Output layout.
* \tparam ADataType Input data type. Pass tuple if there is multiple A.
* \tparam BDataType Weight data type. Pass tuple if there is multiple B.
* \tparam DsDataType D data types.
* \tparam EDataType Output data type.
* \tparam AElementwiseOperation A elementwise operation.
* \tparam BElementwiseOperation B elementwise operation.
* \tparam CDEElementwiseOperation CDE elementwise operation.
* \tparam ComputeType Compute data type (default: ADataType, first if tuple passed).
*/
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
...
...
@@ -29,36 +41,26 @@ template <index_t NDimSpatial,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceGroupedConvFwdMultipleD
:
public
BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static_assert
(
NumDTensor
==
DsLayout
::
Size
(),
"wrong! Inconsistent NumDTensor"
);
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
// input image
const
void
*
p_b
,
// weight
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
// output image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
typename
CDEElementwiseOperation
,
typename
ComputeType
=
decltype
(
UnpackDataType
<
is_detected
<
is_tuple
,
ADataType
>
::
value
,
Number
<
0
>
,
ADataType
>
())
>
// ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
using
DeviceGroupedConvFwdMultipleD
=
DeviceGroupedConvFwdMultipleABD
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
ComputeType
>
;
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp
0 → 100644
View file @
3c4fb1dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <array>
#include "device_grouped_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
NumDTensor
=
0
>
struct
GroupedGemmKernelArgument
{
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
;
void
*
p_e_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
index_t
StrideE
;
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGroupedGemmFixedNK
:
DeviceGroupedGemm
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
kernel_args
)
const
=
0
;
virtual
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
=
0
;
virtual
void
SetKBatch
(
BaseArgument
*
p_arg
,
index_t
k_batch
)
const
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_
inde
x_pool_bwd.hpp
→
include/ck/tensor_operation/gpu/device/device_
ma
x_pool_bwd.hpp
View file @
3c4fb1dd
...
...
@@ -13,7 +13,7 @@ namespace device {
// For pooling which used indexable operation, such as MaxPool, MinPool...etc
template
<
typename
DOutDataType
,
typename
IndexDataType
,
typename
DInDataType
>
struct
Device
Inde
xPoolBwd
:
public
BaseOperator
struct
Device
Ma
xPoolBwd
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_dout
,
...
...
@@ -22,7 +22,8 @@ struct DeviceIndexPoolBwd : public BaseOperator
index_t
dout_length
,
index_t
din_length
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
)
=
0
;
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_dilations
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
...
...
Prev
1
…
8
9
10
11
12
13
14
15
16
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment