Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0b11569f
Commit
0b11569f
authored
Jul 01, 2022
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/develop' into batched_gemm_c_permute
parents
e8d3a0fb
fa9a0a5c
Changes
554
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1303 additions
and
320 deletions
+1303
-320
include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp
include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp
+3
-0
include/ck/tensor_operation/gpu/device/device_gemm_bias_activation.hpp
...nsor_operation/gpu/device/device_gemm_bias_activation.hpp
+3
-0
include/ck/tensor_operation/gpu/device/device_gemm_bias_activation_add.hpp
..._operation/gpu/device/device_gemm_bias_activation_add.hpp
+3
-0
include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
...n/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
+196
-138
include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp
...ensor_operation/gpu/device/device_gemm_bias_c_permute.hpp
+57
-0
include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp
...r_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp
+761
-0
include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp
include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp
+3
-0
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
...ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
+3
-0
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
+3
-0
include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
...ude/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
+14
-68
include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
..._operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
+169
-95
include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp
...ude/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp
+44
-0
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
+3
-0
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp
...peration/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp
+3
-0
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp
.../gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp
+3
-0
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp
.../device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp
+3
-0
include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp
.../tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp
+3
-0
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
...ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
+5
-2
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
...operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
+21
-17
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+3
-0
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
...
...
include/ck/tensor_operation/gpu/device/device_gemm_bias_activation.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
...
...
include/ck/tensor_operation/gpu/device/device_gemm_bias_activation_add.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef DEVICE_GEMM_BIAS_ACTIVATION_ADD_HPP
#define DEVICE_GEMM_BIAS_ACTIVATION_ADD_HPP
...
...
include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
...
...
@@ -26,20 +29,20 @@ template <typename ALayout,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
C0
DataType
,
typename
C1
DataType
,
typename
Bias
DataType
,
typename
D0
DataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
ReduceAccDataType
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1
ElementwiseOperation
,
typename
Dxs
ReduceOperation
,
typename
Dxs
InElementwiseOperation
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
D
GlobalMemoryDataOperation
,
typename
D0
ElementwiseOperation
,
typename
ReduceOperation
s
,
typename
Reduce
InElementwiseOperation
s
,
typename
ReduceAccElementwiseOperation
s
,
typename
Reduce
GlobalMemoryDataOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
...
...
@@ -74,13 +77,7 @@ template <typename ALayout,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmBiasAddReduce_Xdl_CShuffle
:
public
DeviceGemmBiasAddReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C1ElementwiseOperation
,
DxsInElementwiseOperation
,
DxsReduceAccElementwiseOperation
>
struct
DeviceGemmBiasAddReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
1
,
ReduceOperations
::
Size
()
>
{
using
DeviceOp
=
DeviceGemmBiasAddReduce_Xdl_CShuffle
;
...
...
@@ -353,7 +350,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
}
// assume D is packed tensor
static
auto
Make
D
GridDescriptor_M
(
index_t
MRaw
)
static
auto
Make
Reduce
GridDescriptor_M
(
index_t
MRaw
)
{
const
auto
d_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
...
...
@@ -383,7 +380,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
C0GridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
0
));
using
C1GridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
D
GridDesc_M
=
decltype
(
Make
D
GridDescriptor_M
(
1
));
using
Reduce
GridDesc_M
=
decltype
(
Make
Reduce
GridDescriptor_M
(
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
...
...
@@ -391,25 +388,25 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
C0
DataType
,
C1
DataType
,
Bias
DataType
,
D0
DataType
,
ReduceAccDataType
,
D
PtrsGlobal
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C1
ElementwiseOperation
,
Dxs
ReduceOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
D0
ElementwiseOperation
,
ReduceOperation
s
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
InMemoryDataOperationEnum
::
Set
,
D
GlobalMemoryDataOperation
,
Reduce
GlobalMemoryDataOperation
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
C0GridDesc_M_N
,
C1GridDesc_M_N
,
D
GridDesc_M
,
Reduce
GridDesc_M
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
...
...
@@ -452,9 +449,9 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
const
C0
DataType
*
p_
c0
_grid
,
const
C1
DataType
*
p_
c1
_grid
,
D
PtrsGlobal
p_
d
s_grid
,
const
Bias
DataType
*
p_
bias
_grid
,
const
D0
DataType
*
p_
d0
_grid
,
Reduce
PtrsGlobal
p_
reduce
s_grid
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
...
...
@@ -465,32 +462,32 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
C1
ElementwiseOperation
c1
_element_op
,
Dxs
InElementwiseOperation
dxs
_in_element_op
,
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op
)
D0
ElementwiseOperation
d0
_element_op
,
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
,
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_
c0
_grid_
{
p_
c0
_grid
},
p_
c1
_grid_
{
p_
c1
_grid
},
p_
d
s_grid_
{
p_
d
s_grid
},
p_
bias
_grid_
{
p_
bias
_grid
},
p_
d0
_grid_
{
p_
d0
_grid
},
p_
reduce
s_grid_
{
p_
reduce
s_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
)},
c0_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
0
)},
c1_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC1
)},
d
_grid_desc_m_
{
DeviceOp
::
Make
D
GridDescriptor_M
(
MRaw
)},
reduce
_grid_desc_m_
{
DeviceOp
::
Make
Reduce
GridDescriptor_M
(
MRaw
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c0_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c1_grid_desc_mblock_mperblock_nblock_nperblock_
{},
d
_grid_desc_mblock_mperblock_
{},
reduce
_grid_desc_mblock_mperblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
},
c1
_element_op_
{
c1
_element_op
},
dxs
_in_element_op_
{
dxs
_in_element_op
},
dxs
_out_element_op_
{
dxs
_out_element_op
}
d0
_element_op_
{
d0
_element_op
},
reduce
_in_element_op
s
_
{
reduce
_in_element_op
s
},
reduce
_out_element_op
s
_
{
reduce
_out_element_op
s
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
...
...
@@ -509,8 +506,8 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c1_grid_desc_m_n_
);
d
_grid_desc_mblock_mperblock_
=
GridwiseGemm
::
Make
D
GridDescriptor_MBlock_MPerBlock
(
d
_grid_desc_m_
);
reduce
_grid_desc_mblock_mperblock_
=
GridwiseGemm
::
Make
Reduce
GridDescriptor_MBlock_MPerBlock
(
reduce
_grid_desc_m_
);
}
}
...
...
@@ -518,29 +515,30 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
const
C0
DataType
*
p_
c0
_grid_
;
const
C1
DataType
*
p_
c1
_grid_
;
D
PtrsGlobal
p_
d
s_grid_
;
const
Bias
DataType
*
p_
bias
_grid_
;
const
D0
DataType
*
p_
d0
_grid_
;
Reduce
PtrsGlobal
p_
reduce
s_grid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
C0GridDesc_M_N
c0_grid_desc_m_n_
;
C1GridDesc_M_N
c1_grid_desc_m_n_
;
D
GridDesc_M
d
_grid_desc_m_
;
Reduce
GridDesc_M
reduce
_grid_desc_m_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c0_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DGridDescriptor_MBlock_MPerBlock
d_grid_desc_mblock_mperblock_
;
typename
GridwiseGemm
::
ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
C1
ElementwiseOperation
c1
_element_op_
;
Dxs
InElementwiseOperation
dxs
_in_element_op_
;
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op_
;
D0
ElementwiseOperation
d0
_element_op_
;
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
_
;
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
_
;
};
// Invoker
...
...
@@ -571,21 +569,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
C0
DataType
,
C1
DataType
,
D
PtrsGlobal
,
Bias
DataType
,
D0
DataType
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C1
ElementwiseOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
D0
ElementwiseOperation
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
D
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
Reduce
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
true
>
;
...
...
@@ -598,21 +596,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_
c0
_grid_
,
arg
.
p_
c1
_grid_
,
arg
.
p_
d
s_grid_
,
arg
.
p_
bias
_grid_
,
arg
.
p_
d0
_grid_
,
arg
.
p_
reduce
s_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c1
_element_op_
,
arg
.
dxs
_in_element_op_
,
arg
.
dxs
_out_element_op_
,
arg
.
d0
_element_op_
,
arg
.
reduce
_in_element_op
s
_
,
arg
.
reduce
_out_element_op
s
_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c0_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c1_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d
_grid_desc_mblock_mperblock_
,
arg
.
reduce
_grid_desc_mblock_mperblock_
,
arg
.
block_2_ctile_map_
);
}
else
...
...
@@ -621,21 +619,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
C0
DataType
,
C1
DataType
,
D
PtrsGlobal
,
Bias
DataType
,
D0
DataType
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C1
ElementwiseOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
D0
ElementwiseOperation
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
D
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
Reduce
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
false
>
;
...
...
@@ -648,21 +646,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_
c0
_grid_
,
arg
.
p_
c1
_grid_
,
arg
.
p_
d
s_grid_
,
arg
.
p_
bias
_grid_
,
arg
.
p_
d0
_grid_
,
arg
.
p_
reduce
s_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c1
_element_op_
,
arg
.
dxs
_in_element_op_
,
arg
.
dxs
_out_element_op_
,
arg
.
d0
_element_op_
,
arg
.
reduce
_in_element_op
s
_
,
arg
.
reduce
_out_element_op
s
_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c0_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c1_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d
_grid_desc_mblock_mperblock_
,
arg
.
reduce
_grid_desc_mblock_mperblock_
,
arg
.
block_2_ctile_map_
);
}
...
...
@@ -697,45 +695,76 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
const
C0DataType
*
p_c0
,
const
C1DataType
*
p_c1
,
DPtrsGlobal
p_dxs
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC1
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
C1ElementwiseOperation
c1_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsReduceAccElementwiseOperation
dxs_out_element_op
)
static
constexpr
int
NumReduce
=
ReduceOperations
::
Size
();
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_bias
,
std
::
array
<
const
void
*
,
1
>
p_ds
,
void
*
p_c
,
std
::
array
<
void
*
,
NumReduce
>
p_reduces
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
std
::
array
<
ck
::
index_t
,
1
>
StrideDs
,
std
::
array
<
void
*
,
3
>
gemm_element_ops
,
std
::
array
<
void
*
,
1
>
d_element_ops
,
std
::
array
<
void
*
,
NumReduce
>
reduce_in_element_op
,
std
::
array
<
void
*
,
NumReduce
>
reduce_out_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
p_c0
,
p_c1
,
p_dxs
,
MRaw
,
NRaw
,
KRaw
,
ReducePtrsGlobal
reduce_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReducePtrsGlobal
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
static_cast
<
T
*>
(
p_reduces
[
I
]);
},
Number
<
NumReduce
>
{});
ReduceInElementwiseOperations
reduce_in_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceInElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_in_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
ReduceAccElementwiseOperations
reduce_out_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceAccElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_out_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
AElementwiseOperation
a_element_op
=
*
(
static_cast
<
AElementwiseOperation
*>
(
gemm_element_ops
[
0
]));
BElementwiseOperation
b_element_op
=
*
(
static_cast
<
BElementwiseOperation
*>
(
gemm_element_ops
[
1
]));
CElementwiseOperation
c_element_op
=
*
(
static_cast
<
CElementwiseOperation
*>
(
gemm_element_ops
[
2
]));
D0ElementwiseOperation
d_element_op
=
*
(
static_cast
<
D0ElementwiseOperation
*>
(
d_element_ops
[
0
]));
return
Argument
{
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
const
BiasDataType
*>
(
p_bias
),
static_cast
<
const
D0DataType
*>
(
p_ds
[
0
]),
reduce_tuple
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
Stride
C1
,
Stride
Ds
[
0
]
,
a_element_op
,
b_element_op
,
c_element_op
,
c1
_element_op
,
dxs
_in_element_op
,
dxs
_out_element_op
};
d
_element_op
,
reduce
_in_element_op
s
,
reduce
_out_element_op
s
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -744,45 +773,74 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_bias
,
std
::
array
<
const
void
*
,
1
>
p_ds
,
void
*
p_c
,
const
void
*
p_c0
,
const
void
*
p_c1
,
void
*
p_dxs
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC1
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
C1ElementwiseOperation
c1_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsReduceAccElementwiseOperation
dxs_out_element_op
,
std
::
array
<
void
*
,
NumReduce
>
p_reduces
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
std
::
array
<
ck
::
index_t
,
1
>
StrideDs
,
std
::
array
<
void
*
,
3
>
gemm_element_ops
,
std
::
array
<
void
*
,
1
>
d_element_ops
,
std
::
array
<
void
*
,
NumReduce
>
reduce_in_element_op
,
std
::
array
<
void
*
,
NumReduce
>
reduce_out_element_op
,
index_t
/* KBatch */
=
1
)
override
{
DPtrsGlobal
dxs_tuple
=
*
(
static_cast
<
DPtrsGlobal
*>
(
p_dxs
));
ReducePtrsGlobal
reduce_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReducePtrsGlobal
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
static_cast
<
T
*>
(
p_reduces
[
I
]);
},
Number
<
NumReduce
>
{});
ReduceInElementwiseOperations
reduce_in_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceInElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_in_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
ReduceAccElementwiseOperations
reduce_out_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceAccElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_out_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
AElementwiseOperation
a_element_op
=
*
(
static_cast
<
AElementwiseOperation
*>
(
gemm_element_ops
[
0
]));
BElementwiseOperation
b_element_op
=
*
(
static_cast
<
BElementwiseOperation
*>
(
gemm_element_ops
[
1
]));
CElementwiseOperation
c_element_op
=
*
(
static_cast
<
CElementwiseOperation
*>
(
gemm_element_ops
[
2
]));
D0ElementwiseOperation
d_element_op
=
*
(
static_cast
<
D0ElementwiseOperation
*>
(
d_element_ops
[
0
]));
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
const
C0
DataType
*>
(
p_
c0
),
static_cast
<
const
C1
DataType
*>
(
p_
c1
),
dxs
_tuple
,
M
Raw
,
N
Raw
,
K
Raw
,
static_cast
<
const
Bias
DataType
*>
(
p_
bias
),
static_cast
<
const
D0
DataType
*>
(
p_
ds
[
0
]
),
reduce
_tuple
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
Stride
C1
,
Stride
Ds
[
0
]
,
a_element_op
,
b_element_op
,
c_element_op
,
c1
_element_op
,
dxs
_in_element_op
,
dxs
_out_element_op
);
d
_element_op
,
reduce
_in_element_op
s
,
reduce
_out_element_op
s
);
}
// polymorphic
...
...
@@ -797,7 +855,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGemmReduce_Xdl_CShuffle"
str
<<
"DeviceGemm
BiasAdd
Reduce_Xdl_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp
0 → 100644
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
struct
DEGridDesc_M0_M1_M2_N0_N1
{
ck
::
index_t
M0_
,
M1_
,
M2_
,
N0_
,
N1_
;
ck
::
index_t
stride_M0_
,
stride_M1_
,
stride_M2_
,
stride_N0_
,
stride_N1_
;
};
// input : A[M, K], B[K, N],
// input : D[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D)
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceGemmBiasCPermute
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_d
,
void
*
p_e
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
DEGridDesc_M0_M1_M2_N0_N1
d_gride_desc
,
DEGridDesc_M0_M1_M2_N0_N1
e_gride_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGemmBiasCPermutePtr
=
std
::
unique_ptr
<
DeviceGemmBiasCPermute
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp
0 → 100644
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatDsPointer
,
typename
FloatE
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_bias_c_permute
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatDsPointer
p_ds_grid
,
FloatE
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_etile_map
;
#endif
}
}
// namespace ck
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// input : A[M, K], or A[K, N]
// input : B[K, N], or A[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
template
<
typename
ALayout
,
typename
BLayout
,
typename
CDELayout
,
typename
ADataType
,
typename
BDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
DDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmBiasCPermute_Xdl
:
public
DeviceGemmBiasCPermute
<
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceGemmBiasCPermute_Xdl
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
index_t
NumDTensor
=
I1
;
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
}
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
static
auto
MakeEGridDescriptor_M_N
(
DEGridDesc_M0_M1_M2_N0_N1
d_e_grid_desc
)
{
index_t
M0
=
d_e_grid_desc
.
M0_
;
index_t
M1
=
d_e_grid_desc
.
M1_
;
index_t
M2
=
d_e_grid_desc
.
M2_
;
index_t
N0
=
d_e_grid_desc
.
N0_
;
index_t
N1
=
d_e_grid_desc
.
N1_
;
index_t
stride_M0
=
d_e_grid_desc
.
stride_M0_
;
index_t
stride_M1
=
d_e_grid_desc
.
stride_M1_
;
index_t
stride_M2
=
d_e_grid_desc
.
stride_M2_
;
index_t
stride_N0
=
d_e_grid_desc
.
stride_N0_
;
index_t
stride_N1
=
d_e_grid_desc
.
stride_N1_
;
const
auto
MRaw
=
M0
*
M1
*
M2
;
const
auto
NRaw
=
N0
*
N1
;
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
const
auto
c_grid_desc_m0_m1_m2_n0_n1
=
make_naive_tensor_descriptor
(
make_tuple
(
M0
,
M1
,
M2
,
N0
,
N1
),
make_tuple
(
stride_M0
,
stride_M1
,
stride_M2
,
stride_N0
,
stride_N1
));
return
transform_tensor_descriptor
(
c_grid_desc_m0_m1_m2_n0_n1
,
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
)),
make_merge_transform
(
make_tuple
(
N0
,
N1
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_grid_desc_mraw_nraw
;
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
(
DEGridDesc_M0_M1_M2_N0_N1
{}));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
CShuffleDataType
,
ck
::
Tuple
<
DDataType
>
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
EGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
void
*
p_a_grid
,
const
void
*
p_b_grid
,
const
void
*
p_d_grid
,
void
*
p_e_grid
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
DEGridDesc_M0_M1_M2_N0_N1
d_grid_desc
,
DEGridDesc_M0_M1_M2_N0_N1
e_grid_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
// FIXME
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_grid_desc
)},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
{
if
(
MRaw
!=
d_grid_desc
.
M0_
*
d_grid_desc
.
M1_
*
d_grid_desc
.
M2_
)
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
if
(
NRaw
!=
d_grid_desc
.
N0_
*
d_grid_desc
.
N1_
)
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
p_ds_grid_
(
I0
)
=
static_cast
<
const
DDataType
*>
(
p_d_grid
);
const
auto
d_grid_desc_m_n
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
d_grid_desc
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
(
I0
)
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
d_grid_desc_m_n
);
}
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
// FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N
e_grid_desc_m_n_
;
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DefaultBlock2ETileMap
block_2_etile_map_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_gemm_bias_c_permute
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
ck
::
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2ETileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
);
};
float
ave_time
=
0
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_d
,
void
*
p_e
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
DEGridDesc_M0_M1_M2_N0_N1
d_grid_desc
,
DEGridDesc_M0_M1_M2_N0_N1
e_grid_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_d
,
p_e
,
MRaw
,
NRaw
,
KRaw
,
StrideA
,
StrideB
,
d_grid_desc
,
e_grid_desc
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_d
,
void
*
p_e
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
DEGridDesc_M0_M1_M2_N0_N1
d_grid_desc
,
DEGridDesc_M0_M1_M2_N0_N1
e_grid_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_d
,
p_e
,
MRaw
,
NRaw
,
KRaw
,
StrideA
,
StrideB
,
d_grid_desc
,
e_grid_desc
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGemmBiasCPermute_Xdl"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
...
...
include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include "device_base.hpp"
...
...
@@ -6,91 +9,34 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsReduceAccElementwiseOperation
>
template
<
ck
::
index_t
NumDTensor
,
ck
::
index_t
NumReduce
>
struct
DeviceGemmReduce
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_bias
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_c
,
void
*
p_dxs
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsReduceAccElementwiseOperation
dxs_out_element_op
,
ck
::
index_t
BatchCount
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsReduceAccElementwiseOperation
>
using
DeviceGemmReducePtr
=
std
::
unique_ptr
<
DeviceGemmReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsReduceAccElementwiseOperation
>>
;
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsReduceAccElementwiseOperation
>
struct
DeviceGemmBiasAddReduce
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
const
void
*
p_c0
,
const
void
*
p_c1
,
void
*
p_dxs
,
std
::
array
<
void
*
,
NumReduce
>
p_reduces
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
StrideC1
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
C1ElementwiseOperation
c1_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsReduceAccElementwiseOperation
dxs_out_element_op
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
std
::
array
<
void
*
,
3
>
gemm_element_ops
,
std
::
array
<
void
*
,
NumDTensor
>
d_element_ops
,
std
::
array
<
void
*
,
NumReduce
>
reduce_in_element_ops
,
std
::
array
<
void
*
,
NumReduce
>
reduce_out_element_ops
,
ck
::
index_t
BatchCount
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsReduceAccElementwiseOperation
>
using
DeviceGemmBiasAddReducePtr
=
std
::
unique_ptr
<
DeviceGemmBiasAddReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C1ElementwiseOperation
,
DxsInElementwiseOperation
,
DxsReduceAccElementwiseOperation
>>
;
template
<
ck
::
index_t
NumDTensor
,
ck
::
index_t
NumReduce
>
using
DeviceGemmReducePtr
=
std
::
unique_ptr
<
DeviceGemmReduce
<
NumDTensor
,
NumReduce
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
...
...
@@ -29,14 +32,14 @@ template <typename ALayout,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
ReduceAccDataType
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Dxs
ReduceOperation
,
typename
Dxs
InElementwiseOperation
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
D
GlobalMemoryDataOperation
,
typename
ReduceOperation
s
,
typename
Reduce
InElementwiseOperation
s
,
typename
ReduceAccElementwiseOperation
s
,
typename
Reduce
GlobalMemoryDataOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
...
...
@@ -71,11 +74,7 @@ template <typename ALayout,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsReduceAccElementwiseOperation
>
struct
DeviceGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
0
,
ReduceOperations
::
Size
()
>
{
using
DeviceOp
=
DeviceGemmReduce_Xdl_CShuffle
;
...
...
@@ -347,8 +346,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
}
}
// assume
D
is packed tensor
static
auto
Make
D
GridDescriptor_M
(
index_t
MRaw
)
// assume
Reduce
is packed tensor
static
auto
Make
Reduce
GridDescriptor_M
(
index_t
MRaw
)
{
const
auto
d_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
...
...
@@ -376,7 +375,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
D
GridDesc_M
=
decltype
(
Make
D
GridDescriptor_M
(
1
));
using
Reduce
GridDesc_M
=
decltype
(
Make
Reduce
GridDescriptor_M
(
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
...
...
@@ -385,19 +384,19 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
CShuffleDataType
,
CDataType
,
ReduceAccDataType
,
D
PtrsGlobal
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
Dxs
ReduceOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
ReduceOperation
s
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
InMemoryDataOperationEnum
::
Set
,
D
GlobalMemoryDataOperation
,
Reduce
GlobalMemoryDataOperation
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
D
GridDesc_M
,
Reduce
GridDesc_M
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
...
...
@@ -440,7 +439,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
D
PtrsGlobal
p_
d
s_grid
,
Reduce
PtrsGlobal
p_
reduce
s_grid
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
...
...
@@ -450,24 +449,24 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
Dxs
InElementwiseOperation
dxs
_in_element_op
,
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op
)
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
,
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_
d
s_grid_
{
p_
d
s_grid
},
p_
reduce
s_grid_
{
p_
reduce
s_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
)},
d
_grid_desc_m_
{
DeviceOp
::
Make
D
GridDescriptor_M
(
MRaw
)},
reduce
_grid_desc_m_
{
DeviceOp
::
Make
Reduce
GridDescriptor_M
(
MRaw
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
d
_grid_desc_mblock_mperblock_
{},
reduce
_grid_desc_mblock_mperblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
},
dxs
_in_element_op_
{
dxs
_in_element_op
},
dxs
_out_element_op_
{
dxs
_out_element_op
}
reduce
_in_element_op
s
_
{
reduce
_in_element_op
s
},
reduce
_out_element_op
s
_
{
reduce
_out_element_op
s
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
...
...
@@ -478,8 +477,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
d
_grid_desc_mblock_mperblock_
=
GridwiseGemm
::
Make
D
GridDescriptor_MBlock_MPerBlock
(
d
_grid_desc_m_
);
reduce
_grid_desc_mblock_mperblock_
=
GridwiseGemm
::
Make
Reduce
GridDescriptor_MBlock_MPerBlock
(
reduce
_grid_desc_m_
);
}
}
...
...
@@ -487,20 +486,21 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
D
PtrsGlobal
p_
d
s_grid_
;
Reduce
PtrsGlobal
p_
reduce
s_grid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
D
GridDesc_M
d
_grid_desc_m_
;
Reduce
GridDesc_M
reduce
_grid_desc_m_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DGridDescriptor_MBlock_MPerBlock
d_grid_desc_mblock_mperblock_
;
typename
GridwiseGemm
::
ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
Dxs
InElementwiseOperation
dxs
_in_element_op_
;
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op_
;
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
_
;
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
_
;
};
// Invoker
...
...
@@ -525,7 +525,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.
d
_grid_desc_m_{ " << arg.
d
_grid_desc_m_.GetLength(I0) << "}"
std::cout << "arg.
reduce
_grid_desc_m_{ " << arg.
reduce
_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
}
#endif
...
...
@@ -551,16 +551,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
D
PtrsGlobal
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
D
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
Reduce
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
true
>
;
...
...
@@ -573,16 +573,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_
d
s_grid_
,
arg
.
p_
reduce
s_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
dxs
_in_element_op_
,
arg
.
dxs
_out_element_op_
,
arg
.
reduce
_in_element_op
s
_
,
arg
.
reduce
_out_element_op
s
_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d
_grid_desc_mblock_mperblock_
,
arg
.
reduce
_grid_desc_mblock_mperblock_
,
arg
.
block_2_ctile_map_
);
}
else
...
...
@@ -591,16 +591,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
D
PtrsGlobal
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
D
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
Reduce
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
false
>
;
...
...
@@ -613,16 +613,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_
d
s_grid_
,
arg
.
p_
reduce
s_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
dxs
_in_element_op_
,
arg
.
dxs
_out_element_op_
,
arg
.
reduce
_in_element_op
s
_
,
arg
.
reduce
_out_element_op
s
_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d
_grid_desc_mblock_mperblock_
,
arg
.
reduce
_grid_desc_mblock_mperblock_
,
arg
.
block_2_ctile_map_
);
}
...
...
@@ -657,37 +657,75 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
DPtrsGlobal
p_dxs
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsReduceAccElementwiseOperation
dxs_out_element_op
)
static
constexpr
int
NumReduce
=
ReduceOperations
::
Size
();
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_bias
,
std
::
array
<
const
void
*
,
0
>
p_ds
,
void
*
p_c
,
std
::
array
<
void
*
,
NumReduce
>
p_reduces
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
std
::
array
<
ck
::
index_t
,
0
>
StrideDs
,
std
::
array
<
void
*
,
3
>
gemm_element_ops
,
std
::
array
<
void
*
,
0
>
d_element_ops
,
std
::
array
<
void
*
,
NumReduce
>
reduce_in_element_op
,
std
::
array
<
void
*
,
NumReduce
>
reduce_out_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
p_dxs
,
MRaw
,
NRaw
,
KRaw
,
(
void
)
p_bias
;
(
void
)
p_ds
;
(
void
)
StrideDs
;
(
void
)
d_element_ops
;
ReducePtrsGlobal
reduce_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReducePtrsGlobal
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
static_cast
<
T
*>
(
p_reduces
[
I
]);
},
Number
<
NumReduce
>
{});
ReduceInElementwiseOperations
reduce_in_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceInElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_in_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
ReduceAccElementwiseOperations
reduce_out_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceAccElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_out_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
AElementwiseOperation
a_element_op
=
*
(
static_cast
<
AElementwiseOperation
*>
(
gemm_element_ops
[
0
]));
BElementwiseOperation
b_element_op
=
*
(
static_cast
<
BElementwiseOperation
*>
(
gemm_element_ops
[
1
]));
CElementwiseOperation
c_element_op
=
*
(
static_cast
<
CElementwiseOperation
*>
(
gemm_element_ops
[
2
]));
return
Argument
{
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
reduce_tuple
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
dxs
_in_element_op
,
dxs
_out_element_op
};
reduce
_in_element_op
s
,
reduce
_out_element_op
s
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -696,37 +734,73 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_bias
,
std
::
array
<
const
void
*
,
0
>
p_ds
,
void
*
p_c
,
void
*
p_dx
s
,
index_t
M
Raw
,
index_t
N
Raw
,
index_t
K
Raw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b
_element_op
,
CElementwiseOperation
c
_element_op
,
DxsInElementwiseOperation
dxs
_in_element_op
,
DxsReduceAccElementwiseOperation
dxs
_out_element_op
,
index_t
/* KBatch */
=
1
)
override
std
::
array
<
void
*
,
NumReduce
>
p_reduce
s
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
std
::
array
<
ck
::
index_t
,
0
>
StrideDs
,
std
::
array
<
void
*
,
3
>
gemm
_element_op
s
,
std
::
array
<
void
*
,
0
>
d
_element_op
s
,
std
::
array
<
void
*
,
NumReduce
>
reduce
_in_element_op
,
std
::
array
<
void
*
,
NumReduce
>
reduce
_out_element_op
,
ck
::
index_t
=
1
)
override
{
DPtrsGlobal
dxs_tuple
=
*
(
static_cast
<
DPtrsGlobal
*>
(
p_dxs
));
(
void
)
p_bias
;
(
void
)
p_ds
;
(
void
)
StrideDs
;
(
void
)
d_element_ops
;
ReducePtrsGlobal
reduce_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReducePtrsGlobal
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
static_cast
<
T
*>
(
p_reduces
[
I
]);
},
Number
<
NumReduce
>
{});
ReduceInElementwiseOperations
reduce_in_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceInElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_in_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
ReduceAccElementwiseOperations
reduce_out_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceAccElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_out_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
AElementwiseOperation
a_element_op
=
*
(
static_cast
<
AElementwiseOperation
*>
(
gemm_element_ops
[
0
]));
BElementwiseOperation
b_element_op
=
*
(
static_cast
<
BElementwiseOperation
*>
(
gemm_element_ops
[
1
]));
CElementwiseOperation
c_element_op
=
*
(
static_cast
<
CElementwiseOperation
*>
(
gemm_element_ops
[
2
]));
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
dxs
_tuple
,
M
Raw
,
N
Raw
,
K
Raw
,
reduce
_tuple
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
dxs
_in_element_op
,
dxs
_out_element_op
);
reduce
_in_element_op
s
,
reduce
_out_element_op
s
);
}
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp
0 → 100644
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemmSplitK
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
KBatch
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGemmSplitKPtr
=
std
::
unique_ptr
<
DeviceGemmSplitK
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
...
...
@@ -7,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm
_splitk
.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp"
#include "ck/device_utility/device_prop.hpp"
...
...
@@ -54,7 +57,7 @@ template <typename ADataType,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
>
struct
DeviceGemmXdlSplitK
:
public
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
:
public
DeviceGemm
SplitK
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
...
...
@@ -7,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm
_splitk
.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp"
#include "ck/device_utility/device_prop.hpp"
...
...
@@ -56,7 +59,7 @@ template <typename ADataType,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXDL
>
struct
DeviceGemmXdlSplitKCShuffle
:
public
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
:
public
DeviceGemm
SplitK
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -417,21 +420,22 @@ struct DeviceGemmXdlSplitKCShuffle
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetElementSpaceSize
()
*
sizeof
(
CDataType
)));
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
};
if
(
has_main_k0_block_loop
)
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment