Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
13587ab3
Unverified
Commit
13587ab3
authored
Oct 23, 2022
by
arai713
Committed by
GitHub
Oct 23, 2022
Browse files
Merge branch 'develop' into gridwise_2d
parents
7e44fd84
685860c2
Changes
338
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2866 additions
and
66 deletions
+2866
-66
example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp
..._grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp
+1
-1
example/42_groupnorm/groupnorm_sigmoid_fp16.cpp
example/42_groupnorm/groupnorm_sigmoid_fp16.cpp
+25
-25
example/43_splitk_gemm_bias_e_permute/CMakeLists.txt
example/43_splitk_gemm_bias_e_permute/CMakeLists.txt
+2
-0
example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp
...mm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp
+407
-0
example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp
...mm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp
+407
-0
example/44_elementwise_permute/CMakeLists.txt
example/44_elementwise_permute/CMakeLists.txt
+1
-0
example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp
...le/44_elementwise_permute/elementwise_permute_4D_fp16.cpp
+115
-0
include/ck/tensor_operation/gpu/device/device_normalization.hpp
...e/ck/tensor_operation/gpu/device/device_normalization.hpp
+9
-36
include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp
...ation/gpu/device/device_splitk_contraction_multiple_d.hpp
+65
-0
include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
...ice/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
+1147
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
...pl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
...ion/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
+683
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
...gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
...ation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
..._batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
...u/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
...ce/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
...ation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
+0
-0
No files found.
example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp
View file @
13587ab3
...
...
@@ -8,7 +8,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
...
...
example/42_groupnorm/groupnorm_sigmoid_fp16.cpp
View file @
13587ab3
...
...
@@ -9,7 +9,7 @@
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/device_
layernorm
_impl.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_
normalization
_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/fill.hpp"
...
...
@@ -47,7 +47,7 @@ struct YElementOp
};
using
DeviceInstance
=
ck
::
tensor_operation
::
device
::
Device
Layernorm
Impl
<
XDataType
,
ck
::
tensor_operation
::
device
::
Device
Normalization
Impl
<
XDataType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
...
...
@@ -55,26 +55,26 @@ using DeviceInstance =
YElementOp
,
Rank
,
NumReduceDim
,
256
,
// BlockSize
8
,
// ClusterM
32
,
// ClusterK
1024
,
// BlockSize
1
,
// ClusterM
1024
,
// ClusterK
1
,
// SliceM
8
,
// SliceK
32
,
// SliceK
1
,
// SrcVecDim (0=M, 1=K)
8
,
// SrcScalarPerVector
2
,
// SrcScalarPerVector
1
,
// GammaVecDim (0=M, 1=K)
8
,
// GammaScalarPerVector
2
,
// GammaScalarPerVector
1
,
// BetaVecDim (0=M, 1=K)
8
,
// BetaScalarPerVector
8
>
;
// OutScalarPerVector
2
,
// BetaScalarPerVector
2
>
;
// OutScalarPerVector
int
main
(
int
argc
,
char
*
argv
[])
{
ck
::
index_t
N
=
128
;
ck
::
index_t
H
=
16
;
ck
::
index_t
W
=
16
;
ck
::
index_t
N
=
2
;
ck
::
index_t
H
=
32
;
ck
::
index_t
W
=
32
;
ck
::
index_t
G
=
32
;
ck
::
index_t
C
=
4
0
;
ck
::
index_t
C
=
3
0
;
if
(
argc
==
1
)
{
...
...
example/43_splitk_gemm_bias_e_permute/CMakeLists.txt
0 → 100644
View file @
13587ab3
add_example_executable
(
example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp
)
add_example_executable
(
example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp
)
example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp
0 → 100644
View file @
13587ab3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F16
;
using
DDataType
=
F16
;
using
DsDataType
=
ck
::
Tuple
<
DDataType
>
;
using
EDataType
=
F16
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
2
;
static
constexpr
ck
::
index_t
NumDimN
=
2
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CDEElementOp
=
ck
::
tensor_operation
::
element_wise
::
Add
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
ABSpec
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Packed
;
static
constexpr
auto
DESpec
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
// clang-format off
using
DeviceOpInstanceKKNN
=
ck
::
tensor_operation
::
device
::
//############################################| NumDimG| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| Gemm| A| B| DE| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle|CBlockTransferClusterLengths| CBlockTransfer|
//############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceSplitKContractionMultipleD_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
F16
,
F16
,
F32
,
F16
,
DsDataType
,
F16
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
ABSpec
,
ABSpec
,
DESpec
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
;
// clang-format on
using
DeviceOpInstance
=
DeviceOpInstanceKKNN
;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template
<
ck
::
index_t
NumDimG
,
ck
::
index_t
NumDimM
,
ck
::
index_t
NumDimN
,
ck
::
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
EDataType
,
typename
AccDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
ck
::
enable_if_t
<
NumDimG
==
2
&&
NumDimM
==
2
&&
NumDimN
==
2
&&
NumDimK
==
1
,
bool
>
=
false
>
struct
ReferenceContraction_G2_M2_N2_K1
:
public
ck
::
tensor_operation
::
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
ck
::
tensor_operation
::
device
::
BaseArgument
{
Argument
(
const
Tensor
<
ADataType
>&
a_gs_ms_ks
,
const
Tensor
<
BDataType
>&
b_gs_ns_ks
,
Tensor
<
EDataType
>&
e_gs_ms_ns
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
a_gs_ms_ks_
{
a_gs_ms_ks
},
b_gs_ns_ks_
{
b_gs_ns_ks
},
e_gs_ms_ns_
{
e_gs_ms_ns
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
{
}
const
Tensor
<
ADataType
>&
a_gs_ms_ks_
;
const
Tensor
<
BDataType
>&
b_gs_ns_ks_
;
Tensor
<
EDataType
>&
e_gs_ms_ns_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
};
// Invoker
struct
Invoker
:
public
ck
::
tensor_operation
::
device
::
BaseInvoker
{
using
Argument
=
ReferenceContraction_G2_M2_N2_K1
::
Argument
;
float
Run
(
const
Argument
&
arg
)
{
auto
f_ms_ns
=
[
&
](
auto
g0
,
auto
g1
,
auto
m0
,
auto
m1
,
auto
n0
,
auto
n1
)
{
const
int
K0
=
arg
.
a_gs_ms_ks_
.
mDesc
.
GetLengths
()[
4
];
AccDataType
v_acc
=
0
;
for
(
int
k0
=
0
;
k0
<
K0
;
++
k0
)
{
AccDataType
v_a
;
AccDataType
v_b
;
arg
.
a_element_op_
(
v_a
,
ck
::
type_convert
<
const
AccDataType
>
(
arg
.
a_gs_ms_ks_
(
g0
,
g1
,
m0
,
m1
,
k0
)));
arg
.
b_element_op_
(
v_b
,
ck
::
type_convert
<
const
AccDataType
>
(
arg
.
b_gs_ns_ks_
(
g0
,
g1
,
n0
,
n1
,
k0
)));
v_acc
+=
v_a
*
v_b
;
}
AccDataType
v_c
;
arg
.
cde_element_op_
(
v_c
,
v_acc
);
arg
.
e_gs_ms_ns_
(
g0
,
g1
,
m0
,
m1
,
n0
,
n1
)
=
v_c
;
};
make_ParallelTensorFunctor
(
f_ms_ns
,
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
0
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
1
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
2
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
3
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
4
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
5
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
float
Run
(
const
ck
::
tensor_operation
::
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
ck
::
tensor_operation
::
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
const
Tensor
<
ADataType
>&
a_gs_ms_ks
,
const
Tensor
<
BDataType
>&
b_gs_ns_ks
,
Tensor
<
EDataType
>&
e_gs_ms_ns
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
a_gs_ms_ks
,
b_gs_ns_ks
,
e_gs_ms_ns
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
ck
::
tensor_operation
::
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceContraction_G2_M2_N2_K1"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
int
split_k
=
1
;
ck
::
index_t
G0
=
1
;
ck
::
index_t
G1
=
2
;
ck
::
index_t
M0
=
4
;
ck
::
index_t
M1
=
256
;
ck
::
index_t
N0
=
16
;
ck
::
index_t
N1
=
128
;
ck
::
index_t
K0
=
64
*
2
;
// A[G0, G1, M0, M1, K0]
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M0
,
M1
,
K0
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
{
G1
*
M0
*
M1
*
K0
,
M0
*
M1
*
K0
,
M1
*
K0
,
K0
,
1
};
// B[G0, G1, N0, N1, K0]
std
::
vector
<
ck
::
index_t
>
b_gs_ns_ks_lengths
{
G0
,
G1
,
N0
,
N1
,
K0
};
std
::
vector
<
ck
::
index_t
>
b_gs_ns_ks_strides
{
G1
*
N0
*
N1
*
K0
,
N0
*
N1
*
K0
,
N1
*
K0
,
K0
,
1
};
// D[G0, G1, M0, N0, M1, N1]
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_lengths
{
G0
,
G1
,
M0
,
M1
,
N0
,
N1
};
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_strides
{
G1
*
N0
*
N1
,
N0
*
N1
,
0
,
0
,
N1
,
1
};
// E[G0, G1, M0, N0, M1, N1]
std
::
vector
<
ck
::
index_t
>
e_gs_ms_ns_lengths
{
G0
,
G1
,
M0
,
M1
,
N0
,
N1
};
std
::
vector
<
ck
::
index_t
>
e_gs_ms_ns_strides
{
G1
*
M0
*
N0
*
M1
*
N1
,
M0
*
N0
*
M1
*
N1
,
N0
*
M1
*
N1
,
N1
,
M1
*
N1
,
1
};
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
5
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
split_k
=
std
::
stoi
(
argv
[
4
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
exit
(
0
);
}
Tensor
<
ADataType
>
a_gs_ms_ks
(
std
::
vector
<
std
::
size_t
>
(
a_gs_ms_ks_lengths
.
begin
(),
a_gs_ms_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
a_gs_ms_ks_strides
.
begin
(),
a_gs_ms_ks_strides
.
end
()));
Tensor
<
BDataType
>
b_gs_ns_ks
(
std
::
vector
<
std
::
size_t
>
(
b_gs_ns_ks_lengths
.
begin
(),
b_gs_ns_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
b_gs_ns_ks_strides
.
begin
(),
b_gs_ns_ks_strides
.
end
()));
Tensor
<
DDataType
>
d_gs_ms_ns
(
std
::
vector
<
std
::
size_t
>
(
d_gs_ms_ns_lengths
.
begin
(),
d_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
d_gs_ms_ns_strides
.
begin
(),
d_gs_ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_gs_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_strides
.
begin
(),
e_gs_ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_gs_ms_ns_device_result
(
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_strides
.
begin
(),
e_gs_ms_ns_strides
.
end
()));
std
::
cout
<<
"a_gs_ms_ks: "
<<
a_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_gs_ns_ks: "
<<
b_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_gs_ms_ns: "
<<
d_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_gs_ms_ns: "
<<
e_gs_ms_ns_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
break
;
case
2
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
default:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1
});
break
;
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_gs_ms_ns_device_result
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_gs_ms_ks
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_gs_ns_ks
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_gs_ms_ns
.
mData
.
data
());
// set zero
e_device_buf
.
SetZero
();
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
// device operation
auto
op
=
DeviceOpInstance
{};
auto
invoker
=
op
.
MakeInvoker
();
auto
argument
=
op
.
MakeArgument
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()},
e_device_buf
.
GetDeviceBuffer
(),
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_strides
},
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
,
split_k
);
if
(
!
op
.
IsSupportedArgument
(
argument
))
{
std
::
cout
<<
op
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
ck
::
index_t
G
=
std
::
accumulate
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
M
=
std
::
accumulate
(
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
,
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
N
=
std
::
accumulate
(
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
+
NumDimN
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
K
=
std
::
accumulate
(
a_gs_ms_ks_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
a_gs_ms_ks_lengths
.
begin
()
+
NumDimG
+
NumDimM
+
NumDimK
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
G
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
G
*
M
*
K
+
sizeof
(
BDataType
)
*
G
*
K
*
N
+
sizeof
(
DDataType
)
*
G
*
M
*
N
+
sizeof
(
EDataType
)
*
G
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op
.
GetTypeString
()
<<
std
::
endl
;
e_device_buf
.
FromDevice
(
e_gs_ms_ns_device_result
.
mData
.
data
());
if
(
do_verification
)
{
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_strides
.
begin
(),
e_gs_ms_ns_strides
.
end
()));
using
ReferenceOpInstance
=
ReferenceContraction_G2_M2_N2_K1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
CShuffleDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
PassThrough
>
;
auto
ref_gemm
=
ReferenceOpInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_gs_ms_ks
,
b_gs_ns_ks
,
c_ms_ns_host_result
,
a_element_op
,
b_element_op
,
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
e_gs_ms_ns_host_result
.
ForEach
([
&
](
auto
&
,
auto
idx
)
{
cde_element_op
(
e_gs_ms_ns_host_result
(
idx
),
c_ms_ns_host_result
(
idx
),
d_gs_ms_ns
(
idx
));
});
return
ck
::
utils
::
check_err
(
e_gs_ms_ns_device_result
.
mData
,
e_gs_ms_ns_host_result
.
mData
)
?
0
:
1
;
}
return
0
;
}
example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp
0 → 100644
View file @
13587ab3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
ADataType
=
F32
;
using
BDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
DDataType
=
F32
;
using
DsDataType
=
ck
::
Tuple
<
DDataType
>
;
using
EDataType
=
F32
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
2
;
static
constexpr
ck
::
index_t
NumDimN
=
2
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CDEElementOp
=
ck
::
tensor_operation
::
element_wise
::
Add
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
ABSpec
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Packed
;
static
constexpr
auto
DESpec
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
// clang-format off
using
DeviceOpInstanceKKNN
=
ck
::
tensor_operation
::
device
::
//############################################| NumDimG| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| Gemm| A| B| DE| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle|CBlockTransferClusterLengths| CBlockTransfer|
//############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceSplitKContractionMultipleD_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
ABSpec
,
ABSpec
,
DESpec
,
1
,
256
,
256
,
128
,
32
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
>
;
// clang-format on
using
DeviceOpInstance
=
DeviceOpInstanceKKNN
;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template
<
ck
::
index_t
NumDimG
,
ck
::
index_t
NumDimM
,
ck
::
index_t
NumDimN
,
ck
::
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
EDataType
,
typename
AccDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
ck
::
enable_if_t
<
NumDimG
==
2
&&
NumDimM
==
2
&&
NumDimN
==
2
&&
NumDimK
==
1
,
bool
>
=
false
>
struct
ReferenceContraction_G2_M2_N2_K1
:
public
ck
::
tensor_operation
::
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
ck
::
tensor_operation
::
device
::
BaseArgument
{
Argument
(
const
Tensor
<
ADataType
>&
a_gs_ms_ks
,
const
Tensor
<
BDataType
>&
b_gs_ns_ks
,
Tensor
<
EDataType
>&
e_gs_ms_ns
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
a_gs_ms_ks_
{
a_gs_ms_ks
},
b_gs_ns_ks_
{
b_gs_ns_ks
},
e_gs_ms_ns_
{
e_gs_ms_ns
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
{
}
const
Tensor
<
ADataType
>&
a_gs_ms_ks_
;
const
Tensor
<
BDataType
>&
b_gs_ns_ks_
;
Tensor
<
EDataType
>&
e_gs_ms_ns_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
};
// Invoker
struct
Invoker
:
public
ck
::
tensor_operation
::
device
::
BaseInvoker
{
using
Argument
=
ReferenceContraction_G2_M2_N2_K1
::
Argument
;
float
Run
(
const
Argument
&
arg
)
{
auto
f_ms_ns
=
[
&
](
auto
g0
,
auto
g1
,
auto
m0
,
auto
m1
,
auto
n0
,
auto
n1
)
{
const
int
K0
=
arg
.
a_gs_ms_ks_
.
mDesc
.
GetLengths
()[
4
];
AccDataType
v_acc
=
0
;
for
(
int
k0
=
0
;
k0
<
K0
;
++
k0
)
{
AccDataType
v_a
;
AccDataType
v_b
;
arg
.
a_element_op_
(
v_a
,
ck
::
type_convert
<
const
AccDataType
>
(
arg
.
a_gs_ms_ks_
(
g0
,
g1
,
m0
,
m1
,
k0
)));
arg
.
b_element_op_
(
v_b
,
ck
::
type_convert
<
const
AccDataType
>
(
arg
.
b_gs_ns_ks_
(
g0
,
g1
,
n0
,
n1
,
k0
)));
v_acc
+=
v_a
*
v_b
;
}
AccDataType
v_c
;
arg
.
cde_element_op_
(
v_c
,
v_acc
);
arg
.
e_gs_ms_ns_
(
g0
,
g1
,
m0
,
m1
,
n0
,
n1
)
=
v_c
;
};
make_ParallelTensorFunctor
(
f_ms_ns
,
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
0
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
1
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
2
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
3
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
4
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
5
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
float
Run
(
const
ck
::
tensor_operation
::
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
ck
::
tensor_operation
::
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
const
Tensor
<
ADataType
>&
a_gs_ms_ks
,
const
Tensor
<
BDataType
>&
b_gs_ns_ks
,
Tensor
<
EDataType
>&
e_gs_ms_ns
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
a_gs_ms_ks
,
b_gs_ns_ks
,
e_gs_ms_ns
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
ck
::
tensor_operation
::
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceContraction_G2_M2_N2_K1"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
int
split_k
=
1
;
ck
::
index_t
G0
=
1
;
ck
::
index_t
G1
=
2
;
ck
::
index_t
M0
=
4
;
ck
::
index_t
M1
=
256
;
ck
::
index_t
N0
=
16
;
ck
::
index_t
N1
=
128
;
ck
::
index_t
K0
=
64
*
2
;
// A[G0, G1, M0, M1, K0]
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M0
,
M1
,
K0
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
{
G1
*
M0
*
M1
*
K0
,
M0
*
M1
*
K0
,
M1
*
K0
,
K0
,
1
};
// B[G0, G1, N0, N1, K0]
std
::
vector
<
ck
::
index_t
>
b_gs_ns_ks_lengths
{
G0
,
G1
,
N0
,
N1
,
K0
};
std
::
vector
<
ck
::
index_t
>
b_gs_ns_ks_strides
{
G1
*
N0
*
N1
*
K0
,
N0
*
N1
*
K0
,
N1
*
K0
,
K0
,
1
};
// D[G0, G1, M0, N0, M1, N1]
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_lengths
{
G0
,
G1
,
M0
,
M1
,
N0
,
N1
};
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_strides
{
G1
*
N0
*
N1
,
N0
*
N1
,
0
,
0
,
N1
,
1
};
// E[G0, G1, M0, N0, M1, N1]
std
::
vector
<
ck
::
index_t
>
e_gs_ms_ns_lengths
{
G0
,
G1
,
M0
,
M1
,
N0
,
N1
};
std
::
vector
<
ck
::
index_t
>
e_gs_ms_ns_strides
{
G1
*
M0
*
N0
*
M1
*
N1
,
M0
*
N0
*
M1
*
N1
,
N0
*
M1
*
N1
,
N1
,
M1
*
N1
,
1
};
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
5
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
split_k
=
std
::
stoi
(
argv
[
4
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
exit
(
0
);
}
Tensor
<
ADataType
>
a_gs_ms_ks
(
std
::
vector
<
std
::
size_t
>
(
a_gs_ms_ks_lengths
.
begin
(),
a_gs_ms_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
a_gs_ms_ks_strides
.
begin
(),
a_gs_ms_ks_strides
.
end
()));
Tensor
<
BDataType
>
b_gs_ns_ks
(
std
::
vector
<
std
::
size_t
>
(
b_gs_ns_ks_lengths
.
begin
(),
b_gs_ns_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
b_gs_ns_ks_strides
.
begin
(),
b_gs_ns_ks_strides
.
end
()));
Tensor
<
DDataType
>
d_gs_ms_ns
(
std
::
vector
<
std
::
size_t
>
(
d_gs_ms_ns_lengths
.
begin
(),
d_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
d_gs_ms_ns_strides
.
begin
(),
d_gs_ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_gs_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_strides
.
begin
(),
e_gs_ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_gs_ms_ns_device_result
(
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_strides
.
begin
(),
e_gs_ms_ns_strides
.
end
()));
std
::
cout
<<
"a_gs_ms_ks: "
<<
a_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_gs_ns_ks: "
<<
b_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_gs_ms_ns: "
<<
d_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_gs_ms_ns: "
<<
e_gs_ms_ns_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
break
;
case
2
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
default:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1
});
break
;
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_gs_ms_ns_device_result
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_gs_ms_ks
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_gs_ns_ks
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_gs_ms_ns
.
mData
.
data
());
// set zero
e_device_buf
.
SetZero
();
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
// device operation
auto
op
=
DeviceOpInstance
{};
auto
invoker
=
op
.
MakeInvoker
();
auto
argument
=
op
.
MakeArgument
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()},
e_device_buf
.
GetDeviceBuffer
(),
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_strides
},
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
,
split_k
);
if
(
!
op
.
IsSupportedArgument
(
argument
))
{
std
::
cout
<<
op
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
ck
::
index_t
G
=
std
::
accumulate
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
M
=
std
::
accumulate
(
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
,
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
N
=
std
::
accumulate
(
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
+
NumDimN
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
K
=
std
::
accumulate
(
a_gs_ms_ks_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
a_gs_ms_ks_lengths
.
begin
()
+
NumDimG
+
NumDimM
+
NumDimK
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
G
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
G
*
M
*
K
+
sizeof
(
BDataType
)
*
G
*
K
*
N
+
sizeof
(
DDataType
)
*
G
*
M
*
N
+
sizeof
(
EDataType
)
*
G
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op
.
GetTypeString
()
<<
std
::
endl
;
e_device_buf
.
FromDevice
(
e_gs_ms_ns_device_result
.
mData
.
data
());
if
(
do_verification
)
{
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_strides
.
begin
(),
e_gs_ms_ns_strides
.
end
()));
using
ReferenceOpInstance
=
ReferenceContraction_G2_M2_N2_K1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
CShuffleDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
PassThrough
>
;
auto
ref_gemm
=
ReferenceOpInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_gs_ms_ks
,
b_gs_ns_ks
,
c_ms_ns_host_result
,
a_element_op
,
b_element_op
,
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
e_gs_ms_ns_host_result
.
ForEach
([
&
](
auto
&
,
auto
idx
)
{
cde_element_op
(
e_gs_ms_ns_host_result
(
idx
),
c_ms_ns_host_result
(
idx
),
d_gs_ms_ns
(
idx
));
});
return
ck
::
utils
::
check_err
(
e_gs_ms_ns_device_result
.
mData
,
e_gs_ms_ns_host_result
.
mData
)
?
0
:
1
;
}
return
0
;
}
example/44_elementwise_permute/CMakeLists.txt
0 → 100644
View file @
13587ab3
add_example_executable
(
example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp
)
example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp
0 → 100644
View file @
13587ab3
#include <iostream>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceElementwisePermuteInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwise
<
ck
::
Tuple
<
ADataType
>
,
ck
::
Tuple
<
BDataType
>
,
PassThrough
,
4
,
8
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
1
>>
;
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
Functor
>
void
host_elementwise4D
(
HostTensorB
&
B_nhwc
,
const
HostTensorA
&
A_nchw
,
Functor
functor
)
{
for
(
std
::
size_t
n
=
0
;
n
<
A_nchw
.
mDesc
.
GetLengths
()[
0
];
++
n
)
for
(
std
::
size_t
c
=
0
;
c
<
A_nchw
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_t
h
=
0
;
h
<
A_nchw
.
mDesc
.
GetLengths
()[
2
];
++
h
)
for
(
std
::
size_t
w
=
0
;
w
<
A_nchw
.
mDesc
.
GetLengths
()[
3
];
++
w
)
{
auto
a_val
=
A_nchw
(
n
,
c
,
h
,
w
);
functor
(
B_nhwc
(
n
,
h
,
w
,
c
),
a_val
);
}
}
int
main
()
{
bool
do_verification
=
true
;
bool
time_kernel
=
true
;
std
::
vector
<
std
::
size_t
>
nchw
=
{
16
,
128
,
32
,
64
};
std
::
vector
<
std
::
size_t
>
nhwc
=
{
16
,
32
,
64
,
128
};
Tensor
<
ADataType
>
a
(
nchw
);
Tensor
<
BDataType
>
b
(
nhwc
);
a
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a
.
mData
.
data
());
std
::
array
<
const
void
*
,
1
>
input
=
{
a_device_buf
.
GetDeviceBuffer
()};
std
::
array
<
void
*
,
1
>
output
=
{
b_device_buf
.
GetDeviceBuffer
()};
std
::
array
<
ck
::
index_t
,
4
>
ab_lengths
;
std
::
array
<
ck
::
index_t
,
4
>
a_strides
=
{
static_cast
<
int
>
(
nchw
[
1
]
*
nchw
[
2
]
*
nchw
[
3
]),
static_cast
<
int
>
(
nchw
[
2
]
*
nchw
[
3
]),
static_cast
<
int
>
(
nchw
[
3
]),
1
};
std
::
array
<
ck
::
index_t
,
4
>
b_strides
=
{
static_cast
<
int
>
(
nhwc
[
1
]
*
nhwc
[
2
]
*
nhwc
[
3
]),
1
,
static_cast
<
int
>
(
nhwc
[
2
]
*
nhwc
[
3
]),
static_cast
<
int
>
(
nhwc
[
3
])};
std
::
copy
(
nchw
.
begin
(),
nchw
.
end
(),
ab_lengths
.
begin
());
auto
broadcastPermute
=
DeviceElementwisePermuteInstance
{};
auto
argument
=
broadcastPermute
.
MakeArgumentPointer
(
ab_lengths
,
{
a_strides
},
{
b_strides
},
input
,
output
,
PassThrough
{});
if
(
!
broadcastPermute
.
IsSupportedArgument
(
argument
.
get
()))
{
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the device instance, exiting!"
);
};
std
::
cout
<<
"A (nchw): "
<<
a
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"B (nhwc): "
<<
b
.
mDesc
<<
std
::
endl
;
auto
broadcastPermute_invoker_ptr
=
broadcastPermute
.
MakeInvokerPointer
();
float
ave_time
=
broadcastPermute_invoker_ptr
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
nchw
[
0
]
*
nchw
[
1
]
*
nchw
[
2
]
*
nchw
[
3
];
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
(
nchw
[
0
]
*
nchw
[
1
]
*
nchw
[
2
]
*
nchw
[
3
])
+
sizeof
(
BDataType
)
*
(
nchw
[
0
]
*
nchw
[
1
]
*
nchw
[
2
]
*
nchw
[
3
]);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
{
b_device_buf
.
FromDevice
(
b
.
mData
.
data
());
Tensor
<
BDataType
>
host_b
(
nhwc
);
host_elementwise4D
(
host_b
,
a
,
PassThrough
{});
pass
&=
ck
::
utils
::
check_err
(
b
.
mData
,
host_b
.
mData
,
"Error: Incorrect results b"
,
1e-3
,
1e-3
);
}
return
pass
?
0
:
1
;
}
include/ck/tensor_operation/gpu/device/device_normalization.hpp
View file @
13587ab3
...
...
@@ -11,33 +11,6 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
struct
DeviceNormalization
:
public
BaseOperator
{
// inLengths: input tensor extent(s) from high to low dimension
// inStrides: input tensor stride(s) from high to low dimension
// reduceDims: the dimension(s) the normalization operation is applied
// alpha: typeless pointer in host memory storing the alpha scaling value of type AccDataType
// beta: typeless pointer in host memory storing the beta scaling value of type AccDataType
// in_dev: typeless const pointer in device memory storing the input tensor
// out_dev: typeless pointer in device memory storing the output tensor
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
void
*
alpha
,
const
void
*
beta
,
const
void
*
in_dev
,
void
*
out_dev
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
index_t
GetRank
()
const
=
0
;
virtual
index_t
GetNumReduceDim
()
const
=
0
;
};
using
DeviceNormalizationPtr
=
std
::
unique_ptr
<
DeviceNormalization
>
;
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
...
...
@@ -46,7 +19,7 @@ template <typename XDataType,
typename
AccElementwiseOperation
,
index_t
Rank
,
index_t
NumReduceDim
>
struct
Device
Layernorm
:
public
BaseOperator
struct
Device
Normalization
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
lengths
,
...
...
@@ -73,7 +46,7 @@ template <typename XDataType,
typename
AccElementwiseOperation
,
index_t
Rank
,
index_t
NumReduceDim
>
using
Device
Layernorm
Ptr
=
std
::
unique_ptr
<
Device
Layernorm
<
XDataType
,
using
Device
Normalization
Ptr
=
std
::
unique_ptr
<
Device
Normalization
<
XDataType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
...
...
include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp
0 → 100644
View file @
13587ab3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Tensor Contraction:
// input : A
// input : B
// input : D0, D1, ...
// output : E
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceSplitKContractionMultipleD
:
public
BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
,
index_t
split_k
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
0 → 100644
View file @
13587ab3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatDsPointer
,
typename
FloatE
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
AGridDesc_AKB_AK0_M_AK1
,
typename
BGridDesc_BKB_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
ComputePtrOffsetOfBatch
,
typename
Block2ETileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_contraction_multiple_d_xdl_cshuffle
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatDsPointer
p_ds_grid
,
FloatE
*
__restrict__
p_e_grid
,
const
index_t
batch_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
AGridDesc_AKB_AK0_M_AK1
a_grid_desc_akb_ak0_m_ak1
,
const
BGridDesc_BKB_BK0_N_BK1
b_grid_desc_bkb_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
FloatDsPointer
p_ds_grid_grp
;
static
constexpr
index_t
NumDTensor
=
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
::
Size
();
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_batch_offset
[
i
];
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_ds_grid_grp
,
p_e_grid
+
e_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_akb_ak0_m_ak1
,
b_grid_desc_bkb_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
batch_count
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
a_grid_desc_akb_ak0_m_ak1
;
ignore
=
b_grid_desc_bkb_bk0_n_bk1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_etile_map
;
ignore
=
compute_ptr_offset_of_batch
;
#endif
}
}
// namespace ck
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Tensor Contraction:
// input : A
// input : B
// input : D0, D1, ...
// output : E
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
TensorSpecialization
ASpec
,
TensorSpecialization
BSpec
,
TensorSpecialization
DESpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceSplitKContractionMultipleD_Xdl_CShuffle
:
public
DeviceSplitKContractionMultipleD
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceSplitKContractionMultipleD_Xdl_CShuffle
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
// Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
assert
(
a_gs_ms_ks_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimK
&&
a_gs_ms_ks_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimK
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
a_ms_ks_lengths
=
to_tuple
(
a_gs_ms_ks_lengths_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimK
>
{});
const
auto
a_ms_ks_strides
=
to_tuple
(
a_gs_ms_ks_strides_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimK
>
{});
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimM
,
1
>::
type
{};
// dimension Ids for K0, K1, ...
constexpr
auto
kDimIds
=
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimK
,
1
>::
type
{};
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
a_ms_ks_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
a_ms_ks_lengths
,
kDimIds
);
if
constexpr
(
ASpec
==
TensorSpecialization
::
Packed
)
{
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
K
=
container_reduce
(
kLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
a_grid_desc_mraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
a_ms_ks_strides
[
Number
<
NumDimM
-
1
>
{}],
a_ms_ks_strides
[
Number
<
NumDimM
+
NumDimK
-
1
>
{}]));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
else
{
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const
auto
a_grid_desc_ms_ks
=
make_naive_tensor_descriptor
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const
auto
a_grid_desc_mraw_kraw
=
transform_tensor_descriptor
(
a_grid_desc_ms_ks
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
mDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
}
// Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
static
auto
MakeBGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides_vec
)
{
assert
(
b_gs_ns_ks_lengths_vec
.
size
()
==
NumDimG
+
NumDimN
+
NumDimK
&&
b_gs_ns_ks_strides_vec
.
size
()
==
NumDimG
+
NumDimN
+
NumDimK
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
b_ns_ks_lengths
=
to_tuple
(
b_gs_ns_ks_lengths_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimN
+
NumDimK
>
{});
const
auto
b_ns_ks_strides
=
to_tuple
(
b_gs_ns_ks_strides_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimN
+
NumDimK
>
{});
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimN
,
1
>::
type
{};
// dimension Ids for K0, K1, ...
constexpr
auto
kDimIds
=
typename
arithmetic_sequence_gen
<
NumDimN
,
NumDimN
+
NumDimK
,
1
>::
type
{};
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
b_ns_ks_lengths
,
kDimIds
);
// lengths for N0, N1, ...
const
auto
nLengths
=
get_container_subset
(
b_ns_ks_lengths
,
nDimIds
);
if
constexpr
(
BSpec
==
TensorSpecialization
::
Packed
)
{
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
K
=
container_reduce
(
kLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
b_grid_desc_nraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
b_ns_ks_strides
[
Number
<
NumDimN
-
1
>
{}],
b_ns_ks_strides
[
Number
<
NumDimN
+
NumDimK
-
1
>
{}]));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
else
{
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
const
auto
b_grid_desc_ns_ks
=
make_naive_tensor_descriptor
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
const
auto
b_grid_desc_nraw_kraw
=
transform_tensor_descriptor
(
b_grid_desc_ns_ks
,
make_tuple
(
make_merge_transform
(
nLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
nDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
}
// assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides_vec
)
{
assert
(
e_gs_ms_ns_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
&&
e_gs_ms_ns_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
e_ms_ns_lengths
=
to_tuple
(
e_gs_ms_ns_lengths_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
e_ms_ns_strides
=
to_tuple
(
e_gs_ms_ns_strides_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
e_ms_ns_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
nLengths
=
get_container_subset
(
e_ms_ns_lengths
,
nDimIds
);
if
constexpr
(
DESpec
==
TensorSpecialization
::
Packed
)
{
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
e_grid_desc_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
e_ms_ns_strides
[
Number
<
NumDimM
-
1
>
{}],
e_ms_ns_strides
[
Number
<
NumDimM
+
NumDimN
-
1
>
{}]));
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
else
{
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
const
auto
e_grid_desc_ms_ns
=
make_naive_tensor_descriptor
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
// transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const
auto
e_grid_desc_mraw_nraw
=
transform_tensor_descriptor
(
e_grid_desc_ms_ns
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
}
// assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static
auto
MakeEGridDescriptor_G_M_N
(
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides_vec
)
{
assert
(
e_gs_ms_ns_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
&&
e_gs_ms_ns_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
e_gs_ms_ns_lengths
=
to_tuple
(
e_gs_ms_ns_lengths_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
e_gs_ms_ns_strides
=
to_tuple
(
e_gs_ms_ns_strides_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// dimension Ids for G0, G1, ...
constexpr
auto
gDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimG
,
1
>::
type
{};
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
,
NumDimG
+
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
+
NumDimM
,
NumDimG
+
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for G0, G1, ...
const
auto
gLengths
=
get_container_subset
(
e_gs_ms_ns_lengths
,
gDimIds
);
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
e_gs_ms_ns_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
nLengths
=
get_container_subset
(
e_gs_ms_ns_lengths
,
nDimIds
);
if
constexpr
(
DESpec
==
TensorSpecialization
::
Packed
)
{
auto
G
=
container_reduce
(
gLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
e_grid_desc_g_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
G
,
M
,
N
),
make_tuple
(
e_gs_ms_ns_strides
[
Number
<
NumDimG
-
1
>
{}],
e_gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
-
1
>
{}],
e_gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
+
NumDimN
-
1
>
{}]));
// return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw);
return
e_grid_desc_g_mraw_nraw
;
}
else
{
// naive tensor E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const
auto
e_grid_desc_gs_ms_ns
=
make_naive_tensor_descriptor
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
// transformed tensor E[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const
auto
e_grid_desc_g_mraw_nraw
=
transform_tensor_descriptor
(
e_grid_desc_gs_ms_ns
,
make_tuple
(
make_merge_transform
(
gLengths
),
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
gDimIds
,
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw);
return
e_grid_desc_g_mraw_nraw
;
}
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths_vec
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides_vec
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
DeviceOp
::
MakeEGridDescriptor_M_N
(
ds_gs_ms_ns_lengths_vec
[
i
],
ds_gs_ms_ns_strides_vec
[
i
]);
},
Number
<
NumDTensor
>
{});
}
static
auto
MakeDsGridDescriptor_G_M_N
(
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths_vec
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides_vec
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
DeviceOp
::
MakeEGridDescriptor_G_M_N
(
ds_gs_ms_ns_lengths_vec
[
i
],
ds_gs_ms_ns_strides_vec
[
i
]);
},
Number
<
NumDTensor
>
{});
}
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
({},
{}));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
({},
{}));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({{}},
{{}}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
({},
{}));
using
DsGridDesc_G_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_G_M_N
({},
{}))
>
;
using
EGridDesc_G_M_N
=
decltype
(
MakeEGridDescriptor_G_M_N
({},
{}));
struct
ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch
(
index_t
batch_stride_A
,
index_t
batch_stride_B
,
DsGridDesc_G_M_N
ds_grid_desc_g_m_n
,
EGridDesc_G_M_N
e_grid_desc_g_m_n
)
:
batch_stride_A_
(
batch_stride_A
),
batch_stride_B_
(
batch_stride_B
),
ds_grid_desc_g_m_n_
(
ds_grid_desc_g_m_n
),
e_grid_desc_g_m_n_
(
e_grid_desc_g_m_n
)
{
}
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
batch_stride_A_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
batch_stride_B_
);
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
{
std
::
array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_offset
[
i
]
=
static_cast
<
long_index_t
>
(
g_idx
)
*
ds_grid_desc_g_m_n_
[
i
].
CalculateOffset
(
make_multi_index
(
1
,
0
,
0
));
});
return
ds_offset
;
}
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
static_cast
<
long_index_t
>
(
g_idx
)
*
e_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
1
,
0
,
0
));
}
private:
index_t
batch_stride_A_
;
index_t
batch_stride_B_
;
DsGridDesc_G_M_N
ds_grid_desc_g_m_n_
;
EGridDesc_G_M_N
e_grid_desc_g_m_n_
;
};
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmSplitKMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_M_K
,
BGridDesc_N_K
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
// GridwiseGemm
using
GridwiseGemmAtomicAdd
=
GridwiseGemmSplitKMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
AtomicAdd
,
AGridDesc_M_K
,
BGridDesc_N_K
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
using
AGridDesc_AKB_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AKB_AK0_M_AK1
(
AGridDesc_M_K
{},
1
))
>
;
using
BGridDesc_BKB_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BKB_BK0_N_BK1
(
BGridDesc_N_K
{},
1
))
>
;
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
void
*
p_a_grid
,
const
void
*
p_b_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
void
*
p_e_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
,
index_t
split_k
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ns_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
)},
ds_grid_desc_g_m_n_
{
DeviceOp
::
MakeDsGridDescriptor_G_M_N
(
ds_gs_ms_ns_lengths
,
ds_gs_ms_ns_strides
)},
e_grid_desc_g_m_n_
{
DeviceOp
::
MakeEGridDescriptor_G_M_N
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
)},
a_grid_desc_akb_ak0_m_ak1_
{
GridwiseGemm
::
MakeDefaultAGridDescriptor_AKB_AK0_M_AK1
(
a_grid_desc_m_k_
,
split_k
)},
b_grid_desc_bkb_bk0_n_bk1_
{
GridwiseGemm
::
MakeDefaultBGridDescriptor_BKB_BK0_N_BK1
(
b_grid_desc_n_k_
,
split_k
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
,
split_k
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
a_mz_stride_
{},
a_kz_stride_
{},
b_nz_stride_
{},
b_kz_stride_
{},
ds_nz_stride_
{},
e_nz_stride_
{},
a_batch_stride_
{
a_gs_ms_ks_strides
[
NumDimG
-
1
]},
b_batch_stride_
{
b_gs_ns_ks_strides
[
NumDimG
-
1
]},
compute_ptr_offset_of_batch_
{
a_batch_stride_
,
b_batch_stride_
,
ds_grid_desc_g_m_n_
,
e_grid_desc_g_m_n_
},
split_k_
{
split_k
}
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
,
""
);
// populate pointer, batch stride, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
// D desc
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
ds_gs_ms_ns_lengths
[
i
],
ds_gs_ms_ns_strides
[
i
]);
});
// populate desc for Ds/E
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_akb_ak0_m_ak1_
,
b_grid_desc_bkb_bk0_n_bk1_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
}
// for sanity check of vector memory access
a_mz_stride_
=
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
-
1
];
a_kz_stride_
=
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
+
NumDimK
-
1
];
b_nz_stride_
=
b_gs_ns_ks_strides
[
NumDimG
+
NumDimN
-
1
];
b_kz_stride_
=
b_gs_ns_ks_strides
[
NumDimG
+
NumDimN
+
NumDimK
-
1
];
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
{
ds_nz_stride_
[
i
]
=
ds_gs_ms_ns_strides
[
i
][
NumDimG
+
NumDimM
+
NumDimN
-
1
];
}
e_nz_stride_
=
e_gs_ms_ns_strides
[
NumDimG
+
NumDimM
+
NumDimN
-
1
];
Print
();
}
void
Print
()
const
{
std
::
cout
<<
"A[M, K]: "
<<
a_grid_desc_m_k_
.
GetLength
(
I0
)
<<
", "
<<
a_grid_desc_m_k_
.
GetLength
(
I1
)
<<
std
::
endl
;
std
::
cout
<<
"B[N, K]: "
<<
b_grid_desc_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b_grid_desc_n_k_
.
GetLength
(
I1
)
<<
std
::
endl
;
std
::
cout
<<
"A[akb, ak0, m, ak1]: "
<<
a_grid_desc_akb_ak0_m_ak1_
.
GetLength
(
I0
)
<<
", "
<<
a_grid_desc_akb_ak0_m_ak1_
.
GetLength
(
I1
)
<<
", "
<<
a_grid_desc_akb_ak0_m_ak1_
.
GetLength
(
I2
)
<<
", "
<<
a_grid_desc_akb_ak0_m_ak1_
.
GetLength
(
I3
)
<<
std
::
endl
;
std
::
cout
<<
"B[bkb, bk0, n, bk1]: "
<<
b_grid_desc_bkb_bk0_n_bk1_
.
GetLength
(
I0
)
<<
", "
<<
b_grid_desc_bkb_bk0_n_bk1_
.
GetLength
(
I1
)
<<
", "
<<
b_grid_desc_bkb_bk0_n_bk1_
.
GetLength
(
I2
)
<<
", "
<<
b_grid_desc_bkb_bk0_n_bk1_
.
GetLength
(
I3
)
<<
std
::
endl
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
std
::
cout
<<
"Ds[M, N]: "
<<
ds_grid_desc_m_n_
[
i
].
GetLength
(
I0
)
<<
", "
<<
ds_grid_desc_m_n_
[
i
].
GetLength
(
I1
)
<<
std
::
endl
;
});
std
::
cout
<<
"E[M, N]: "
<<
e_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
e_grid_desc_m_n_
.
GetLength
(
I1
)
<<
std
::
endl
;
}
// private:
// pointers
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// tensor descriptors for problem definiton
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
DsGridDesc_G_M_N
ds_grid_desc_g_m_n_
;
EGridDesc_G_M_N
e_grid_desc_g_m_n_
;
// tensor descriptors for block/thread-wise copy
AGridDesc_AKB_AK0_M_AK1
a_grid_desc_akb_ak0_m_ak1_
;
BGridDesc_BKB_BK0_N_BK1
b_grid_desc_bkb_bk0_n_bk1_
;
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// Strides for the last M/N/K dimensions of A/B/Ds/E
// for sanity check of vector load/store
index_t
a_mz_stride_
;
index_t
a_kz_stride_
;
index_t
b_nz_stride_
;
index_t
b_kz_stride_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_nz_stride_
;
index_t
e_mz_stride_
;
index_t
e_nz_stride_
;
index_t
a_batch_stride_
;
index_t
b_batch_stride_
;
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
index_t
split_k_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_akb_ak0_m_ak1_
,
arg
.
b_grid_desc_bkb_bk0_n_bk1_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"
);
}
const
index_t
G
=
arg
.
e_grid_desc_g_m_n_
.
GetLength
(
I0
);
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
)
*
G
;
const
auto
K
=
arg
.
a_grid_desc_akb_ak0_m_ak1_
.
GetLength
(
I1
)
*
arg
.
a_grid_desc_akb_ak0_m_ak1_
.
GetLength
(
I3
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_contraction_multiple_d_xdl_cshuffle
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AGridDesc_AKB_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BKB_BK0_N_BK1
,
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
,
typename
GridwiseGemm
::
DefaultBlock2ETileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
G
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_grid_desc_akb_ak0_m_ak1_
,
arg
.
b_grid_desc_bkb_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_etile_map_
);
};
auto
launch_kernel_atomic_add
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_contraction_multiple_d_xdl_cshuffle
<
GridwiseGemmAtomicAdd
,
ADataType
,
// TODO: distiguish A/B datatype
typename
GridwiseGemmAtomicAdd
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AGridDesc_AKB_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BKB_BK0_N_BK1
,
typename
GridwiseGemmAtomicAdd
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemmAtomicAdd
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
,
typename
GridwiseGemmAtomicAdd
::
DefaultBlock2ETileMap
,
has_main_loop
>
;
hipGetErrorString
(
hipMemset
(
arg
.
p_e_grid_
,
0
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetElementSpaceSize
()
*
sizeof
(
EDataType
)));
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
G
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_grid_desc_akb_ak0_m_ak1_
,
arg
.
b_grid_desc_bkb_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_etile_map_
);
};
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
if
(
arg
.
split_k_
<=
1
)
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
else
return
launch_kernel_atomic_add
(
integral_constant
<
bool
,
true
>
{});
}
else
{
if
(
arg
.
split_k_
<=
1
)
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
else
return
launch_kernel_atomic_add
(
integral_constant
<
bool
,
false
>
{});
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
{
return
false
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_akb_ak0_m_ak1_
,
arg
.
b_grid_desc_bkb_bk0_n_bk1_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
return
false
;
}
// check vector access
static_assert
((
ABlockTransferSrcVectorDim
==
2
||
ABlockTransferSrcVectorDim
==
3
)
&&
(
BBlockTransferSrcVectorDim
==
2
||
BBlockTransferSrcVectorDim
==
3
),
"wrong!"
);
// vector memory access of A: could be on M or AK1 dimension
if
constexpr
(
ABlockTransferSrcVectorDim
==
2
)
{
if
(
!
(
arg
.
a_mz_stride_
==
1
&&
arg
.
a_grid_desc_akb_ak0_m_ak1_
.
GetLength
(
I2
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
if
(
!
(
arg
.
a_kz_stride_
==
1
&&
arg
.
a_grid_desc_akb_ak0_m_ak1_
.
GetLength
(
I3
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
// vector memory access of B: could be on N or BK1 dimension
if
constexpr
(
BBlockTransferSrcVectorDim
==
2
)
{
if
(
!
(
arg
.
b_nz_stride_
==
1
&&
arg
.
b_grid_desc_bkb_bk0_n_bk1_
.
GetLength
(
I2
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
if
(
!
(
arg
.
b_kz_stride_
==
1
&&
arg
.
b_grid_desc_bkb_bk0_n_bk1_
.
GetLength
(
I3
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
// vector memory access of Ds: always on NPerBlock dimension
bool
valid_d_access
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
if
(
!
(
arg
.
ds_nz_stride_
[
i
]
==
1
&&
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
i
].
GetLength
(
I3
)
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
valid_d_access
=
false
;
}
});
if
(
valid_d_access
==
false
)
{
return
false
;
}
// vector memory access of E: always on NPerBlock dimension
if
(
!
((
arg
.
e_nz_stride_
==
1
&&
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetLength
(
I3
)
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
)
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
))
{
return
false
;
}
return
true
;
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
,
index_t
split_k
)
{
return
Argument
{
p_a
,
p_b
,
p_ds
,
p_e
,
a_gs_ms_ns_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
ds_gs_ms_ns_lengths
,
ds_gs_ms_ns_strides
,
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
,
split_k
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
,
index_t
split_k
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
a_gs_ms_ns_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
ds_gs_ms_ns_lengths
,
ds_gs_ms_ns_strides
,
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
,
split_k
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceSplitKContractionMultipleD_Xdl_CShuffle"
<<
"<"
<<
NumDimG
<<
", "
<<
NumDimM
<<
", "
<<
NumDimN
<<
", "
<<
NumDimK
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
ABlockTransferSrcVectorDim
<<
", "
<<
BBlockTransferSrcVectorDim
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_batched_contraction_multiple_d_xdl_cshuffle.hpp
View file @
13587ab3
File moved
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
0 → 100644
View file @
13587ab3
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_e_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
* limitations.
*
* \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
\link
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemmCPermute and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template
<
typename
GridwiseGemm
,
typename
ABDataType
,
typename
EDataType
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
ComputePtrOffsetOfBatch
,
typename
Block2ETileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_batched_gemm_e_permute_xdl
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
index_t
batch_count
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
)));
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
ck
::
Tuple
<>
{},
p_e_grid
+
e_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
ck
::
Tuple
<>
{},
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_e_grid
;
ignore
=
batch_count
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
compute_ptr_offset_of_batch
;
ignore
=
block_2_etile_map
;
#endif
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumPrefetch
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceBatchedGemmEPermuteXdl
:
public
DeviceBatchedGemmEPermute
<
ALayout
,
BLayout
,
ELayout
,
ADataType
,
BDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceBatchedGemmEPermuteXdl
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
static
auto
MakeAGridDescriptor_M_K
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
}
}();
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
static
auto
MakeBGridDescriptor_N_K
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
static
auto
MakeEGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
stride_M
,
index_t
stride_N
)
{
const
auto
e_grid_desc_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
stride_M
,
stride_N
));
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
static
auto
MakeEGridDescriptor_G0_G1_M_N
(
index_t
G0
,
index_t
G1
,
index_t
MRaw
,
index_t
NRaw
,
index_t
stride_G0
,
index_t
stride_G1
,
index_t
stride_M
,
index_t
stride_N
)
{
const
auto
e_grid_desc_g0_g1_mraw_nraw
=
[
&
]()
{
return
make_naive_tensor_descriptor
(
make_tuple
(
G0
,
G1
,
MRaw
,
NRaw
),
make_tuple
(
stride_G0
,
stride_G1
,
stride_M
,
stride_N
));
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
e_grid_desc_g0_g1_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
G0
),
make_pass_through_transform
(
G1
),
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
e_grid_desc_g0_g1_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
G0
),
make_pass_through_transform
(
G1
),
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
e_grid_desc_g0_g1_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
G0
),
make_pass_through_transform
(
G1
),
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
}
else
{
// not pad M or N
return
e_grid_desc_g0_g1_mraw_nraw
;
}
}
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
(
1
,
1
,
1
));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
(
1
,
1
,
1
));
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
(
1
,
1
,
1
,
1
));
using
EGridDesc_G0_G1_M_N
=
decltype
(
MakeEGridDescriptor_G0_G1_M_N
(
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
));
struct
ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch
(
index_t
Batchstride_A
,
index_t
Batchstride_B
,
EGridDesc_G0_G1_M_N
e_grid_desc_g0_g1_m_n
)
:
Batchstride_A_
(
Batchstride_A
),
Batchstride_B_
(
Batchstride_B
),
e_grid_desc_g0_g1_m_n_
(
e_grid_desc_g0_g1_m_n
)
{
}
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
Batchstride_A_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
Batchstride_B_
);
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
const
index_t
G1
=
e_grid_desc_g0_g1_m_n_
.
GetLength
(
I1
);
index_t
b0
=
g_idx
/
G1
;
index_t
b1
=
g_idx
-
b0
*
G1
;
// g_idx % G1
return
e_grid_desc_g0_g1_m_n_
.
CalculateOffset
(
make_multi_index
(
b0
,
b1
,
0
,
0
));
}
private:
index_t
Batchstride_A_
;
index_t
Batchstride_B_
;
EGridDesc_G0_G1_M_N
e_grid_desc_g0_g1_m_n_
;
};
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CShuffleDataType
,
ck
::
Tuple
<>
,
// DsDataType,
EDataType
,
// EDataType,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_M_K
,
BGridDesc_N_K
,
Tuple
<>
,
EGridDesc_M_N
,
NumPrefetch
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}));
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
EDataType
*
p_e_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
index_t
batch_stride_A
,
index_t
batch_stride_B
,
BatchedGemmEPermuteDesc
batched_gemm_e_permute_desc
,
index_t
BatchCount
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_e_grid_
{
p_e_grid
},
BatchCount_
(
BatchCount
),
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
(
M
,
K
,
stride_A
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
K
,
N
,
stride_B
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
(
batched_gemm_e_permute_desc
.
M_
,
batched_gemm_e_permute_desc
.
N_
,
batched_gemm_e_permute_desc
.
stride_M_
,
batched_gemm_e_permute_desc
.
stride_N_
)},
a_grid_desc_ak0_m_ak1_
{
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
b_grid_desc_bk0_n_bk1_
{
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k_
)},
e_grid_desc_mblock_mperblock_nblock_nperblock
{},
e_grid_desc_g0_g1_m_n_
{
DeviceOp
::
MakeEGridDescriptor_G0_G1_M_N
(
batched_gemm_e_permute_desc
.
G0_
,
batched_gemm_e_permute_desc
.
G1_
,
batched_gemm_e_permute_desc
.
M_
,
batched_gemm_e_permute_desc
.
N_
,
batched_gemm_e_permute_desc
.
stride_G0_
,
batched_gemm_e_permute_desc
.
stride_G1_
,
batched_gemm_e_permute_desc
.
stride_M_
,
batched_gemm_e_permute_desc
.
stride_N_
)},
compute_ptr_offset_of_batch_
{
batch_stride_A
,
batch_stride_B
,
e_grid_desc_g0_g1_m_n_
},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k_
,
b_grid_desc_n_k_
,
ck
::
Tuple
<>
{},
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
e_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
}
}
void
Print
()
const
{
std
::
cout
<<
"A[M, K]: "
<<
a_grid_desc_m_k_
<<
std
::
endl
;
std
::
cout
<<
"B[N, K]: "
<<
b_grid_desc_n_k_
<<
std
::
endl
;
std
::
cout
<<
"C[M, N]: "
<<
e_grid_desc_m_n_
<<
std
::
endl
;
}
// private:
// pointers
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
EDataType
*
p_e_grid_
;
// batch count
index_t
BatchCount_
;
// tensor descriptors for problem definiton
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
;
EGridDesc_G0_G1_M_N
e_grid_desc_g0_g1_m_n_
;
// for calculating Batch offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
ck
::
Tuple
<>
{},
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseBatchedGemmCPermute_km_kn_m0m1n0n1_xdlops_v2r3 has invalid "
"setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
)
*
arg
.
BatchCount_
;
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_gemm_e_permute_xdl
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
EDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_AK0_M_AK1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_BK0_N_BK1
>
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
ComputePtrOffsetOfStridedBatch
,
remove_reference_t
<
Block2ETileMap
>
,
has_main_k_block_loop_
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_e_grid_
,
arg
.
BatchCount_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_etile_map_
);
};
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
ck
::
Tuple
<>
{},
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
EDataType
*
p_e
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
index_t
batch_stride_A
,
index_t
batch_stride_B
,
BatchedGemmEPermuteDesc
batched_gemm_e_permute_desc
,
index_t
BatchCount
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_e
,
M
,
N
,
K
,
stride_A
,
stride_B
,
batch_stride_A
,
batch_stride_B
,
batched_gemm_e_permute_desc
,
BatchCount
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_e
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
index_t
batch_stride_A
,
index_t
batch_stride_B
,
BatchedGemmEPermuteDesc
batched_gemm_e_permute_desc
,
index_t
BatchCount
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
EDataType
*>
(
p_e
),
M
,
N
,
K
,
stride_A
,
stride_B
,
batch_stride_A
,
batch_stride_B
,
batched_gemm_e_permute_desc
,
BatchCount
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceBatchedGemmEPermuteXdl"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_batched_gemm_gemm_xdl_cshuffle.hpp
View file @
13587ab3
File moved
include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_batched_gemm_multi_d_xdl.hpp
View file @
13587ab3
...
...
@@ -38,9 +38,9 @@ namespace device {
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
\link
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the
computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
impl/
device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
*
\link
DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the
*
computing of
pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
View file @
13587ab3
File moved
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_batched_gemm_reduce_xdl_cshuffle.hpp
View file @
13587ab3
File moved
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
13587ab3
...
...
@@ -9,10 +9,10 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
13587ab3
File moved
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_batched_gemm_xdl.hpp
View file @
13587ab3
File moved
include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
13587ab3
File moved
Prev
1
2
3
4
5
6
7
8
9
10
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment