Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2b27d5fc
Commit
2b27d5fc
authored
Jul 01, 2022
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/develop' into rosenrodt/gemm-layernorm
parents
f689a155
fa9a0a5c
Changes
137
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2125 additions
and
637 deletions
+2125
-637
example/25_gemm_bias_c_permute/gemm_bias_c_permute_xdl_fp16.cpp
...e/25_gemm_bias_c_permute/gemm_bias_c_permute_xdl_fp16.cpp
+284
-0
example/CMakeLists.txt
example/CMakeLists.txt
+1
-0
include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp
...k/tensor_operation/gpu/device/device_5ary_elementwise.hpp
+56
-42
include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
+45
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
...on/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
+195
-125
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
+12
-12
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
...tensor_operation/gpu/device/device_binary_elementwise.hpp
+24
-16
include/ck/tensor_operation/gpu/device/device_elementwise.hpp
...ude/ck/tensor_operation/gpu/device/device_elementwise.hpp
+40
-0
include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
...n/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
+193
-138
include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp
...ensor_operation/gpu/device/device_gemm_bias_c_permute.hpp
+57
-0
include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp
...r_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp
+761
-0
include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
...ude/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
+11
-68
include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
..._operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
+166
-95
include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp
...ude/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp
+44
-0
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
...ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
+2
-2
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
...operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
+18
-17
include/ck/tensor_operation/gpu/device/device_normalization.hpp
...e/ck/tensor_operation/gpu/device/device_normalization.hpp
+43
-0
include/ck/tensor_operation/gpu/device/device_softmax.hpp
include/ck/tensor_operation/gpu/device/device_softmax.hpp
+65
-21
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+9
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
...pu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
+99
-99
No files found.
example/25_gemm_bias_c_permute/gemm_bias_c_permute_xdl_fp16.cpp
0 → 100644
View file @
2b27d5fc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
DDataType
=
F16
;
using
EDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
DLayout
=
Row
;
using
ELayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
Add
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmBiasCPermute_Xdl
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
1
>
;
// clang-format on
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
ck
::
index_t
M0
=
4
;
ck
::
index_t
M1
=
32
;
ck
::
index_t
M2
=
128
;
ck
::
index_t
N0
=
16
;
ck
::
index_t
N1
=
256
;
// GEMM shape
ck
::
index_t
M
=
M0
*
M1
*
M2
;
ck
::
index_t
N
=
N0
*
N1
;
ck
::
index_t
K
=
128
;
ck
::
index_t
stride_A
=
K
;
ck
::
index_t
stride_B
=
K
;
#if 1
// E = [M0, N0, M1, N1, M2]
ck
::
index_t
stride_E_M0
=
N0
*
M1
*
N1
*
M2
;
ck
::
index_t
stride_E_M1
=
N1
*
M2
;
ck
::
index_t
stride_E_M2
=
1
;
ck
::
index_t
stride_E_N0
=
M1
*
N1
*
M2
;
ck
::
index_t
stride_E_N1
=
M2
;
// D = [0, N0, 0, N1, 0]
ck
::
index_t
stride_D_M0
=
0
;
ck
::
index_t
stride_D_M1
=
0
;
ck
::
index_t
stride_D_M2
=
0
;
ck
::
index_t
stride_D_N0
=
N1
;
ck
::
index_t
stride_D_N1
=
1
;
#else
// D = [0, 0, 0, N0, N1]
ck
::
index_t
stride_D_M0
=
0
;
ck
::
index_t
stride_D_M1
=
0
;
ck
::
index_t
stride_D_M2
=
0
;
ck
::
index_t
stride_D_N0
=
N1
;
ck
::
index_t
stride_D_N1
=
1
;
// E = [M0, M1, M2, N0, N1]
ck
::
index_t
stride_E_M0
=
M1
*
M2
*
N0
*
N1
;
ck
::
index_t
stride_E_M1
=
M2
*
N0
*
N1
;
ck
::
index_t
stride_E_M2
=
N0
*
N1
;
ck
::
index_t
stride_E_N0
=
N1
;
ck
::
index_t
stride_E_N1
=
1
;
#endif
const
ck
::
tensor_operation
::
device
::
DEGridDesc_M0_M1_M2_N0_N1
d_grid_desc
{
M0
,
M1
,
M2
,
N0
,
N1
,
stride_D_M0
,
stride_D_M1
,
stride_D_M2
,
stride_D_N0
,
stride_D_N1
};
const
ck
::
tensor_operation
::
device
::
DEGridDesc_M0_M1_M2_N0_N1
e_grid_desc
{
M0
,
M1
,
M2
,
N0
,
N1
,
stride_E_M0
,
stride_E_M1
,
stride_E_M2
,
stride_E_N0
,
stride_E_N1
};
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
exit
(
0
);
}
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
auto
f_host_de_tensor_descriptor
=
[](
ck
::
tensor_operation
::
device
::
DEGridDesc_M0_M1_M2_N0_N1
de_grid_desc
)
{
std
::
size_t
m0
=
de_grid_desc
.
M0_
;
std
::
size_t
m1
=
de_grid_desc
.
M1_
;
std
::
size_t
m2
=
de_grid_desc
.
M2_
;
std
::
size_t
n0
=
de_grid_desc
.
N0_
;
std
::
size_t
n1
=
de_grid_desc
.
N1_
;
std
::
size_t
stride_m0
=
de_grid_desc
.
stride_M0_
;
std
::
size_t
stride_m1
=
de_grid_desc
.
stride_M1_
;
std
::
size_t
stride_m2
=
de_grid_desc
.
stride_M2_
;
std
::
size_t
stride_n0
=
de_grid_desc
.
stride_N0_
;
std
::
size_t
stride_n1
=
de_grid_desc
.
stride_N1_
;
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
m0
,
m1
,
m2
,
n0
,
n1
}),
std
::
vector
<
std
::
size_t
>
({
stride_m0
,
stride_m1
,
stride_m2
,
stride_n0
,
stride_n1
}));
};
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
stride_A
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
stride_B
,
BLayout
{}));
Tensor
<
DDataType
>
d_m0_m1_m2_n0_n1
(
f_host_de_tensor_descriptor
(
d_grid_desc
));
Tensor
<
EDataType
>
e_m0_m1_m2_n0_n1_host_result
(
f_host_de_tensor_descriptor
(
e_grid_desc
));
Tensor
<
EDataType
>
e_m0_m1_m2_n0_n1_device_result
(
f_host_de_tensor_descriptor
(
e_grid_desc
));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_m0_m1_m2_n0_n1: "
<<
d_m0_m1_m2_n0_n1
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_m0_m1_m2_n0_n1: "
<<
e_m0_m1_m2_n0_n1_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d_m0_m1_m2_n0_n1
.
GenerateTensorValue
(
GeneratorTensor_2
<
DDataType
>
{
-
5
,
5
});
break
;
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d_m0_m1_m2_n0_n1
.
GenerateTensorValue
(
GeneratorTensor_3
<
DDataType
>
{
0.0
,
1.0
});
}
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
d_m0_m1_m2_n0_n1_device_buf
(
sizeof
(
DDataType
)
*
d_m0_m1_m2_n0_n1
.
mDesc
.
GetElementSpace
());
DeviceMem
e_m0_m1_m2_n0_n1_device_buf
(
sizeof
(
EDataType
)
*
e_m0_m1_m2_n0_n1_device_result
.
mDesc
.
GetElementSpace
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
d_m0_m1_m2_n0_n1_device_buf
.
ToDevice
(
d_m0_m1_m2_n0_n1
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
// do GEMM
auto
device_op
=
DeviceOpInstance
{};
auto
invoker
=
device_op
.
MakeInvoker
();
auto
argument
=
device_op
.
MakeArgument
(
a_m_k_device_buf
.
GetDeviceBuffer
(),
b_k_n_device_buf
.
GetDeviceBuffer
(),
d_m0_m1_m2_n0_n1_device_buf
.
GetDeviceBuffer
(),
e_m0_m1_m2_n0_n1_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
stride_A
,
stride_B
,
d_grid_desc
,
e_grid_desc
,
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
device_op
.
IsSupportedArgument
(
argument
))
{
throw
std
::
runtime_error
(
"wrong! this device_op instance does not support this problem"
);
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
DDataType
)
*
N
+
sizeof
(
EDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
device_op
.
GetTypeString
()
<<
std
::
endl
;
if
(
do_verification
)
{
Tensor
<
AccDataType
>
c_m_n
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
M
),
static_cast
<
std
::
size_t
>
(
N
)}));
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
AccDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
PassThrough
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n
,
a_element_op
,
b_element_op
,
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
for
(
int
m0
=
0
;
m0
<
M0
;
++
m0
)
for
(
int
m1
=
0
;
m1
<
M1
;
++
m1
)
for
(
int
m2
=
0
;
m2
<
M2
;
++
m2
)
for
(
int
n0
=
0
;
n0
<
N0
;
++
n0
)
for
(
int
n1
=
0
;
n1
<
N1
;
++
n1
)
{
int
m
=
m0
*
M1
*
M2
+
m1
*
M2
+
m2
;
int
n
=
n0
*
N1
+
n1
;
cde_element_op
(
e_m0_m1_m2_n0_n1_host_result
(
m0
,
m1
,
m2
,
n0
,
n1
),
ck
::
type_convert
<
EDataType
>
(
c_m_n
(
m
,
n
)),
d_m0_m1_m2_n0_n1
(
m0
,
m1
,
m2
,
n0
,
n1
));
}
e_m0_m1_m2_n0_n1_device_buf
.
FromDevice
(
e_m0_m1_m2_n0_n1_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
e_m0_m1_m2_n0_n1_device_result
.
mData
,
e_m0_m1_m2_n0_n1_host_result
.
mData
)
?
0
:
1
;
}
return
0
;
}
example/CMakeLists.txt
View file @
2b27d5fc
...
...
@@ -42,3 +42,4 @@ add_subdirectory(20_convnd_bwd_weight_xdl)
add_subdirectory
(
21_gemm_layernorm
)
add_subdirectory
(
22_cgemm
)
add_subdirectory
(
23_softmax
)
add_subdirectory
(
25_gemm_bias_c_permute
)
include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp
View file @
2b27d5fc
...
...
@@ -10,7 +10,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_
ba
se.hpp"
#include "ck/tensor_operation/gpu/device/device_
elementwi
se.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
...
...
@@ -35,7 +35,7 @@ template <typename ADataType,
index_t
DScalarPerVector
,
index_t
EScalarPerVector
,
index_t
FScalarPerVector
>
struct
Device5AryElementwise
:
public
BaseOpera
tor
struct
Device5AryElementwise
:
public
DeviceElementwise
<
5
,
1
,
NDim
,
ElementwiseFunc
tor
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -268,12 +268,8 @@ struct Device5AryElementwise : public BaseOperator
return
true
;
};
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
const
CDataType
*
p_c
,
const
DDataType
*
p_d
,
const
EDataType
*
p_e
,
FDataType
*
p_f
,
static
auto
MakeArgument
(
std
::
array
<
const
void
*
,
5
>
p_inputs
,
std
::
array
<
void
*
,
1
>
p_outputs
,
std
::
vector
<
index_t
>
lengths
,
std
::
vector
<
index_t
>
a_strides
,
std
::
vector
<
index_t
>
b_strides
,
...
...
@@ -283,12 +279,12 @@ struct Device5AryElementwise : public BaseOperator
std
::
vector
<
index_t
>
f_strides
,
ElementwiseFunctor
functor
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
p_d
,
p_e
,
p_f
,
return
Argument
{
static_cast
<
const
ADataType
*>
(
p_inputs
[
0
])
,
static_cast
<
const
BDataType
*>
(
p_inputs
[
1
])
,
static_cast
<
const
CDataType
*>
(
p_inputs
[
2
])
,
static_cast
<
const
DDataType
*>
(
p_inputs
[
3
])
,
static_cast
<
const
EDataType
*>
(
p_inputs
[
4
])
,
static_cast
<
FDataType
*>
(
p_outputs
[
0
])
,
lengths
,
a_strides
,
b_strides
,
...
...
@@ -299,40 +295,58 @@ struct Device5AryElementwise : public BaseOperator
functor
};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_c
,
const
void
*
p_d
,
const
void
*
p_e
,
void
*
p_f
,
std
::
vector
<
index_t
>
lengths
,
std
::
vector
<
index_t
>
a_strides
,
std
::
vector
<
index_t
>
b_strides
,
std
::
vector
<
index_t
>
c_strides
,
std
::
vector
<
index_t
>
d_strides
,
std
::
vector
<
index_t
>
e_strides
,
std
::
vector
<
index_t
>
f_strides
,
ElementwiseFunctor
functor
)
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
array
<
const
void
*
,
5
>
p_inputs
,
std
::
array
<
void
*
,
1
>
p_outputs
,
std
::
vector
<
index_t
>
lengths
,
std
::
vector
<
std
::
vector
<
index_t
>>
input_strides
,
std
::
vector
<
std
::
vector
<
index_t
>>
output_strides
,
ElementwiseFunctor
functor
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_
a
),
static_cast
<
const
BDataType
*>
(
p_
b
),
static_cast
<
const
CDataType
*>
(
p_
c
),
static_cast
<
const
DDataType
*>
(
p_
d
),
static_cast
<
const
EDataType
*>
(
p_
e
),
static_cast
<
FDataType
*>
(
p_
f
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_
inputs
[
0
]
),
static_cast
<
const
BDataType
*>
(
p_
inputs
[
1
]
),
static_cast
<
const
CDataType
*>
(
p_
inputs
[
2
]
),
static_cast
<
const
DDataType
*>
(
p_
inputs
[
3
]
),
static_cast
<
const
EDataType
*>
(
p_
inputs
[
4
]
),
static_cast
<
FDataType
*>
(
p_
outputs
[
0
]
),
lengths
,
a
_strides
,
b
_strides
,
c
_strides
,
d
_strides
,
e
_strides
,
f
_strides
,
input
_strides
[
0
]
,
input
_strides
[
1
]
,
input
_strides
[
2
]
,
input
_strides
[
3
]
,
input
_strides
[
4
]
,
output
_strides
[
0
]
,
functor
);
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
();
}
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"Device5aryElementwise"
<<
"<"
<<
"NDim = "
<<
NDim
<<
"MPerThread = "
<<
MPerThread
<<
"AScalarPerVector = "
<<
AScalarPerVector
<<
"BScalarPerVector = "
<<
BScalarPerVector
<<
"CScalarPerVector = "
<<
CScalarPerVector
<<
"DScalarPerVector = "
<<
DScalarPerVector
<<
"EScalarPerVector = "
<<
EScalarPerVector
<<
"FScalarPerVector = "
<<
FScalarPerVector
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
// namespace device
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
0 → 100644
View file @
2b27d5fc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceBatchedGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
Batch
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceBatchedGemmPtr
=
std
::
unique_ptr
<
DeviceBatchedGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
View file @
2b27d5fc
...
...
@@ -23,16 +23,16 @@ namespace device {
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Dxs
InElementwiseOperation
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
Reduce
InElementwiseOperation
s
,
typename
ReduceAccElementwiseOperation
s
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
D
GridDescriptor_MBlock_MPerBlock
,
typename
Reduce
GridDescriptor_MBlock_MPerBlock
,
typename
ComputeBasePrtOfBatch
,
typename
Block2CTileMap
,
bool
HasMainK0BlockLoop
>
...
...
@@ -44,18 +44,18 @@ __global__ void
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
D
PtrsGlobal
p_
d
s_grid
,
Reduce
PtrsGlobal
p_
reduce
s_grid
,
const
index_t
batch_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Dxs
InElementwiseOperation
dxs
_in_element_op
,
const
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op
,
const
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
,
const
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
D
GridDescriptor_MBlock_MPerBlock
d
_grid_desc_mblock_mperblock
,
const
Reduce
GridDescriptor_MBlock_MPerBlock
reduce
_grid_desc_mblock_mperblock
,
const
ComputeBasePrtOfBatch
compute_base_ptr_of_batch_
,
const
Block2CTileMap
block_2_ctile_map
)
{
...
...
@@ -71,10 +71,10 @@ __global__ void
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch_
.
GetCBasePtr
(
g_idx
)));
static_for
<
0
,
p_
d
s_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
static_for
<
0
,
p_
reduce
s_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
const
long_index_t
d_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch_
.
GetDBasePtr
(
g_idx
,
In
)));
p_
d
s_grid
(
In
)
=
p_
d
s_grid
(
In
)
+
d_batch_offset
;
p_
reduce
s_grid
(
In
)
=
p_
reduce
s_grid
(
In
)
+
d_batch_offset
;
});
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -82,36 +82,36 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_
d
s_grid
,
p_
reduce
s_grid
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
dxs
_in_element_op
,
dxs
_out_element_op
,
reduce
_in_element_op
s
,
reduce
_out_element_op
s
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
d
_grid_desc_mblock_mperblock
,
reduce
_grid_desc_mblock_mperblock
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_
d
s_grid
;
ignore
=
p_
reduce
s_grid
;
ignore
=
batch_count
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
dxs
_in_element_op
;
ignore
=
dxs
_out_element_op
;
ignore
=
reduce
_in_element_op
s
;
ignore
=
reduce
_out_element_op
s
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
d
_grid_desc_mblock_mperblock
;
ignore
=
reduce
_grid_desc_mblock_mperblock
;
ignore
=
compute_base_ptr_of_batch_
;
ignore
=
block_2_ctile_map
;
#endif
// end of if defined (defined(__gfx908__) || defined(__gfx90a__))
#endif
}
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
...
...
@@ -126,14 +126,14 @@ template <typename ALayout,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
ReduceAccDataType
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Dxs
ReduceOperation
,
typename
Dxs
InElementwiseOperation
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
D
GlobalMemoryDataOperation
,
typename
ReduceOperation
s
,
typename
Reduce
InElementwiseOperation
s
,
typename
ReduceAccElementwiseOperation
s
,
typename
Reduce
GlobalMemoryDataOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
...
...
@@ -168,12 +168,7 @@ template <typename ALayout,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceBatchedGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsReduceAccElementwiseOperation
>
struct
DeviceBatchedGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
0
,
ReduceOperations
::
Size
()
>
{
using
DeviceOp
=
DeviceBatchedGemmReduce_Xdl_CShuffle
;
...
...
@@ -446,7 +441,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
}
// assume D is packed tensor
static
auto
Make
D
GridDescriptor_M
(
index_t
MRaw
)
static
auto
Make
Reduce
GridDescriptor_M
(
index_t
MRaw
)
{
const
auto
d_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
...
...
@@ -474,7 +469,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
D
GridDesc_M
=
decltype
(
Make
D
GridDescriptor_M
(
1
));
using
Reduce
GridDesc_M
=
decltype
(
Make
Reduce
GridDescriptor_M
(
1
));
struct
ComputeBasePtrOfStridedBatch
{
...
...
@@ -527,19 +522,19 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
CShuffleDataType
,
CDataType
,
ReduceAccDataType
,
D
PtrsGlobal
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
Dxs
ReduceOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
ReduceOperation
s
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
InMemoryDataOperationEnum
::
Set
,
D
GlobalMemoryDataOperation
,
Reduce
GlobalMemoryDataOperation
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
D
GridDesc_M
,
Reduce
GridDesc_M
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
...
...
@@ -582,7 +577,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
D
PtrsGlobal
p_
d
s_grid
,
Reduce
PtrsGlobal
p_
reduce
s_grid
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
...
...
@@ -592,31 +587,31 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
Dxs
InElementwiseOperation
dxs
_in_element_op
,
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op
,
index_t
Batch
Count
)
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
,
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
,
index_t
Batch
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_
d
s_grid_
{
p_
d
s_grid
},
Batch
Count
_
(
Batch
Count
),
p_
reduce
s_grid_
{
p_
reduce
s_grid
},
Batch_
(
Batch
),
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
)},
d
_grid_desc_m_
{
DeviceOp
::
Make
D
GridDescriptor_M
(
MRaw
)},
reduce
_grid_desc_m_
{
DeviceOp
::
Make
Reduce
GridDescriptor_M
(
MRaw
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
d
_grid_desc_mblock_mperblock_
{},
reduce
_grid_desc_mblock_mperblock_
{},
compute_base_ptr_of_batch_
{
type_convert
<
index_t
>
(
a_grid_desc_ak0_m_ak1_
.
GetElementSpaceSize
()),
type_convert
<
index_t
>
(
b_grid_desc_bk0_n_bk1_
.
GetElementSpaceSize
()),
type_convert
<
index_t
>
(
c_grid_desc_m_n_
.
GetElementSpaceSize
()),
type_convert
<
index_t
>
(
d
_grid_desc_m_
.
GetElementSpaceSize
())},
type_convert
<
index_t
>
(
reduce
_grid_desc_m_
.
GetElementSpaceSize
())},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
},
dxs
_in_element_op_
{
dxs
_in_element_op
},
dxs
_out_element_op_
{
dxs
_out_element_op
}
reduce
_in_element_op
s
_
{
reduce
_in_element_op
s
},
reduce
_out_element_op
s
_
{
reduce
_out_element_op
s
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
...
...
@@ -627,8 +622,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
d
_grid_desc_mblock_mperblock_
=
GridwiseGemm
::
Make
D
GridDescriptor_MBlock_MPerBlock
(
d
_grid_desc_m_
);
reduce
_grid_desc_mblock_mperblock_
=
GridwiseGemm
::
Make
Reduce
GridDescriptor_MBlock_MPerBlock
(
reduce
_grid_desc_m_
);
}
}
...
...
@@ -636,22 +631,23 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
D
PtrsGlobal
p_
d
s_grid_
;
index_t
Batch
Count
_
;
Reduce
PtrsGlobal
p_
reduce
s_grid_
;
index_t
Batch_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
D
GridDesc_M
d
_grid_desc_m_
;
Reduce
GridDesc_M
reduce
_grid_desc_m_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DGridDescriptor_MBlock_MPerBlock
d_grid_desc_mblock_mperblock_
;
typename
GridwiseGemm
::
ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
Dxs
InElementwiseOperation
dxs
_in_element_op_
;
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op_
;
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
_
;
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
_
;
};
// Invoker
...
...
@@ -663,7 +659,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
{
#if 0
{
std::cout << "arg.Batch
Count
_ = " << arg.Batch
Count
_ << std::endl;
std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl;
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
...
...
@@ -678,7 +674,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.
d
_grid_desc_m_{ " << arg.
d
_grid_desc_m_.GetLength(I0) << "}"
std::cout << "arg.
reduce
_grid_desc_m_{ " << arg.
reduce
_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
}
#endif
...
...
@@ -692,7 +688,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
Batch
Count
_
;
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
Batch_
;
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
...
...
@@ -704,16 +700,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
D
PtrsGlobal
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
D
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
Reduce
GridDescriptor_MBlock_MPerBlock
,
ComputeBasePtrOfStridedBatch
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
true
>
;
...
...
@@ -727,17 +723,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_
d
s_grid_
,
arg
.
Batch
Count
_
,
arg
.
p_
reduce
s_grid_
,
arg
.
Batch_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
dxs
_in_element_op_
,
arg
.
dxs
_out_element_op_
,
arg
.
reduce
_in_element_op
s
_
,
arg
.
reduce
_out_element_op
s
_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d
_grid_desc_mblock_mperblock_
,
arg
.
reduce
_grid_desc_mblock_mperblock_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
block_2_ctile_map_
);
}
...
...
@@ -747,16 +743,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
D
PtrsGlobal
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
D
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
Reduce
GridDescriptor_MBlock_MPerBlock
,
ComputeBasePtrOfStridedBatch
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
false
>
;
...
...
@@ -770,17 +766,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_
d
s_grid_
,
arg
.
Batch
Count
_
,
arg
.
p_
reduce
s_grid_
,
arg
.
Batch_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
dxs
_in_element_op_
,
arg
.
dxs
_out_element_op_
,
arg
.
reduce
_in_element_op
s
_
,
arg
.
reduce
_out_element_op
s
_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d
_grid_desc_mblock_mperblock_
,
arg
.
reduce
_grid_desc_mblock_mperblock_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
block_2_ctile_map_
);
}
...
...
@@ -824,39 +820,77 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
}
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
DPtrsGlobal
p_dxs
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsReduceAccElementwiseOperation
dxs_out_element_op
,
index_t
BatchCount
)
static
constexpr
int
NumReduce
=
ReduceOperations
::
Size
();
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_bias
,
std
::
array
<
const
void
*
,
0
>
p_ds
,
void
*
p_c
,
std
::
array
<
void
*
,
NumReduce
>
p_reduces
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
std
::
array
<
ck
::
index_t
,
0
>
StrideDs
,
std
::
array
<
void
*
,
3
>
gemm_element_ops
,
std
::
array
<
void
*
,
0
>
d_element_ops
,
std
::
array
<
void
*
,
NumReduce
>
reduce_in_element_op
,
std
::
array
<
void
*
,
NumReduce
>
reduce_out_element_op
,
index_t
Batch
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
p_dxs
,
MRaw
,
NRaw
,
KRaw
,
(
void
)
p_bias
;
(
void
)
p_ds
;
(
void
)
StrideDs
;
(
void
)
d_element_ops
;
ReducePtrsGlobal
reduce_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReducePtrsGlobal
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
static_cast
<
T
*>
(
p_reduces
[
I
]);
},
Number
<
NumReduce
>
{});
ReduceInElementwiseOperations
reduce_in_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceInElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_in_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
ReduceAccElementwiseOperations
reduce_out_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceAccElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_out_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
AElementwiseOperation
a_element_op
=
*
(
static_cast
<
AElementwiseOperation
*>
(
gemm_element_ops
[
0
]));
BElementwiseOperation
b_element_op
=
*
(
static_cast
<
BElementwiseOperation
*>
(
gemm_element_ops
[
1
]));
CElementwiseOperation
c_element_op
=
*
(
static_cast
<
CElementwiseOperation
*>
(
gemm_element_ops
[
2
]));
return
Argument
{
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
reduce_tuple
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
dxs
_in_element_op
,
dxs
_out_element_op
,
Batch
Count
};
reduce
_in_element_op
s
,
reduce
_out_element_op
s
,
Batch
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -865,38 +899,74 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_bias
,
std
::
array
<
const
void
*
,
0
>
p_ds
,
void
*
p_c
,
void
*
p_dx
s
,
index_t
M
Raw
,
index_t
N
Raw
,
index_t
K
Raw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b
_element_op
,
CElementwiseOperation
c
_element_op
,
DxsInElementwiseOperation
dxs
_in_element_op
,
DxsReduceAccElementwiseOperation
dxs
_out_element_op
,
index_t
Batch
Count
)
override
std
::
array
<
void
*
,
NumReduce
>
p_reduce
s
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
std
::
array
<
ck
::
index_t
,
0
>
StrideDs
,
std
::
array
<
void
*
,
3
>
gemm
_element_op
s
,
std
::
array
<
void
*
,
0
>
d
_element_op
s
,
std
::
array
<
void
*
,
NumReduce
>
reduce
_in_element_op
,
std
::
array
<
void
*
,
NumReduce
>
reduce
_out_element_op
,
index_t
Batch
=
1
)
override
{
DPtrsGlobal
dxs_tuple
=
*
(
static_cast
<
DPtrsGlobal
*>
(
p_dxs
));
(
void
)
p_bias
;
(
void
)
p_ds
;
(
void
)
StrideDs
;
(
void
)
d_element_ops
;
ReducePtrsGlobal
reduce_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReducePtrsGlobal
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
static_cast
<
T
*>
(
p_reduces
[
I
]);
},
Number
<
NumReduce
>
{});
ReduceInElementwiseOperations
reduce_in_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceInElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_in_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
ReduceAccElementwiseOperations
reduce_out_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceAccElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_out_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
AElementwiseOperation
a_element_op
=
*
(
static_cast
<
AElementwiseOperation
*>
(
gemm_element_ops
[
0
]));
BElementwiseOperation
b_element_op
=
*
(
static_cast
<
BElementwiseOperation
*>
(
gemm_element_ops
[
1
]));
CElementwiseOperation
c_element_op
=
*
(
static_cast
<
CElementwiseOperation
*>
(
gemm_element_ops
[
2
]));
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
dxs
_tuple
,
M
Raw
,
N
Raw
,
K
Raw
,
reduce
_tuple
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
dxs
_in_element_op
,
dxs
_out_element_op
,
Batch
Count
);
reduce
_in_element_op
s
,
reduce
_out_element_op
s
,
Batch
);
}
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
View file @
2b27d5fc
...
...
@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_
batched_
gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
#include "ck/device_utility/device_prop.hpp"
...
...
@@ -152,7 +152,7 @@ template <typename ADataType,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
>
struct
DeviceBatchedGemmXdl
:
public
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
:
public
Device
Batched
Gemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -339,11 +339,11 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
Batch
Count
)
index_t
Batch
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
Batch
Count
_
(
Batch
Count
),
Batch_
(
Batch
),
a_grid_desc_k0_m_k1_
{
DeviceBatchedGemmXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
)},
b_grid_desc_k0_n_k1_
{
...
...
@@ -376,7 +376,7 @@ struct DeviceBatchedGemmXdl
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
index_t
Batch
Count
_
;
index_t
Batch_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
...
...
@@ -420,7 +420,7 @@ struct DeviceBatchedGemmXdl
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
Batch
Count
_
;
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
Batch_
;
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
...
...
@@ -451,7 +451,7 @@ struct DeviceBatchedGemmXdl
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
Batch
Count
_
,
arg
.
Batch_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
...
...
@@ -485,7 +485,7 @@ struct DeviceBatchedGemmXdl
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
Batch
Count
_
,
arg
.
Batch_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
...
...
@@ -539,7 +539,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
Batch
Count
)
index_t
Batch
)
{
return
Argument
{
p_a
,
p_b
,
...
...
@@ -555,7 +555,7 @@ struct DeviceBatchedGemmXdl
a_element_op
,
b_element_op
,
c_element_op
,
Batch
Count
};
Batch
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -573,7 +573,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
Batch
Count
)
override
index_t
Batch
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
...
...
@@ -589,7 +589,7 @@ struct DeviceBatchedGemmXdl
a_element_op
,
b_element_op
,
c_element_op
,
Batch
Count
);
Batch
);
}
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
View file @
2b27d5fc
...
...
@@ -9,6 +9,7 @@
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
namespace
ck
{
...
...
@@ -25,7 +26,7 @@ template <typename ADataType,
index_t
AScalarPerVector
,
index_t
BScalarPerVector
,
index_t
CScalarPerVector
>
struct
DeviceBinaryElementwise
:
public
BaseOpera
tor
struct
DeviceBinaryElementwise
:
public
DeviceElementwise
<
2
,
1
,
NDim
,
ElementwiseFunc
tor
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -198,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator
return
true
;
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
std
::
vector
<
index_t
>
lengths
,
std
::
vector
<
index_t
>
a_strides
,
std
::
vector
<
index_t
>
b_strides
,
std
::
vector
<
index_t
>
c_strides
,
ElementwiseFunctor
functor
)
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
array
<
const
void
*
,
2
>
p_inputs
,
std
::
array
<
void
*
,
1
>
p_outputs
,
std
::
vector
<
index_t
>
lengths
,
std
::
vector
<
std
::
vector
<
index_t
>>
input_strides
,
std
::
vector
<
std
::
vector
<
index_t
>>
output_strides
,
ElementwiseFunctor
functor
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_
a
),
static_cast
<
const
BDataType
*>
(
p_
b
),
static_cast
<
CDataType
*>
(
p_
c
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_
inputs
[
0
]
),
static_cast
<
const
BDataType
*>
(
p_
inputs
[
1
]
),
static_cast
<
CDataType
*>
(
p_
outputs
[
0
]
),
lengths
,
a
_strides
,
b
_strides
,
c
_strides
,
input
_strides
[
0
]
,
input
_strides
[
1
]
,
output
_strides
[
0
]
,
functor
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
();
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
...
...
@@ -226,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator
// clang-format off
str
<<
"DeviceBinaryElementwise"
<<
"<"
<<
"NDim = "
<<
NDim
<<
"MPerThread = "
<<
MPerThread
<<
"AScalarPerVector = "
<<
AScalarPerVector
<<
"BScalarPerVector = "
<<
BScalarPerVector
<<
"CScalarPerVector = "
<<
CScalarPerVector
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/device_elementwise.hpp
0 → 100644
View file @
2b27d5fc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
ck
::
index_t
NumInputTensor
,
ck
::
index_t
NumOutputTensor
,
index_t
NDim
,
typename
ElementwiseFunctor
>
struct
DeviceElementwise
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
array
<
const
void
*
,
NumInputTensor
>
p_inputs
,
std
::
array
<
void
*
,
NumOutputTensor
>
p_outputs
,
std
::
vector
<
index_t
>
lengths
,
std
::
vector
<
std
::
vector
<
index_t
>>
input_strides
,
std
::
vector
<
std
::
vector
<
index_t
>>
output_strides
,
ElementwiseFunctor
functor
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
ck
::
index_t
NumInputTensor
,
ck
::
index_t
NumOutputTensor
,
index_t
NDim
,
typename
ElementwiseFunctor
>
using
DeviceElementwisePtr
=
std
::
unique_ptr
<
DeviceElementwise
<
NumInputTensor
,
NumOutputTensor
,
NDim
,
ElementwiseFunctor
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
View file @
2b27d5fc
...
...
@@ -29,20 +29,20 @@ template <typename ALayout,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
C0
DataType
,
typename
C1
DataType
,
typename
Bias
DataType
,
typename
D0
DataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
ReduceAccDataType
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1
ElementwiseOperation
,
typename
Dxs
ReduceOperation
,
typename
Dxs
InElementwiseOperation
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
D
GlobalMemoryDataOperation
,
typename
D0
ElementwiseOperation
,
typename
ReduceOperation
s
,
typename
Reduce
InElementwiseOperation
s
,
typename
ReduceAccElementwiseOperation
s
,
typename
Reduce
GlobalMemoryDataOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
...
...
@@ -77,13 +77,7 @@ template <typename ALayout,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmBiasAddReduce_Xdl_CShuffle
:
public
DeviceGemmBiasAddReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C1ElementwiseOperation
,
DxsInElementwiseOperation
,
DxsReduceAccElementwiseOperation
>
struct
DeviceGemmBiasAddReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
1
,
ReduceOperations
::
Size
()
>
{
using
DeviceOp
=
DeviceGemmBiasAddReduce_Xdl_CShuffle
;
...
...
@@ -356,7 +350,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
}
// assume D is packed tensor
static
auto
Make
D
GridDescriptor_M
(
index_t
MRaw
)
static
auto
Make
Reduce
GridDescriptor_M
(
index_t
MRaw
)
{
const
auto
d_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
...
...
@@ -386,7 +380,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
C0GridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
0
));
using
C1GridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
D
GridDesc_M
=
decltype
(
Make
D
GridDescriptor_M
(
1
));
using
Reduce
GridDesc_M
=
decltype
(
Make
Reduce
GridDescriptor_M
(
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
...
...
@@ -394,25 +388,25 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
C0
DataType
,
C1
DataType
,
Bias
DataType
,
D0
DataType
,
ReduceAccDataType
,
D
PtrsGlobal
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C1
ElementwiseOperation
,
Dxs
ReduceOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
D0
ElementwiseOperation
,
ReduceOperation
s
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
InMemoryDataOperationEnum
::
Set
,
D
GlobalMemoryDataOperation
,
Reduce
GlobalMemoryDataOperation
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
C0GridDesc_M_N
,
C1GridDesc_M_N
,
D
GridDesc_M
,
Reduce
GridDesc_M
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
...
...
@@ -455,9 +449,9 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
const
C0
DataType
*
p_
c0
_grid
,
const
C1
DataType
*
p_
c1
_grid
,
D
PtrsGlobal
p_
d
s_grid
,
const
Bias
DataType
*
p_
bias
_grid
,
const
D0
DataType
*
p_
d0
_grid
,
Reduce
PtrsGlobal
p_
reduce
s_grid
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
...
...
@@ -468,32 +462,32 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
C1
ElementwiseOperation
c1
_element_op
,
Dxs
InElementwiseOperation
dxs
_in_element_op
,
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op
)
D0
ElementwiseOperation
d0
_element_op
,
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
,
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_
c0
_grid_
{
p_
c0
_grid
},
p_
c1
_grid_
{
p_
c1
_grid
},
p_
d
s_grid_
{
p_
d
s_grid
},
p_
bias
_grid_
{
p_
bias
_grid
},
p_
d0
_grid_
{
p_
d0
_grid
},
p_
reduce
s_grid_
{
p_
reduce
s_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
)},
c0_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
0
)},
c1_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC1
)},
d
_grid_desc_m_
{
DeviceOp
::
Make
D
GridDescriptor_M
(
MRaw
)},
reduce
_grid_desc_m_
{
DeviceOp
::
Make
Reduce
GridDescriptor_M
(
MRaw
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c0_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c1_grid_desc_mblock_mperblock_nblock_nperblock_
{},
d
_grid_desc_mblock_mperblock_
{},
reduce
_grid_desc_mblock_mperblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
},
c1
_element_op_
{
c1
_element_op
},
dxs
_in_element_op_
{
dxs
_in_element_op
},
dxs
_out_element_op_
{
dxs
_out_element_op
}
d0
_element_op_
{
d0
_element_op
},
reduce
_in_element_op
s
_
{
reduce
_in_element_op
s
},
reduce
_out_element_op
s
_
{
reduce
_out_element_op
s
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
...
...
@@ -512,8 +506,8 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c1_grid_desc_m_n_
);
d
_grid_desc_mblock_mperblock_
=
GridwiseGemm
::
Make
D
GridDescriptor_MBlock_MPerBlock
(
d
_grid_desc_m_
);
reduce
_grid_desc_mblock_mperblock_
=
GridwiseGemm
::
Make
Reduce
GridDescriptor_MBlock_MPerBlock
(
reduce
_grid_desc_m_
);
}
}
...
...
@@ -521,29 +515,30 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
const
C0
DataType
*
p_
c0
_grid_
;
const
C1
DataType
*
p_
c1
_grid_
;
D
PtrsGlobal
p_
d
s_grid_
;
const
Bias
DataType
*
p_
bias
_grid_
;
const
D0
DataType
*
p_
d0
_grid_
;
Reduce
PtrsGlobal
p_
reduce
s_grid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
C0GridDesc_M_N
c0_grid_desc_m_n_
;
C1GridDesc_M_N
c1_grid_desc_m_n_
;
D
GridDesc_M
d
_grid_desc_m_
;
Reduce
GridDesc_M
reduce
_grid_desc_m_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c0_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DGridDescriptor_MBlock_MPerBlock
d_grid_desc_mblock_mperblock_
;
typename
GridwiseGemm
::
ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
C1
ElementwiseOperation
c1
_element_op_
;
Dxs
InElementwiseOperation
dxs
_in_element_op_
;
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op_
;
D0
ElementwiseOperation
d0
_element_op_
;
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
_
;
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
_
;
};
// Invoker
...
...
@@ -574,21 +569,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
C0
DataType
,
C1
DataType
,
D
PtrsGlobal
,
Bias
DataType
,
D0
DataType
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C1
ElementwiseOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
D0
ElementwiseOperation
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
D
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
Reduce
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
true
>
;
...
...
@@ -601,21 +596,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_
c0
_grid_
,
arg
.
p_
c1
_grid_
,
arg
.
p_
d
s_grid_
,
arg
.
p_
bias
_grid_
,
arg
.
p_
d0
_grid_
,
arg
.
p_
reduce
s_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c1
_element_op_
,
arg
.
dxs
_in_element_op_
,
arg
.
dxs
_out_element_op_
,
arg
.
d0
_element_op_
,
arg
.
reduce
_in_element_op
s
_
,
arg
.
reduce
_out_element_op
s
_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c0_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c1_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d
_grid_desc_mblock_mperblock_
,
arg
.
reduce
_grid_desc_mblock_mperblock_
,
arg
.
block_2_ctile_map_
);
}
else
...
...
@@ -624,21 +619,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
C0
DataType
,
C1
DataType
,
D
PtrsGlobal
,
Bias
DataType
,
D0
DataType
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C1
ElementwiseOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
D0
ElementwiseOperation
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
D
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
Reduce
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
false
>
;
...
...
@@ -651,21 +646,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_
c0
_grid_
,
arg
.
p_
c1
_grid_
,
arg
.
p_
d
s_grid_
,
arg
.
p_
bias
_grid_
,
arg
.
p_
d0
_grid_
,
arg
.
p_
reduce
s_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c1
_element_op_
,
arg
.
dxs
_in_element_op_
,
arg
.
dxs
_out_element_op_
,
arg
.
d0
_element_op_
,
arg
.
reduce
_in_element_op
s
_
,
arg
.
reduce
_out_element_op
s
_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c0_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c1_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d
_grid_desc_mblock_mperblock_
,
arg
.
reduce
_grid_desc_mblock_mperblock_
,
arg
.
block_2_ctile_map_
);
}
...
...
@@ -700,45 +695,76 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
const
C0DataType
*
p_c0
,
const
C1DataType
*
p_c1
,
DPtrsGlobal
p_dxs
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC1
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
C1ElementwiseOperation
c1_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsReduceAccElementwiseOperation
dxs_out_element_op
)
static
constexpr
int
NumReduce
=
ReduceOperations
::
Size
();
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_bias
,
std
::
array
<
const
void
*
,
1
>
p_ds
,
void
*
p_c
,
std
::
array
<
void
*
,
NumReduce
>
p_reduces
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
std
::
array
<
ck
::
index_t
,
1
>
StrideDs
,
std
::
array
<
void
*
,
3
>
gemm_element_ops
,
std
::
array
<
void
*
,
1
>
d_element_ops
,
std
::
array
<
void
*
,
NumReduce
>
reduce_in_element_op
,
std
::
array
<
void
*
,
NumReduce
>
reduce_out_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
p_c0
,
p_c1
,
p_dxs
,
MRaw
,
NRaw
,
KRaw
,
ReducePtrsGlobal
reduce_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReducePtrsGlobal
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
static_cast
<
T
*>
(
p_reduces
[
I
]);
},
Number
<
NumReduce
>
{});
ReduceInElementwiseOperations
reduce_in_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceInElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_in_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
ReduceAccElementwiseOperations
reduce_out_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceAccElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_out_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
AElementwiseOperation
a_element_op
=
*
(
static_cast
<
AElementwiseOperation
*>
(
gemm_element_ops
[
0
]));
BElementwiseOperation
b_element_op
=
*
(
static_cast
<
BElementwiseOperation
*>
(
gemm_element_ops
[
1
]));
CElementwiseOperation
c_element_op
=
*
(
static_cast
<
CElementwiseOperation
*>
(
gemm_element_ops
[
2
]));
D0ElementwiseOperation
d_element_op
=
*
(
static_cast
<
D0ElementwiseOperation
*>
(
d_element_ops
[
0
]));
return
Argument
{
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
const
BiasDataType
*>
(
p_bias
),
static_cast
<
const
D0DataType
*>
(
p_ds
[
0
]),
reduce_tuple
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
Stride
C1
,
Stride
Ds
[
0
]
,
a_element_op
,
b_element_op
,
c_element_op
,
c1
_element_op
,
dxs
_in_element_op
,
dxs
_out_element_op
};
d
_element_op
,
reduce
_in_element_op
s
,
reduce
_out_element_op
s
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -747,45 +773,74 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_bias
,
std
::
array
<
const
void
*
,
1
>
p_ds
,
void
*
p_c
,
const
void
*
p_c0
,
const
void
*
p_c1
,
void
*
p_dxs
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC1
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
C1ElementwiseOperation
c1_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsReduceAccElementwiseOperation
dxs_out_element_op
,
std
::
array
<
void
*
,
NumReduce
>
p_reduces
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
std
::
array
<
ck
::
index_t
,
1
>
StrideDs
,
std
::
array
<
void
*
,
3
>
gemm_element_ops
,
std
::
array
<
void
*
,
1
>
d_element_ops
,
std
::
array
<
void
*
,
NumReduce
>
reduce_in_element_op
,
std
::
array
<
void
*
,
NumReduce
>
reduce_out_element_op
,
index_t
/* KBatch */
=
1
)
override
{
DPtrsGlobal
dxs_tuple
=
*
(
static_cast
<
DPtrsGlobal
*>
(
p_dxs
));
ReducePtrsGlobal
reduce_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReducePtrsGlobal
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
static_cast
<
T
*>
(
p_reduces
[
I
]);
},
Number
<
NumReduce
>
{});
ReduceInElementwiseOperations
reduce_in_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceInElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_in_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
ReduceAccElementwiseOperations
reduce_out_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceAccElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_out_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
AElementwiseOperation
a_element_op
=
*
(
static_cast
<
AElementwiseOperation
*>
(
gemm_element_ops
[
0
]));
BElementwiseOperation
b_element_op
=
*
(
static_cast
<
BElementwiseOperation
*>
(
gemm_element_ops
[
1
]));
CElementwiseOperation
c_element_op
=
*
(
static_cast
<
CElementwiseOperation
*>
(
gemm_element_ops
[
2
]));
D0ElementwiseOperation
d_element_op
=
*
(
static_cast
<
D0ElementwiseOperation
*>
(
d_element_ops
[
0
]));
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
const
C0
DataType
*>
(
p_
c0
),
static_cast
<
const
C1
DataType
*>
(
p_
c1
),
dxs
_tuple
,
M
Raw
,
N
Raw
,
K
Raw
,
static_cast
<
const
Bias
DataType
*>
(
p_
bias
),
static_cast
<
const
D0
DataType
*>
(
p_
ds
[
0
]
),
reduce
_tuple
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
Stride
C1
,
Stride
Ds
[
0
]
,
a_element_op
,
b_element_op
,
c_element_op
,
c1
_element_op
,
dxs
_in_element_op
,
dxs
_out_element_op
);
d
_element_op
,
reduce
_in_element_op
s
,
reduce
_out_element_op
s
);
}
// polymorphic
...
...
@@ -800,7 +855,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGemmReduce_Xdl_CShuffle"
str
<<
"DeviceGemm
BiasAdd
Reduce_Xdl_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp
0 → 100644
View file @
2b27d5fc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
struct
DEGridDesc_M0_M1_M2_N0_N1
{
ck
::
index_t
M0_
,
M1_
,
M2_
,
N0_
,
N1_
;
ck
::
index_t
stride_M0_
,
stride_M1_
,
stride_M2_
,
stride_N0_
,
stride_N1_
;
};
// input : A[M, K], B[K, N],
// input : D[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D)
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceGemmBiasCPermute
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_d
,
void
*
p_e
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
DEGridDesc_M0_M1_M2_N0_N1
d_gride_desc
,
DEGridDesc_M0_M1_M2_N0_N1
e_gride_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGemmBiasCPermutePtr
=
std
::
unique_ptr
<
DeviceGemmBiasCPermute
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp
0 → 100644
View file @
2b27d5fc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatDsPointer
,
typename
FloatE
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_bias_c_permute
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatDsPointer
p_ds_grid
,
FloatE
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_etile_map
;
#endif
}
}
// namespace ck
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// input : A[M, K], or A[K, N]
// input : B[K, N], or A[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
template
<
typename
ALayout
,
typename
BLayout
,
typename
CDELayout
,
typename
ADataType
,
typename
BDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
DDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmBiasCPermute_Xdl
:
public
DeviceGemmBiasCPermute
<
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceGemmBiasCPermute_Xdl
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
index_t
NumDTensor
=
I1
;
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
}
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
static
auto
MakeEGridDescriptor_M_N
(
DEGridDesc_M0_M1_M2_N0_N1
d_e_grid_desc
)
{
index_t
M0
=
d_e_grid_desc
.
M0_
;
index_t
M1
=
d_e_grid_desc
.
M1_
;
index_t
M2
=
d_e_grid_desc
.
M2_
;
index_t
N0
=
d_e_grid_desc
.
N0_
;
index_t
N1
=
d_e_grid_desc
.
N1_
;
index_t
stride_M0
=
d_e_grid_desc
.
stride_M0_
;
index_t
stride_M1
=
d_e_grid_desc
.
stride_M1_
;
index_t
stride_M2
=
d_e_grid_desc
.
stride_M2_
;
index_t
stride_N0
=
d_e_grid_desc
.
stride_N0_
;
index_t
stride_N1
=
d_e_grid_desc
.
stride_N1_
;
const
auto
MRaw
=
M0
*
M1
*
M2
;
const
auto
NRaw
=
N0
*
N1
;
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
const
auto
c_grid_desc_m0_m1_m2_n0_n1
=
make_naive_tensor_descriptor
(
make_tuple
(
M0
,
M1
,
M2
,
N0
,
N1
),
make_tuple
(
stride_M0
,
stride_M1
,
stride_M2
,
stride_N0
,
stride_N1
));
return
transform_tensor_descriptor
(
c_grid_desc_m0_m1_m2_n0_n1
,
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
)),
make_merge_transform
(
make_tuple
(
N0
,
N1
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_grid_desc_mraw_nraw
;
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
(
DEGridDesc_M0_M1_M2_N0_N1
{}));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
CShuffleDataType
,
ck
::
Tuple
<
DDataType
>
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
EGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
void
*
p_a_grid
,
const
void
*
p_b_grid
,
const
void
*
p_d_grid
,
void
*
p_e_grid
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
DEGridDesc_M0_M1_M2_N0_N1
d_grid_desc
,
DEGridDesc_M0_M1_M2_N0_N1
e_grid_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
// FIXME
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_grid_desc
)},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
{
if
(
MRaw
!=
d_grid_desc
.
M0_
*
d_grid_desc
.
M1_
*
d_grid_desc
.
M2_
)
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
if
(
NRaw
!=
d_grid_desc
.
N0_
*
d_grid_desc
.
N1_
)
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
p_ds_grid_
(
I0
)
=
static_cast
<
const
DDataType
*>
(
p_d_grid
);
const
auto
d_grid_desc_m_n
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
d_grid_desc
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
(
I0
)
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
d_grid_desc_m_n
);
}
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
// FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N
e_grid_desc_m_n_
;
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DefaultBlock2ETileMap
block_2_etile_map_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_gemm_bias_c_permute
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
ck
::
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2ETileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
);
};
float
ave_time
=
0
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_d
,
void
*
p_e
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
DEGridDesc_M0_M1_M2_N0_N1
d_grid_desc
,
DEGridDesc_M0_M1_M2_N0_N1
e_grid_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_d
,
p_e
,
MRaw
,
NRaw
,
KRaw
,
StrideA
,
StrideB
,
d_grid_desc
,
e_grid_desc
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_d
,
void
*
p_e
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
DEGridDesc_M0_M1_M2_N0_N1
d_grid_desc
,
DEGridDesc_M0_M1_M2_N0_N1
e_grid_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_d
,
p_e
,
MRaw
,
NRaw
,
KRaw
,
StrideA
,
StrideB
,
d_grid_desc
,
e_grid_desc
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGemmBiasCPermute_Xdl"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
View file @
2b27d5fc
...
...
@@ -9,91 +9,34 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsReduceAccElementwiseOperation
>
template
<
ck
::
index_t
NumDTensor
,
ck
::
index_t
NumReduce
>
struct
DeviceGemmReduce
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_bias
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_c
,
void
*
p_dx
s
,
std
::
array
<
void
*
,
NumReduce
>
p_reduce
s
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b
_element_op
,
CElementwiseOperation
c
_element_op
,
DxsInElementwiseOperation
dxs
_in_element_op
,
DxsReduceAccElementwiseOperation
dxs
_out_element_op
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
std
::
array
<
void
*
,
3
>
gemm
_element_op
s
,
std
::
array
<
void
*
,
NumDTensor
>
d
_element_op
s
,
std
::
array
<
void
*
,
NumReduce
>
reduce
_in_element_op
s
,
std
::
array
<
void
*
,
NumReduce
>
reduce
_out_element_op
s
,
ck
::
index_t
BatchCount
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsReduceAccElementwiseOperation
>
using
DeviceGemmReducePtr
=
std
::
unique_ptr
<
DeviceGemmReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsReduceAccElementwiseOperation
>>
;
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsReduceAccElementwiseOperation
>
struct
DeviceGemmBiasAddReduce
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
const
void
*
p_c0
,
const
void
*
p_c1
,
void
*
p_dxs
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
StrideC1
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
C1ElementwiseOperation
c1_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsReduceAccElementwiseOperation
dxs_out_element_op
,
ck
::
index_t
BatchCount
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsReduceAccElementwiseOperation
>
using
DeviceGemmBiasAddReducePtr
=
std
::
unique_ptr
<
DeviceGemmBiasAddReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C1ElementwiseOperation
,
DxsInElementwiseOperation
,
DxsReduceAccElementwiseOperation
>>
;
template
<
ck
::
index_t
NumDTensor
,
ck
::
index_t
NumReduce
>
using
DeviceGemmReducePtr
=
std
::
unique_ptr
<
DeviceGemmReduce
<
NumDTensor
,
NumReduce
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
View file @
2b27d5fc
...
...
@@ -32,14 +32,14 @@ template <typename ALayout,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
ReduceAccDataType
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Dxs
ReduceOperation
,
typename
Dxs
InElementwiseOperation
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
D
GlobalMemoryDataOperation
,
typename
ReduceOperation
s
,
typename
Reduce
InElementwiseOperation
s
,
typename
ReduceAccElementwiseOperation
s
,
typename
Reduce
GlobalMemoryDataOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
...
...
@@ -74,11 +74,7 @@ template <typename ALayout,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsReduceAccElementwiseOperation
>
struct
DeviceGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
0
,
ReduceOperations
::
Size
()
>
{
using
DeviceOp
=
DeviceGemmReduce_Xdl_CShuffle
;
...
...
@@ -350,8 +346,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
}
}
// assume
D
is packed tensor
static
auto
Make
D
GridDescriptor_M
(
index_t
MRaw
)
// assume
Reduce
is packed tensor
static
auto
Make
Reduce
GridDescriptor_M
(
index_t
MRaw
)
{
const
auto
d_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
...
...
@@ -379,7 +375,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
D
GridDesc_M
=
decltype
(
Make
D
GridDescriptor_M
(
1
));
using
Reduce
GridDesc_M
=
decltype
(
Make
Reduce
GridDescriptor_M
(
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
...
...
@@ -388,19 +384,19 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
CShuffleDataType
,
CDataType
,
ReduceAccDataType
,
D
PtrsGlobal
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
Dxs
ReduceOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
ReduceOperation
s
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
InMemoryDataOperationEnum
::
Set
,
D
GlobalMemoryDataOperation
,
Reduce
GlobalMemoryDataOperation
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
D
GridDesc_M
,
Reduce
GridDesc_M
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
...
...
@@ -443,7 +439,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
D
PtrsGlobal
p_
d
s_grid
,
Reduce
PtrsGlobal
p_
reduce
s_grid
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
...
...
@@ -453,24 +449,24 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
Dxs
InElementwiseOperation
dxs
_in_element_op
,
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op
)
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
,
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_
d
s_grid_
{
p_
d
s_grid
},
p_
reduce
s_grid_
{
p_
reduce
s_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
)},
d
_grid_desc_m_
{
DeviceOp
::
Make
D
GridDescriptor_M
(
MRaw
)},
reduce
_grid_desc_m_
{
DeviceOp
::
Make
Reduce
GridDescriptor_M
(
MRaw
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
d
_grid_desc_mblock_mperblock_
{},
reduce
_grid_desc_mblock_mperblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
},
dxs
_in_element_op_
{
dxs
_in_element_op
},
dxs
_out_element_op_
{
dxs
_out_element_op
}
reduce
_in_element_op
s
_
{
reduce
_in_element_op
s
},
reduce
_out_element_op
s
_
{
reduce
_out_element_op
s
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
...
...
@@ -481,8 +477,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
d
_grid_desc_mblock_mperblock_
=
GridwiseGemm
::
Make
D
GridDescriptor_MBlock_MPerBlock
(
d
_grid_desc_m_
);
reduce
_grid_desc_mblock_mperblock_
=
GridwiseGemm
::
Make
Reduce
GridDescriptor_MBlock_MPerBlock
(
reduce
_grid_desc_m_
);
}
}
...
...
@@ -490,20 +486,21 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
D
PtrsGlobal
p_
d
s_grid_
;
Reduce
PtrsGlobal
p_
reduce
s_grid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
D
GridDesc_M
d
_grid_desc_m_
;
Reduce
GridDesc_M
reduce
_grid_desc_m_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DGridDescriptor_MBlock_MPerBlock
d_grid_desc_mblock_mperblock_
;
typename
GridwiseGemm
::
ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
Dxs
InElementwiseOperation
dxs
_in_element_op_
;
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op_
;
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
_
;
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
_
;
};
// Invoker
...
...
@@ -528,7 +525,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.
d
_grid_desc_m_{ " << arg.
d
_grid_desc_m_.GetLength(I0) << "}"
std::cout << "arg.
reduce
_grid_desc_m_{ " << arg.
reduce
_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
}
#endif
...
...
@@ -554,16 +551,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
D
PtrsGlobal
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
D
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
Reduce
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
true
>
;
...
...
@@ -576,16 +573,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_
d
s_grid_
,
arg
.
p_
reduce
s_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
dxs
_in_element_op_
,
arg
.
dxs
_out_element_op_
,
arg
.
reduce
_in_element_op
s
_
,
arg
.
reduce
_out_element_op
s
_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d
_grid_desc_mblock_mperblock_
,
arg
.
reduce
_grid_desc_mblock_mperblock_
,
arg
.
block_2_ctile_map_
);
}
else
...
...
@@ -594,16 +591,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
D
PtrsGlobal
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
D
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
Reduce
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
false
>
;
...
...
@@ -616,16 +613,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_
d
s_grid_
,
arg
.
p_
reduce
s_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
dxs
_in_element_op_
,
arg
.
dxs
_out_element_op_
,
arg
.
reduce
_in_element_op
s
_
,
arg
.
reduce
_out_element_op
s
_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d
_grid_desc_mblock_mperblock_
,
arg
.
reduce
_grid_desc_mblock_mperblock_
,
arg
.
block_2_ctile_map_
);
}
...
...
@@ -660,37 +657,75 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
DPtrsGlobal
p_dxs
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsReduceAccElementwiseOperation
dxs_out_element_op
)
static
constexpr
int
NumReduce
=
ReduceOperations
::
Size
();
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_bias
,
std
::
array
<
const
void
*
,
0
>
p_ds
,
void
*
p_c
,
std
::
array
<
void
*
,
NumReduce
>
p_reduces
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
std
::
array
<
ck
::
index_t
,
0
>
StrideDs
,
std
::
array
<
void
*
,
3
>
gemm_element_ops
,
std
::
array
<
void
*
,
0
>
d_element_ops
,
std
::
array
<
void
*
,
NumReduce
>
reduce_in_element_op
,
std
::
array
<
void
*
,
NumReduce
>
reduce_out_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
p_dxs
,
MRaw
,
NRaw
,
KRaw
,
(
void
)
p_bias
;
(
void
)
p_ds
;
(
void
)
StrideDs
;
(
void
)
d_element_ops
;
ReducePtrsGlobal
reduce_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReducePtrsGlobal
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
static_cast
<
T
*>
(
p_reduces
[
I
]);
},
Number
<
NumReduce
>
{});
ReduceInElementwiseOperations
reduce_in_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceInElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_in_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
ReduceAccElementwiseOperations
reduce_out_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceAccElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_out_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
AElementwiseOperation
a_element_op
=
*
(
static_cast
<
AElementwiseOperation
*>
(
gemm_element_ops
[
0
]));
BElementwiseOperation
b_element_op
=
*
(
static_cast
<
BElementwiseOperation
*>
(
gemm_element_ops
[
1
]));
CElementwiseOperation
c_element_op
=
*
(
static_cast
<
CElementwiseOperation
*>
(
gemm_element_ops
[
2
]));
return
Argument
{
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
reduce_tuple
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
dxs
_in_element_op
,
dxs
_out_element_op
};
reduce
_in_element_op
s
,
reduce
_out_element_op
s
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -699,37 +734,73 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_bias
,
std
::
array
<
const
void
*
,
0
>
p_ds
,
void
*
p_c
,
void
*
p_dx
s
,
index_t
M
Raw
,
index_t
N
Raw
,
index_t
K
Raw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b
_element_op
,
CElementwiseOperation
c
_element_op
,
DxsInElementwiseOperation
dxs
_in_element_op
,
DxsReduceAccElementwiseOperation
dxs
_out_element_op
,
index_t
/* KBatch */
=
1
)
override
std
::
array
<
void
*
,
NumReduce
>
p_reduce
s
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
std
::
array
<
ck
::
index_t
,
0
>
StrideDs
,
std
::
array
<
void
*
,
3
>
gemm
_element_op
s
,
std
::
array
<
void
*
,
0
>
d
_element_op
s
,
std
::
array
<
void
*
,
NumReduce
>
reduce
_in_element_op
,
std
::
array
<
void
*
,
NumReduce
>
reduce
_out_element_op
,
ck
::
index_t
=
1
)
override
{
DPtrsGlobal
dxs_tuple
=
*
(
static_cast
<
DPtrsGlobal
*>
(
p_dxs
));
(
void
)
p_bias
;
(
void
)
p_ds
;
(
void
)
StrideDs
;
(
void
)
d_element_ops
;
ReducePtrsGlobal
reduce_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReducePtrsGlobal
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
static_cast
<
T
*>
(
p_reduces
[
I
]);
},
Number
<
NumReduce
>
{});
ReduceInElementwiseOperations
reduce_in_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceInElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_in_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
ReduceAccElementwiseOperations
reduce_out_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceAccElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_out_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
AElementwiseOperation
a_element_op
=
*
(
static_cast
<
AElementwiseOperation
*>
(
gemm_element_ops
[
0
]));
BElementwiseOperation
b_element_op
=
*
(
static_cast
<
BElementwiseOperation
*>
(
gemm_element_ops
[
1
]));
CElementwiseOperation
c_element_op
=
*
(
static_cast
<
CElementwiseOperation
*>
(
gemm_element_ops
[
2
]));
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
dxs
_tuple
,
M
Raw
,
N
Raw
,
K
Raw
,
reduce
_tuple
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
dxs
_in_element_op
,
dxs
_out_element_op
);
reduce
_in_element_op
s
,
reduce
_out_element_op
s
);
}
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp
0 → 100644
View file @
2b27d5fc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemmSplitK
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
KBatch
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGemmSplitKPtr
=
std
::
unique_ptr
<
DeviceGemmSplitK
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
View file @
2b27d5fc
...
...
@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm
_splitk
.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp"
#include "ck/device_utility/device_prop.hpp"
...
...
@@ -57,7 +57,7 @@ template <typename ADataType,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
>
struct
DeviceGemmXdlSplitK
:
public
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
:
public
DeviceGemm
SplitK
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
2b27d5fc
...
...
@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm
_splitk
.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp"
#include "ck/device_utility/device_prop.hpp"
...
...
@@ -59,7 +59,7 @@ template <typename ADataType,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXDL
>
struct
DeviceGemmXdlSplitKCShuffle
:
public
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
:
public
DeviceGemm
SplitK
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -420,21 +420,22 @@ struct DeviceGemmXdlSplitKCShuffle
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetElementSpaceSize
()
*
sizeof
(
CDataType
)));
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
};
if
(
has_main_k0_block_loop
)
...
...
include/ck/tensor_operation/gpu/device/device_normalization.hpp
0 → 100644
View file @
2b27d5fc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
struct
DeviceNormalization
:
public
BaseOperator
{
// inLengths: input tensor extent(s) from high to low dimension
// inStrides: input tensor stride(s) from high to low dimension
// reduceDims: the dimension(s) the normalization operation is applied
// alpha: typeless pointer in host memory storing the alpha scaling value of type AccDataType
// beta: typeless pointer in host memory storing the beta scaling value of type AccDataType
// in_dev: typeless const pointer in device memory storing the input tensor
// out_dev: typeless pointer in device memory storing the output tensor
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
void
*
alpha
,
const
void
*
beta
,
const
void
*
in_dev
,
void
*
out_dev
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
index_t
GetRank
()
const
=
0
;
virtual
index_t
GetNumReduceDim
()
const
=
0
;
};
using
DeviceNormalizationPtr
=
std
::
unique_ptr
<
DeviceNormalization
>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_softmax.hpp
View file @
2b27d5fc
...
...
@@ -9,6 +9,7 @@
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
...
...
@@ -33,8 +34,15 @@ template <typename InDataType,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
DeviceSoftmax
:
public
BaseOperator
struct
DeviceSoftmax
:
public
DeviceNormalization
{
static
constexpr
index_t
kRank
=
Rank
;
static
constexpr
index_t
kNumReduceDim
=
NumReduceDim
;
virtual
index_t
GetRank
()
const
override
{
return
kRank
;
}
virtual
index_t
GetNumReduceDim
()
const
override
{
return
kNumReduceDim
;
}
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
// Used for freeloading of some handy functions from DeviceReduceMultiBlock
...
...
@@ -61,18 +69,33 @@ struct DeviceSoftmax : public BaseOperator
using
GridDesc_M_K
=
decltype
(
Reduction
::
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridwiseReduce
=
GridwiseSoftmax_mk_to_mk
<
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
>
;
using
GridwiseSoftmaxGeneric
=
GridwiseSoftmax_mk_to_mk
<
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
,
false
>
;
using
GridwiseSoftmaxSweepOnce
=
GridwiseSoftmax_mk_to_mk
<
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
,
true
>
;
struct
Argument
:
public
Reduction
::
Argument
{
...
...
@@ -121,8 +144,19 @@ struct DeviceSoftmax : public BaseOperator
const
auto
out_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
kernel_main
=
kernel_softmax
<
GridwiseReduce
,
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
>
;
bool
sweep_once
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
const
auto
kernel_main
=
sweep_once
?
kernel_softmax
<
GridwiseSoftmaxSweepOnce
,
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
>
:
kernel_softmax
<
GridwiseSoftmaxGeneric
,
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
>
;
float
avg_time
=
0
;
...
...
@@ -167,24 +201,34 @@ struct DeviceSoftmax : public BaseOperator
return
true
;
};
// inLengths: input tensor extent(s) from high to low dimension
// inStrides: input tensor stride(s) from high to low dimension
// reduceDims: the dimension(s) the softmax normalization operate on
// alpha: typeless pointer in host memory storing the alpha scaling value as type AccDataType
// beta: typeless pointer in host memory storing the beta scaling value as type AccDataType
// in_dev: typeless const pointer in device memory storing the input tensor
// out_dev: typeless pointer in device memory storing the output tensor
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
int
>
reduceDims
,
AccDataType
alpha
,
AccDataType
beta
,
const
void
*
alpha
,
const
void
*
beta
,
const
void
*
in_dev
,
void
*
out_dev
)
void
*
out_dev
)
override
{
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
reduceDims
,
alpha
,
beta
,
*
static_cast
<
const
AccDataType
*>
(
alpha
)
,
*
static_cast
<
const
AccDataType
*>
(
beta
)
,
static_cast
<
const
InDataType
*>
(
in_dev
),
static_cast
<
OutDataType
*>
(
out_dev
));
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
2b27d5fc
...
...
@@ -11,8 +11,8 @@ namespace element_wise {
struct
Add
{
template
<
typename
T
>
__host__
__device__
constexpr
void
operator
()(
T
&
y
,
const
T
&
x0
,
const
T
&
x1
)
const
;
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__host__
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
;
template
<
>
__host__
__device__
constexpr
void
...
...
@@ -28,6 +28,13 @@ struct Add
y
=
x0
+
x1
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
float
&
x0
,
const
half_t
&
x1
)
const
{
y
=
type_convert
<
half_t
>
(
x0
)
+
x1
;
};
// Question: should half_t be supported ?
template
<
>
__host__
__device__
constexpr
void
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
View file @
2b27d5fc
...
...
@@ -23,19 +23,19 @@ template <typename GridwiseGemm,
typename
FloatC
,
typename
FloatC0
,
typename
FloatC1
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
Dxs
InElementwiseOperation
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
Reduce
InElementwiseOperation
s
,
typename
ReduceAccElementwiseOperation
s
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
D
GridDescriptor_MBlock_MPerBlock
,
typename
Reduce
GridDescriptor_MBlock_MPerBlock
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
__global__
void
...
...
@@ -46,15 +46,15 @@ __global__ void
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC0
*
__restrict__
p_
c0
_grid
,
const
FloatC1
*
__restrict__
p_
c1
_grid
,
D
PtrsGlobal
p_
d
s_grid
,
const
FloatC0
*
__restrict__
p_
bias
_grid
,
const
FloatC1
*
__restrict__
p_
d0
_grid
,
Reduce
PtrsGlobal
p_
reduce
s_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
C1ElementwiseOperation
c1_element_op
,
const
Dxs
InElementwiseOperation
dxs
_in_element_op
,
const
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op
,
const
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
,
const
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -63,7 +63,7 @@ __global__ void
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
D
GridDescriptor_MBlock_MPerBlock
d
_grid_desc_mblock_mperblock
,
const
Reduce
GridDescriptor_MBlock_MPerBlock
reduce
_grid_desc_mblock_mperblock
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...
...
@@ -72,42 +72,42 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_
c0
_grid
,
p_
c1
_grid
,
p_
d
s_grid
,
p_
bias
_grid
,
p_
d0
_grid
,
p_
reduce
s_grid
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
c1_element_op
,
dxs
_in_element_op
,
dxs
_out_element_op
,
reduce
_in_element_op
s
,
reduce
_out_element_op
s
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
d
_grid_desc_mblock_mperblock
,
reduce
_grid_desc_mblock_mperblock
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_
c0
_grid
;
ignore
=
p_
c1
_grid
;
ignore
=
p_
d
s_grid
;
ignore
=
p_
bias
_grid
;
ignore
=
p_
d0
_grid
;
ignore
=
p_
reduce
s_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c1_element_op
;
ignore
=
dxs
_in_element_op
;
ignore
=
dxs
_out_element_op
;
ignore
=
reduce
_in_element_op
s
;
ignore
=
reduce
_out_element_op
s
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c0_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c1_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
d
_grid_desc_mblock_mperblock
;
ignore
=
reduce
_grid_desc_mblock_mperblock
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
...
...
@@ -119,22 +119,22 @@ template <typename FloatAB,
typename
FloatC0
,
typename
FloatC1
,
typename
FloatReduceAcc
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
Dxs
ReduceOperation
,
typename
Dxs
InElementwiseOperation
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
ReduceOperation
s
,
typename
Reduce
InElementwiseOperation
s
,
typename
ReduceAccElementwiseOperation
s
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
D
GlobalMemoryDataOperation
,
typename
Reduce
GlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
typename
C0GridDesc_M_N
,
typename
C1GridDesc_M_N
,
typename
D
GridDesc_M
,
typename
Reduce
GridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
...
...
@@ -321,18 +321,18 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
__host__
__device__
static
constexpr
auto
Make
D
GridDescriptor_MBlock_MPerBlock
(
const
D
GridDesc_M
&
d_grid_desc_m
)
Make
Reduce
GridDescriptor_MBlock_MPerBlock
(
const
Reduce
GridDesc_M
&
d_grid_desc_m
)
{
const
auto
M
=
d_grid_desc_m
.
GetLength
(
I0
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
d
_grid_desc_mblock_mperblock
=
transform_tensor_descriptor
(
const
auto
reduce
_grid_desc_mblock_mperblock
=
transform_tensor_descriptor
(
d_grid_desc_m
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{}));
return
d
_grid_desc_mblock_mperblock
;
return
reduce
_grid_desc_mblock_mperblock
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
...
...
@@ -352,36 +352,37 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
C1GridDesc_M_N
{}))
>
;
using
D
GridDescriptor_MBlock_MPerBlock
=
remove_cvref_t
<
decltype
(
Make
D
GridDescriptor_MBlock_MPerBlock
(
D
GridDesc_M
{}))
>
;
using
Reduce
GridDescriptor_MBlock_MPerBlock
=
remove_cvref_t
<
decltype
(
Make
Reduce
GridDescriptor_MBlock_MPerBlock
(
Reduce
GridDesc_M
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC0
*
__restrict__
p_c0_grid
,
const
FloatC1
*
__restrict__
p_c1_grid
,
DPtrsGlobal
p_ds_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
C1ElementwiseOperation
&
c1_element_op
,
const
DxsInElementwiseOperation
&
dxs_in_element_op
,
const
DxsReduceAccElementwiseOperation
&
dxs_out_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
DGridDescriptor_MBlock_MPerBlock
&
d_grid_desc_mblock_mperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC0
*
__restrict__
p_bias_grid
,
const
FloatC1
*
__restrict__
p_d0_grid
,
ReducePtrsGlobal
p_reduces_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
C1ElementwiseOperation
&
c1_element_op
,
const
ReduceInElementwiseOperations
&
reduce_in_element_ops
,
const
ReduceAccElementwiseOperations
&
reduce_out_element_ops
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ReduceGridDescriptor_MBlock_MPerBlock
&
reduce_grid_desc_mblock_mperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
...
...
@@ -390,9 +391,9 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
c0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
c0
_grid
,
c0_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_
bias
_grid
,
c0_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
c1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
c1
_grid
,
c1_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_
d0
_grid
,
c1_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
block_work_idx
=
...
...
@@ -725,12 +726,12 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
mreduce_per_thread
>
{},
Number
<
nreduce_per_thread
>
{}));
// VGPR
d_
reduce_thread_desc_mperblock
constexpr
auto
d_
reduce_thread_desc_mperblock
=
// VGPR reduce_thread_desc_mperblock
constexpr
auto
reduce_thread_desc_mperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
mreduce_per_thread
>
{}));
// VGPR
d_
reduce_thread_desc_mblock_mperblock
constexpr
auto
d_
reduce_thread_desc_mblock_mperblock
=
// VGPR reduce_thread_desc_mblock_mperblock
constexpr
auto
reduce_thread_desc_mblock_mperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
mreduce_per_thread
>
{}));
auto
c_reduce_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
...
...
@@ -759,29 +760,29 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
1
,
true
>
{
c_reduce_block_desc_mperblock_nperblock
,
c_reduce_thread_data_idx_begin
};
auto
dxs_
reduce_thread_copy_vgpr_to_global
=
generate_tuple
(
auto
reduce_
tuple_
thread_copy_vgpr_to_global
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
p_
d
_grid
=
p_
d
s_grid
[
I
];
auto
d_out
_element_op
=
dxs
_out_element_op
[
I
];
auto
p_
reduce
_grid
=
p_
reduce
s_grid
[
I
];
auto
reduce_acc
_element_op
=
reduce
_out_element_op
s
[
I
];
return
ThreadwiseTensorSliceTransfer_v1r3
<
FloatReduceAcc
,
remove_pointer_t
<
decltype
(
p_
d
_grid
)
>
,
decltype
(
d_
reduce_thread_desc_mblock_mperblock
),
decltype
(
d
_grid_desc_mblock_mperblock
),
decltype
(
d_out
_element_op
),
remove_pointer_t
<
decltype
(
p_
reduce
_grid
)
>
,
decltype
(
reduce_thread_desc_mblock_mperblock
),
decltype
(
reduce
_grid_desc_mblock_mperblock
),
decltype
(
reduce_acc
_element_op
),
Sequence
<
1
,
mreduce_per_thread
>
,
Sequence
<
0
,
1
>
,
1
,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
D
GlobalMemoryDataOperation
::
At
(
I
),
Reduce
GlobalMemoryDataOperation
::
At
(
I
),
1
,
false
>
{
d
_grid_desc_mblock_mperblock
,
false
>
{
reduce
_grid_desc_mblock_mperblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
c_reduce_thread_data_idx_begin
[
I0
]),
// mperblock
d_out
_element_op
};
reduce_acc
_element_op
};
},
Number
<
p_
d
s_grid
.
Size
()
>
{});
Number
<
p_
reduce
s_grid
.
Size
()
>
{});
// c0 and c1
constexpr
auto
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock
=
...
...
@@ -909,35 +910,35 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
static_for
<
0
,
p_
d
s_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
auto
&
p_
d
_grid
=
p_
d
s_grid
[
In
];
static_for
<
0
,
p_
reduce
s_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
auto
&
p_
reduce
_grid
=
p_
reduce
s_grid
[
In
];
auto
d
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
d
_grid
,
d
_grid_desc_mblock_mperblock
.
GetElementSpaceSize
());
auto
reduce
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
reduce
_grid
,
reduce
_grid_desc_mblock_mperblock
.
GetElementSpaceSize
());
auto
d
_thread_buf
=
auto
reduce
_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
d_
reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
auto
&
d
_in_element_op
=
dxs
_in_element_op
[
In
];
auto
&
reduce
_in_element_op
=
reduce
_in_element_op
s
[
In
];
auto
&
d_
reduce_thread_copy_vgpr_to_global
=
dxs_
reduce_thread_copy_vgpr_to_global
(
In
);
auto
&
reduce_thread_copy_vgpr_to_global
=
reduce_
tuple_
thread_copy_vgpr_to_global
(
In
);
using
D
ReduceOperation
=
remove_cvref_t
<
decltype
(
Dxs
ReduceOperation
{}[
In
])
>
;
using
ReduceOperation
=
remove_cvref_t
<
decltype
(
ReduceOperation
s
{}[
In
])
>
;
using
ThreadwiseReduce
=
ThreadwiseReduction
<
FloatReduceAcc
,
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
d_
reduce_thread_desc_mperblock
),
D
ReduceOperation
,
decltype
(
reduce_thread_desc_mperblock
),
ReduceOperation
,
false
>
;
// Global write Gemm shuffle + reduction
const
auto
d_zero
Val
=
D
ReduceOperation
::
template
GetIdentityValue
<
FloatReduceAcc
>();
const
auto
reduce_identity
Val
=
ReduceOperation
::
template
GetIdentityValue
<
FloatReduceAcc
>();
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
[
&
](
auto
I
)
{
d
_thread_buf
(
I
)
=
d_zero
Val
;
});
[
&
](
auto
I
)
{
reduce
_thread_buf
(
I
)
=
reduce_identity
Val
;
});
// reduce in VGPR
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
im
)
{
...
...
@@ -946,26 +947,25 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Number
<
c_reduce_thread_desc_mperblock_nperblock
.
CalculateOffset
(
make_tuple
(
im
,
in
))
>
{};
d
_in_element_op
(
c_reduce_thread_buf
(
offset
),
c_reduce_thread_buf
(
offset
));
reduce
_in_element_op
(
c_reduce_thread_buf
(
offset
),
c_reduce_thread_buf
(
offset
));
});
});
ThreadwiseReduce
::
Reduce
(
c_reduce_thread_buf
,
d
_thread_buf
);
ThreadwiseReduce
::
Reduce
(
c_reduce_thread_buf
,
reduce
_thread_buf
);
// copy from VGPR to Global
d_reduce_thread_copy_vgpr_to_global
.
Run
(
d_reduce_thread_desc_mblock_mperblock
,
make_tuple
(
I0
,
I0
),
d_thread_buf
,
d_grid_desc_mblock_mperblock
,
d_grid_buf
);
reduce_thread_copy_vgpr_to_global
.
Run
(
reduce_thread_desc_mblock_mperblock
,
make_tuple
(
I0
,
I0
),
reduce_thread_buf
,
reduce_grid_desc_mblock_mperblock
,
reduce_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
d_
reduce_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
d
_grid_desc_mblock_mperblock
,
reduce_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
reduce
_grid_desc_mblock_mperblock
,
make_tuple
(
c_global_step
[
I0
],
c_global_step
[
I1
]));
}
});
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment