Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8053bca3
Commit
8053bca3
authored
Jun 20, 2023
by
aska-0096
Browse files
separate array base and vector base attention tensor transformation
parent
b3770638
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
461 additions
and
251 deletions
+461
-251
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+3
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+34
-113
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
...tion/operator_transform/transform_contraction_to_gemm.hpp
+31
-134
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp
...tor_transform/transform_contraction_to_gemm_arraybase.hpp
+391
-0
No files found.
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
8053bca3
...
@@ -44,7 +44,7 @@ __global__ void
...
@@ -44,7 +44,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
)
const
CElementwiseOperation
c_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
defined(__gfx940__)
|| defined(__gfx941__) || defined(__gfx942__)
)
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
...
@@ -682,7 +682,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -682,7 +682,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
))
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
8053bca3
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm
_arraybase
.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -572,7 +572,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -572,7 +572,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
constexpr
auto
B0EnableLds
=
B0EnableLds_auto
||
B0EnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
B0EnableLds
=
B0EnableLds_auto
||
B0EnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
B1EnableLds
=
B1EnableLds_auto
||
B1EnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
B1EnableLds
=
B1EnableLds_auto
||
B1EnableLds_manu
||
(
NumPrefetch
>
1
);
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
_Wmma
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimL
,
NumDimK
,
NumDimN
>
,
Sequence
<
NumDimG
,
NumDimM
,
NumDimL
,
NumDimK
,
NumDimN
>
,
Sequence
<
MPerBlock
,
LPerBlock
,
KPerBlock
,
NPerBlock
>
,
Sequence
<
MPerBlock
,
LPerBlock
,
KPerBlock
,
NPerBlock
>
,
GemmSpec
,
GemmSpec
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
8053bca3
...
@@ -68,7 +68,7 @@ __global__ void
...
@@ -68,7 +68,7 @@ __global__ void
const
C0MatrixMask
c0_matrix_mask
)
const
C0MatrixMask
c0_matrix_mask
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
defined(__gfx940__)
|| defined(__gfx941__) || defined(__gfx942__)
)
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -252,27 +252,25 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -252,27 +252,25 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
B1Spec
,
B1Spec
,
CSpec
>
;
CSpec
>
;
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides_vec
)
{
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
Number
<
AK1
>
{});
Number
<
AK1
>
{});
}
}
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b_gs_ns_ks_strides_vec
)
{
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths_vec
,
b_gs_ns_ks_strides_vec
),
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths_vec
,
b_gs_ns_ks_strides_vec
),
Number
<
BK1
>
{});
Number
<
BK1
>
{});
}
}
static
auto
MakeB1GridDescriptor_BK0_N_BK1
(
static
auto
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_gemm1ns_gemm1ks_lengths_vec
,
MakeB1GridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_gemm1ns_gemm1ks_strides_vec
)
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides_vec
)
{
{
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths_vec
,
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths_vec
,
...
@@ -455,18 +453,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -455,18 +453,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CDataType
*
p_c_grid
,
CDataType
*
p_c_grid
,
const
std
::
array
<
void
*
,
NumD0Tensor
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumD0Tensor
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumD1Tensor
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
NumD1Tensor
>
p_acc1_biases
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>&
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>&
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD1Tensor
>&
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD1Tensor
>&
...
@@ -730,7 +724,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -730,7 +724,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
#endif
#endif
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
))
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
{
return
false
;
return
false
;
}
}
...
@@ -791,12 +786,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -791,12 +786,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
if
(
arg
.
d0s_nl_ns_lengths_strides_
[
i
][
1
]
==
1
&&
if
(
arg
.
d0s_nl_ns_lengths_strides_
[
i
][
1
]
==
1
&&
arg
.
d0s_nl_ns_lengths_strides_
[
i
][
0
]
%
D0sTransferSrcScalarPerVector
!=
0
)
arg
.
d0s_nl_ns_lengths_strides_
[
i
][
0
]
%
D0sTransferSrcScalarPerVector
!=
0
)
{
{
std
::
cout
<<
"first"
<<
std
::
endl
;
return
false
;
return
false
;
}
}
if
(
arg
.
d0s_nl_ns_lengths_strides_
[
i
][
1
]
!=
1
&&
D0sTransferSrcScalarPerVector
!=
1
)
if
(
arg
.
d0s_nl_ns_lengths_strides_
[
i
][
1
]
!=
1
&&
D0sTransferSrcScalarPerVector
!=
1
)
{
{
std
::
cout
<<
"second"
<<
std
::
endl
;
return
false
;
return
false
;
}
}
}
}
...
@@ -841,56 +834,20 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -841,56 +834,20 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
C1DEElementwiseOperation
c1de_element_op
)
C1DEElementwiseOperation
c1de_element_op
)
{
{
constexpr
auto
dimension
=
NumDimG
+
NumDimM
+
NumDimN
;
std
::
array
<
index_t
,
dimension
>
a_gs_ms_ks_lengths_
{};
std
::
array
<
index_t
,
dimension
>
a_gs_ms_ks_strides_
{};
std
::
array
<
index_t
,
dimension
>
b_gs_ns_ks_lengths_
{};
std
::
array
<
index_t
,
dimension
>
b_gs_ns_ks_strides_
{};
std
::
array
<
index_t
,
dimension
>
b1_gs_gemm1ns_gemm1ks_lengths_
{};
// b1_gs_os_ns_lengths
std
::
array
<
index_t
,
dimension
>
b1_gs_gemm1ns_gemm1ks_strides_
{};
// b1_gs_os_ns_strides
std
::
array
<
index_t
,
dimension
>
c_gs_ms_gemm1ns_lengths_
{};
// c_gs_ms_os_lengths
std
::
array
<
index_t
,
dimension
>
c_gs_ms_gemm1ns_strides_
{};
// c_gs_ms_os_strides
std
::
copy
(
a_gs_ms_ks_lengths
.
begin
(),
a_gs_ms_ks_lengths
.
begin
()
+
dimension
,
a_gs_ms_ks_lengths_
.
begin
());
std
::
copy
(
a_gs_ms_ks_strides
.
begin
(),
a_gs_ms_ks_strides
.
begin
()
+
dimension
,
a_gs_ms_ks_strides_
.
begin
());
std
::
copy
(
b_gs_ns_ks_lengths
.
begin
(),
b_gs_ns_ks_lengths
.
begin
()
+
dimension
,
b_gs_ns_ks_lengths_
.
begin
());
std
::
copy
(
b_gs_ns_ks_strides
.
begin
(),
b_gs_ns_ks_strides
.
begin
()
+
dimension
,
b_gs_ns_ks_strides_
.
begin
());
std
::
copy
(
b1_gs_gemm1ns_gemm1ks_lengths
.
begin
(),
b1_gs_gemm1ns_gemm1ks_lengths
.
begin
()
+
dimension
,
b1_gs_gemm1ns_gemm1ks_lengths_
.
begin
());
// b1_gs_os_ns_lengths
std
::
copy
(
b1_gs_gemm1ns_gemm1ks_strides
.
begin
(),
b1_gs_gemm1ns_gemm1ks_strides
.
begin
()
+
dimension
,
b1_gs_gemm1ns_gemm1ks_strides_
.
begin
());
// b1_gs_os_ns_strides
std
::
copy
(
c_gs_ms_gemm1ns_lengths
.
begin
(),
c_gs_ms_gemm1ns_lengths
.
begin
()
+
dimension
,
c_gs_ms_gemm1ns_lengths_
.
begin
());
// c_gs_ms_os_lengths
std
::
copy
(
c_gs_ms_gemm1ns_strides
.
begin
(),
c_gs_ms_gemm1ns_strides
.
begin
()
+
dimension
,
c_gs_ms_gemm1ns_strides_
.
begin
());
// c_gs_ms_os_strides
return
Argument
{
p_a
,
return
Argument
{
p_a
,
p_b
,
p_b
,
p_b1
,
p_b1
,
p_c
,
p_c
,
p_acc0_biases
,
p_acc0_biases
,
p_acc1_biases
,
p_acc1_biases
,
a_gs_ms_ks_lengths
_
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
_
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
_
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
_
,
b_gs_ns_ks_strides
,
b1_gs_gemm1ns_gemm1ks_lengths
_
,
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides
_
,
// b1_gs_os_ns_strides
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths
_
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
_
,
// c_gs_ms_os_strides
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
,
acc0_biases_gs_ms_ns_strides
,
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
...
@@ -933,56 +890,20 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -933,56 +890,20 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
C1DEElementwiseOperation
c1de_element_op
)
override
C1DEElementwiseOperation
c1de_element_op
)
override
{
{
constexpr
auto
dimension
=
NumDimG
+
NumDimM
+
NumDimN
;
std
::
array
<
index_t
,
dimension
>
a_gs_ms_ks_lengths_
{};
std
::
array
<
index_t
,
dimension
>
a_gs_ms_ks_strides_
{};
std
::
array
<
index_t
,
dimension
>
b_gs_ns_ks_lengths_
{};
std
::
array
<
index_t
,
dimension
>
b_gs_ns_ks_strides_
{};
std
::
array
<
index_t
,
dimension
>
b1_gs_gemm1ns_gemm1ks_lengths_
{};
// b1_gs_os_ns_lengths
std
::
array
<
index_t
,
dimension
>
b1_gs_gemm1ns_gemm1ks_strides_
{};
// b1_gs_os_ns_strides
std
::
array
<
index_t
,
dimension
>
c_gs_ms_gemm1ns_lengths_
{};
// c_gs_ms_os_lengths
std
::
array
<
index_t
,
dimension
>
c_gs_ms_gemm1ns_strides_
{};
// c_gs_ms_os_strides
std
::
copy
(
a_gs_ms_ks_lengths
.
begin
(),
a_gs_ms_ks_lengths
.
begin
()
+
dimension
,
a_gs_ms_ks_lengths_
.
begin
());
std
::
copy
(
a_gs_ms_ks_strides
.
begin
(),
a_gs_ms_ks_strides
.
begin
()
+
dimension
,
a_gs_ms_ks_strides_
.
begin
());
std
::
copy
(
b_gs_ns_ks_lengths
.
begin
(),
b_gs_ns_ks_lengths
.
begin
()
+
dimension
,
b_gs_ns_ks_lengths_
.
begin
());
std
::
copy
(
b_gs_ns_ks_strides
.
begin
(),
b_gs_ns_ks_strides
.
begin
()
+
dimension
,
b_gs_ns_ks_strides_
.
begin
());
std
::
copy
(
b1_gs_gemm1ns_gemm1ks_lengths
.
begin
(),
b1_gs_gemm1ns_gemm1ks_lengths
.
begin
()
+
dimension
,
b1_gs_gemm1ns_gemm1ks_lengths_
.
begin
());
// b1_gs_os_ns_lengths
std
::
copy
(
b1_gs_gemm1ns_gemm1ks_strides
.
begin
(),
b1_gs_gemm1ns_gemm1ks_strides
.
begin
()
+
dimension
,
b1_gs_gemm1ns_gemm1ks_strides_
.
begin
());
// b1_gs_os_ns_strides
std
::
copy
(
c_gs_ms_gemm1ns_lengths
.
begin
(),
c_gs_ms_gemm1ns_lengths
.
begin
()
+
dimension
,
c_gs_ms_gemm1ns_lengths_
.
begin
());
// c_gs_ms_os_lengths
std
::
copy
(
c_gs_ms_gemm1ns_strides
.
begin
(),
c_gs_ms_gemm1ns_strides
.
begin
()
+
dimension
,
c_gs_ms_gemm1ns_strides_
.
begin
());
// c_gs_ms_os_strides
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
B1DataType
*>
(
p_b1
),
static_cast
<
const
B1DataType
*>
(
p_b1
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
CDataType
*>
(
p_c
),
p_acc0_biases
,
// cast in struct Argument
p_acc0_biases
,
// cast in struct Argument
p_acc1_biases
,
// cast in struct Argument
p_acc1_biases
,
// cast in struct Argument
a_gs_ms_ks_lengths
_
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
_
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
_
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
_
,
b_gs_ns_ks_strides
,
b1_gs_gemm1ns_gemm1ks_lengths
_
,
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides
_
,
// b1_gs_os_ns_strides
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths
_
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
_
,
// c_gs_ms_os_strides
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
,
acc0_biases_gs_ms_ns_strides
,
acc1_biases_gs_ms_gemm1ns_lengths
,
acc1_biases_gs_ms_gemm1ns_lengths
,
...
...
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
View file @
8053bca3
...
@@ -16,15 +16,14 @@ template <index_t NumDimG,
...
@@ -16,15 +16,14 @@ template <index_t NumDimG,
index_t
NumDimM
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimN
,
device
::
TensorSpecialization
TensorSpec
>
device
::
TensorSpecialization
TensorSpec
>
__host__
__device__
static
auto
static
auto
MakeGridDescriptorPair
(
const
std
::
vector
<
index_t
>&
gs_ms_ns_lengths_vec
,
MakeGridDescriptorPair
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
gs_ms_ns_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
gs_ms_ns_strides_vec
)
{
{
//
if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
if
(
!
(
gs_ms_ns_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
&&
//
gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN))
gs_ms_ns_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
))
//
{
{
//
throw std::runtime_error("wrong! dimension must match input lengths");
throw
std
::
runtime_error
(
"wrong! dimension must match input lengths"
);
//
}
}
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
...
@@ -144,24 +143,21 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -144,24 +143,21 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
//
//
// A
// A
//
//
__host__
__device__
static
auto
MakeAGridDescriptorPair
(
static
auto
MakeAGridDescriptorPair
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides_vec
)
{
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimM
,
NumDimK
,
ASpec
>
(
a_gs_ms_ks_lengths_vec
,
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimM
,
NumDimK
,
ASpec
>
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
);
a_gs_ms_ks_strides_vec
);
}
}
// TODO: rename to G_MRaw_KRaw
// TODO: rename to G_MRaw_KRaw
__host__
__device__
static
auto
MakeAGridDescriptor_G_M_K
(
static
auto
MakeAGridDescriptor_G_M_K
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides_vec
)
{
{
return
MakeAGridDescriptorPair
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
).
first
;
return
MakeAGridDescriptorPair
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
).
first
;
}
}
__host__
__device__
static
auto
MakeAGridDescriptor_M_K
(
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides_vec
)
{
{
return
matrix_padder
.
PadADescriptor_M_K
(
return
matrix_padder
.
PadADescriptor_M_K
(
MakeAGridDescriptorPair
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
).
second
);
MakeAGridDescriptorPair
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
).
second
);
...
@@ -183,57 +179,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -183,57 +179,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
template
<
typename
AGridDesc_M_K
,
typename
WmmaK
,
typename
MRepeat
,
typename
MWaves
,
typename
MPerWmma
,
typename
AK1
>
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
WmmaK
&
,
const
MRepeat
&
,
const
MWaves
&
,
const
MPerWmma
&
,
const
AK1
&
)
{
const
auto
M0
=
a_grid_desc_m_k
.
GetLength
(
I0
)
/
MPerBlock
;
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AKWmma
=
K
/
WmmaK
{};
constexpr
auto
AKRow
=
2
;
constexpr
auto
AK0PerWmma
=
WmmaK
{}
/
AKRow
/
AK1
{};
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AKWmma
,
Number
<
AK0PerWmma
>
{},
Number
<
AKRow
>
{},
AK1
{})),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
{},
MWaves
{},
MPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
//
//
// B (alias of B0)
// B (alias of B0)
//
//
__host__
__device__
static
auto
MakeB0GridDescriptorPair
(
static
auto
MakeB0GridDescriptorPair
(
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_strides_vec
)
{
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimK
,
B0Spec
>
(
b0_gs_ns_ks_lengths_vec
,
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimK
,
B0Spec
>
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
);
b0_gs_ns_ks_strides_vec
);
}
}
// TODO: rename to G_MRaw_NRaw
// TODO: rename to G_MRaw_NRaw
__host__
__device__
static
auto
MakeB0GridDescriptor_G_N_K
(
static
auto
MakeB0GridDescriptor_G_N_K
(
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_strides_vec
)
{
{
return
MakeB0GridDescriptorPair
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
).
first
;
return
MakeB0GridDescriptorPair
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
).
first
;
}
}
__host__
__device__
static
auto
MakeB0GridDescriptor_N_K
(
static
auto
MakeB0GridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_strides_vec
)
{
{
// alias of matrix_padder.PadB0Descriptor_N_K
// alias of matrix_padder.PadB0Descriptor_N_K
return
matrix_padder
.
PadBDescriptor_N_K
(
return
matrix_padder
.
PadBDescriptor_N_K
(
...
@@ -256,57 +219,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -256,57 +219,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
template
<
typename
BGridDesc_L_K
,
typename
WmmaK
,
typename
LRepeat
,
typename
LWaves
,
typename
LPerWmma
,
typename
BK1
>
__host__
__device__
static
constexpr
auto
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1
(
const
BGridDesc_L_K
&
b_grid_desc_l_k
,
const
WmmaK
&
,
const
LRepeat
&
,
const
LWaves
&
,
const
LPerWmma
&
,
const
BK1
&
)
{
const
auto
L0
=
b_grid_desc_l_k
.
GetLength
(
I0
)
/
NPerBlock
;
const
auto
K
=
b_grid_desc_l_k
.
GetLength
(
I1
);
const
auto
BKWmma
=
K
/
WmmaK
{};
constexpr
auto
BKRow
=
2
;
constexpr
auto
BK0PerWmma
=
WmmaK
{}
/
BKRow
/
BK1
{};
return
transform_tensor_descriptor
(
b_grid_desc_l_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BKWmma
,
Number
<
BK0PerWmma
>
{},
Number
<
BKRow
>
{},
BK1
{})),
make_unmerge_transform
(
make_tuple
(
L0
*
LRepeat
{},
LWaves
{},
LPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
//
//
// B1
// B1
//
//
__host__
__device__
static
auto
MakeB1GridDescriptorPair
(
static
auto
MakeB1GridDescriptorPair
(
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_strides_vec
)
{
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimO
,
NumDimN
,
B1Spec
>
(
b1_gs_os_ns_lengths_vec
,
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimO
,
NumDimN
,
B1Spec
>
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
);
b1_gs_os_ns_strides_vec
);
}
}
// TODO: rename to G_NRaw_KRaw
// TODO: rename to G_NRaw_KRaw
__host__
__device__
static
auto
MakeB1GridDescriptor_G_N_K
(
static
auto
MakeB1GridDescriptor_G_N_K
(
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_strides_vec
)
{
{
return
MakeB1GridDescriptorPair
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
).
first
;
return
MakeB1GridDescriptorPair
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
).
first
;
}
}
__host__
__device__
static
auto
MakeB1GridDescriptor_N_K
(
static
auto
MakeB1GridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_strides_vec
)
{
{
// alias of matrix_padder.PadB1Descriptor_O_N
// alias of matrix_padder.PadB1Descriptor_O_N
return
matrix_padder
.
PadB1Descriptor_N_K
(
return
matrix_padder
.
PadB1Descriptor_N_K
(
...
@@ -330,57 +260,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -330,57 +260,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
template
<
typename
BGridDesc_N_L
,
typename
WmmaL
,
typename
NRepeat
,
typename
NWaves
,
typename
NPerWmma
,
typename
BL1
>
__host__
__device__
static
constexpr
auto
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1
(
const
BGridDesc_N_L
&
b_grid_desc_n_l
,
const
WmmaL
&
,
const
NRepeat
&
,
const
NWaves
&
,
const
NPerWmma
&
,
const
BL1
&
)
{
const
auto
N0
=
b_grid_desc_n_l
.
GetLength
(
I0
)
/
OPerBlock
;
const
auto
L
=
b_grid_desc_n_l
.
GetLength
(
I1
);
const
auto
BLWmma
=
L
/
WmmaL
{};
constexpr
auto
BLRow
=
2
;
constexpr
auto
BL0PerWmma
=
WmmaL
{}
/
BLRow
/
BL1
{};
return
transform_tensor_descriptor
(
b_grid_desc_n_l
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BLWmma
,
Number
<
BL0PerWmma
>
{},
Number
<
BLRow
>
{},
BL1
{})),
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
{},
NWaves
{},
NPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
//
//
// C
// C
//
//
__host__
__device__
static
auto
MakeCGridDescriptorPair
(
static
auto
MakeCGridDescriptorPair
(
const
std
::
vector
<
index_t
>&
c_gs_ms_os_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_lengths_vec
,
const
std
::
vector
<
index_t
>&
c_gs_ms_os_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_strides_vec
)
{
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimM
,
NumDimO
,
CSpec
>
(
c_gs_ms_os_lengths_vec
,
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimM
,
NumDimO
,
CSpec
>
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
);
c_gs_ms_os_strides_vec
);
}
}
// TODO: rename to G_MRaw_NRaw
// TODO: rename to G_MRaw_NRaw
__host__
__device__
static
auto
MakeCGridDescriptor_G_M_N
(
static
auto
MakeCGridDescriptor_G_M_N
(
const
std
::
vector
<
index_t
>&
c_gs_ms_os_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_lengths_vec
,
const
std
::
vector
<
index_t
>&
c_gs_ms_os_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_strides_vec
)
{
{
return
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
first
;
return
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
first
;
}
}
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
static
auto
MakeCGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
c_gs_ms_os_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_lengths_vec
,
const
std
::
vector
<
index_t
>&
c_gs_ms_os_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_strides_vec
)
{
{
return
matrix_padder
.
PadCDescriptor_M_N
(
return
matrix_padder
.
PadCDescriptor_M_N
(
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
second
);
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
second
);
...
...
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp
0 → 100644
View file @
8053bca3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
namespace
ck
{
namespace
tensor_operation
{
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
device
::
TensorSpecialization
TensorSpec
>
__host__
__device__
static
auto
MakeGridDescriptorPair
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
gs_ms_ns_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
gs_ms_ns_strides_vec
)
{
// if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
// gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN))
// {
// throw std::runtime_error("wrong! dimension must match input lengths");
// }
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
gs_ms_ns_lengths
=
to_tuple
(
gs_ms_ns_lengths_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
gs_ms_ns_strides
=
to_tuple
(
gs_ms_ns_strides_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// dimension Ids for G0, G1, ...
constexpr
auto
gDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimG
,
1
>::
type
{};
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
,
NumDimG
+
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
+
NumDimM
,
NumDimG
+
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for G0, G1, ...
const
auto
gLengths
=
get_container_subset
(
gs_ms_ns_lengths
,
gDimIds
);
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
gs_ms_ns_lengths
,
mDimIds
);
// lengths for N0, N1, ...
const
auto
nLengths
=
get_container_subset
(
gs_ms_ns_lengths
,
nDimIds
);
if
constexpr
(
TensorSpec
==
device
::
TensorSpecialization
::
Packed
)
{
auto
G
=
container_reduce
(
gLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
grid_desc_g_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
G
,
M
,
N
),
make_tuple
(
gs_ms_ns_strides
[
Number
<
NumDimG
-
1
>
{}],
gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
-
1
>
{}],
gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
+
NumDimN
-
1
>
{}]));
const
auto
grid_desc_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
-
1
>
{}],
gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
+
NumDimN
-
1
>
{}]));
return
std
::
make_pair
(
grid_desc_g_mraw_nraw
,
grid_desc_mraw_nraw
);
}
else
{
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const
auto
grid_desc_gs_ms_ns
=
make_naive_tensor_descriptor
(
gs_ms_ns_lengths
,
gs_ms_ns_strides
);
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
// Note: This does not require padding as it only provides G offset calculation. Technically
// descriptor for only G is needed. Here we opt for backward compatibility purpose to return
// G_M_N
const
auto
grid_desc_g_mraw_nraw
=
transform_tensor_descriptor
(
grid_desc_gs_ms_ns
,
make_tuple
(
make_merge_transform
(
gLengths
),
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
gDimIds
,
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
c_ms_ns_lengths
=
to_tuple
(
gs_ms_ns_lengths_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
c_ms_ns_strides
=
to_tuple
(
gs_ms_ns_strides_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const
auto
grid_desc_ms_ns
=
make_naive_tensor_descriptor
(
c_ms_ns_lengths
,
c_ms_ns_strides
);
const
auto
grid_desc_mraw_nraw
=
transform_tensor_descriptor
(
grid_desc_ms_ns
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
mDimIds
-
Number
<
NumDimG
>
{},
nDimIds
-
Number
<
NumDimG
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
std
::
make_pair
(
grid_desc_g_mraw_nraw
,
grid_desc_mraw_nraw
);
}
}
template
<
typename
NumDims_G_M_N_K_O
,
// Sequence<>
typename
PerBlock_M_N_K_O
,
// Sequence<>
device
::
GemmSpecialization
GemmSpec
,
device
::
TensorSpecialization
ASpec
,
device
::
TensorSpecialization
B0Spec
,
device
::
TensorSpecialization
B1Spec
,
device
::
TensorSpecialization
CSpec
>
struct
TransformBatchedContractionContractionToBatchedGemmGemm_Wmma
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
index_t
NumDimG
=
NumDims_G_M_N_K_O
::
At
(
I0
);
static
constexpr
index_t
NumDimM
=
NumDims_G_M_N_K_O
::
At
(
I1
);
static
constexpr
index_t
NumDimN
=
NumDims_G_M_N_K_O
::
At
(
I2
);
static
constexpr
index_t
NumDimK
=
NumDims_G_M_N_K_O
::
At
(
I3
);
static
constexpr
index_t
NumDimO
=
NumDims_G_M_N_K_O
::
At
(
I4
);
static
constexpr
index_t
MPerBlock
=
PerBlock_M_N_K_O
::
At
(
I0
);
static
constexpr
index_t
NPerBlock
=
PerBlock_M_N_K_O
::
At
(
I1
);
static
constexpr
index_t
KPerBlock
=
PerBlock_M_N_K_O
::
At
(
I2
);
static
constexpr
index_t
OPerBlock
=
PerBlock_M_N_K_O
::
At
(
I3
);
static
constexpr
auto
matrix_padder
=
device
::
GemmGemmPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
,
OPerBlock
};
//
// A
//
__host__
__device__
static
auto
MakeAGridDescriptorPair
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimM
,
NumDimK
,
ASpec
>
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
);
}
// TODO: rename to G_MRaw_KRaw
__host__
__device__
static
auto
MakeAGridDescriptor_G_M_K
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides_vec
)
{
return
MakeAGridDescriptorPair
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
).
first
;
}
__host__
__device__
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides_vec
)
{
return
matrix_padder
.
PadADescriptor_M_K
(
MakeAGridDescriptorPair
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
).
second
);
}
template
<
typename
AGridDesc_M_K
,
typename
Number
>
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
Number
&
AK1
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
typename
AGridDesc_M_K
,
typename
WmmaK
,
typename
MRepeat
,
typename
MWaves
,
typename
MPerWmma
,
typename
AK1
>
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
WmmaK
&
,
const
MRepeat
&
,
const
MWaves
&
,
const
MPerWmma
&
,
const
AK1
&
)
{
const
auto
M0
=
a_grid_desc_m_k
.
GetLength
(
I0
)
/
MPerBlock
;
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AKWmma
=
K
/
WmmaK
{};
constexpr
auto
AKRow
=
2
;
constexpr
auto
AK0PerWmma
=
WmmaK
{}
/
AKRow
/
AK1
{};
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AKWmma
,
Number
<
AK0PerWmma
>
{},
Number
<
AKRow
>
{},
AK1
{})),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
{},
MWaves
{},
MPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
//
// B (alias of B0)
//
__host__
__device__
static
auto
MakeB0GridDescriptorPair
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimK
,
B0Spec
>
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
);
}
// TODO: rename to G_MRaw_NRaw
__host__
__device__
static
auto
MakeB0GridDescriptor_G_N_K
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_strides_vec
)
{
return
MakeB0GridDescriptorPair
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
).
first
;
}
__host__
__device__
static
auto
MakeB0GridDescriptor_N_K
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_strides_vec
)
{
// alias of matrix_padder.PadB0Descriptor_N_K
return
matrix_padder
.
PadBDescriptor_N_K
(
MakeB0GridDescriptorPair
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
).
second
);
}
template
<
typename
BGridDesc_N_K
,
typename
Number
>
__host__
__device__
static
constexpr
auto
MakeB0GridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
Number
&
BK1
)
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
typename
BGridDesc_L_K
,
typename
WmmaK
,
typename
LRepeat
,
typename
LWaves
,
typename
LPerWmma
,
typename
BK1
>
__host__
__device__
static
constexpr
auto
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1
(
const
BGridDesc_L_K
&
b_grid_desc_l_k
,
const
WmmaK
&
,
const
LRepeat
&
,
const
LWaves
&
,
const
LPerWmma
&
,
const
BK1
&
)
{
const
auto
L0
=
b_grid_desc_l_k
.
GetLength
(
I0
)
/
NPerBlock
;
const
auto
K
=
b_grid_desc_l_k
.
GetLength
(
I1
);
const
auto
BKWmma
=
K
/
WmmaK
{};
constexpr
auto
BKRow
=
2
;
constexpr
auto
BK0PerWmma
=
WmmaK
{}
/
BKRow
/
BK1
{};
return
transform_tensor_descriptor
(
b_grid_desc_l_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BKWmma
,
Number
<
BK0PerWmma
>
{},
Number
<
BKRow
>
{},
BK1
{})),
make_unmerge_transform
(
make_tuple
(
L0
*
LRepeat
{},
LWaves
{},
LPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
//
// B1
//
__host__
__device__
static
auto
MakeB1GridDescriptorPair
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimO
,
NumDimN
,
B1Spec
>
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
);
}
// TODO: rename to G_NRaw_KRaw
__host__
__device__
static
auto
MakeB1GridDescriptor_G_N_K
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_strides_vec
)
{
return
MakeB1GridDescriptorPair
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
).
first
;
}
__host__
__device__
static
auto
MakeB1GridDescriptor_N_K
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_strides_vec
)
{
// alias of matrix_padder.PadB1Descriptor_O_N
return
matrix_padder
.
PadB1Descriptor_N_K
(
MakeB1GridDescriptorPair
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
).
second
);
}
template
<
typename
B1GridDesc_N_K
,
typename
Number
>
__host__
__device__
static
constexpr
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
B1GridDesc_N_K
&
b1_grid_desc_n_k
,
const
Number
&
B1K1
)
{
const
auto
N
=
b1_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b1_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
B1K0
=
K
/
B1K1
;
return
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
typename
BGridDesc_N_L
,
typename
WmmaL
,
typename
NRepeat
,
typename
NWaves
,
typename
NPerWmma
,
typename
BL1
>
__host__
__device__
static
constexpr
auto
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1
(
const
BGridDesc_N_L
&
b_grid_desc_n_l
,
const
WmmaL
&
,
const
NRepeat
&
,
const
NWaves
&
,
const
NPerWmma
&
,
const
BL1
&
)
{
const
auto
N0
=
b_grid_desc_n_l
.
GetLength
(
I0
)
/
OPerBlock
;
const
auto
L
=
b_grid_desc_n_l
.
GetLength
(
I1
);
const
auto
BLWmma
=
L
/
WmmaL
{};
constexpr
auto
BLRow
=
2
;
constexpr
auto
BL0PerWmma
=
WmmaL
{}
/
BLRow
/
BL1
{};
return
transform_tensor_descriptor
(
b_grid_desc_n_l
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BLWmma
,
Number
<
BL0PerWmma
>
{},
Number
<
BLRow
>
{},
BL1
{})),
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
{},
NWaves
{},
NPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
//
// C
//
__host__
__device__
static
auto
MakeCGridDescriptorPair
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimM
,
NumDimO
,
CSpec
>
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
);
}
// TODO: rename to G_MRaw_NRaw
__host__
__device__
static
auto
MakeCGridDescriptor_G_M_N
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_strides_vec
)
{
return
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
first
;
}
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_strides_vec
)
{
return
matrix_padder
.
PadCDescriptor_M_N
(
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
second
);
}
};
}
// namespace tensor_operation
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment