Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f5de8b57
"vscode:/vscode.git/clone" did not exist on "dfa1a569f21c2b9d3d931f9d90f1bb918e787c60"
Unverified
Commit
f5de8b57
authored
Jun 30, 2022
by
Chao Liu
Committed by
GitHub
Jun 30, 2022
Browse files
Merge branch 'develop' into modified_grouped_gemm_addressing_method
parents
e83c7061
fa9a0a5c
Changes
76
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1772 additions
and
576 deletions
+1772
-576
include/ck/tensor_operation/gpu/device/device_elementwise.hpp
...ude/ck/tensor_operation/gpu/device/device_elementwise.hpp
+40
-0
include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
...n/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
+193
-138
include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp
...ensor_operation/gpu/device/device_gemm_bias_c_permute.hpp
+57
-0
include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp
...r_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp
+761
-0
include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
...ude/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
+11
-68
include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
..._operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
+166
-95
include/ck/tensor_operation/gpu/device/device_normalization.hpp
...e/ck/tensor_operation/gpu/device/device_normalization.hpp
+43
-0
include/ck/tensor_operation/gpu/device/device_softmax.hpp
include/ck/tensor_operation/gpu/device/device_softmax.hpp
+65
-21
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+9
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
...pu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
+99
-99
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
...eration/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
+84
-84
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
+82
-61
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+17
-2
include/ck/utility/math.hpp
include/ck/utility/math.hpp
+2
-0
include/ck/utility/reduction_functions_accumulate.hpp
include/ck/utility/reduction_functions_accumulate.hpp
+1
-1
library/include/ck/library/host_tensor/host_tensor.hpp
library/include/ck/library/host_tensor/host_tensor.hpp
+6
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp
...rary/reference_tensor_operation/cpu/reference_softmax.hpp
+2
-5
library/include/ck/library/tensor_operation_instance/device_operation_instance.hpp
...y/tensor_operation_instance/device_operation_instance.hpp
+1
-0
library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp
...or_operation_instance/gpu/device_elementwise_instance.hpp
+49
-0
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp
...ion_instance/gpu/device_gemm_mean_squaremean_instance.hpp
+84
-0
No files found.
include/ck/tensor_operation/gpu/device/device_elementwise.hpp
0 → 100644
View file @
f5de8b57
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
ck
::
index_t
NumInputTensor
,
ck
::
index_t
NumOutputTensor
,
index_t
NDim
,
typename
ElementwiseFunctor
>
struct
DeviceElementwise
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
array
<
const
void
*
,
NumInputTensor
>
p_inputs
,
std
::
array
<
void
*
,
NumOutputTensor
>
p_outputs
,
std
::
vector
<
index_t
>
lengths
,
std
::
vector
<
std
::
vector
<
index_t
>>
input_strides
,
std
::
vector
<
std
::
vector
<
index_t
>>
output_strides
,
ElementwiseFunctor
functor
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
ck
::
index_t
NumInputTensor
,
ck
::
index_t
NumOutputTensor
,
index_t
NDim
,
typename
ElementwiseFunctor
>
using
DeviceElementwisePtr
=
std
::
unique_ptr
<
DeviceElementwise
<
NumInputTensor
,
NumOutputTensor
,
NDim
,
ElementwiseFunctor
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
View file @
f5de8b57
...
...
@@ -29,20 +29,20 @@ template <typename ALayout,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
C0
DataType
,
typename
C1
DataType
,
typename
Bias
DataType
,
typename
D0
DataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
ReduceAccDataType
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1
ElementwiseOperation
,
typename
Dxs
ReduceOperation
,
typename
Dxs
InElementwiseOperation
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
D
GlobalMemoryDataOperation
,
typename
D0
ElementwiseOperation
,
typename
ReduceOperation
s
,
typename
Reduce
InElementwiseOperation
s
,
typename
ReduceAccElementwiseOperation
s
,
typename
Reduce
GlobalMemoryDataOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
...
...
@@ -77,13 +77,7 @@ template <typename ALayout,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmBiasAddReduce_Xdl_CShuffle
:
public
DeviceGemmBiasAddReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C1ElementwiseOperation
,
DxsInElementwiseOperation
,
DxsReduceAccElementwiseOperation
>
struct
DeviceGemmBiasAddReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
1
,
ReduceOperations
::
Size
()
>
{
using
DeviceOp
=
DeviceGemmBiasAddReduce_Xdl_CShuffle
;
...
...
@@ -356,7 +350,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
}
// assume D is packed tensor
static
auto
Make
D
GridDescriptor_M
(
index_t
MRaw
)
static
auto
Make
Reduce
GridDescriptor_M
(
index_t
MRaw
)
{
const
auto
d_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
...
...
@@ -386,7 +380,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
C0GridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
0
));
using
C1GridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
D
GridDesc_M
=
decltype
(
Make
D
GridDescriptor_M
(
1
));
using
Reduce
GridDesc_M
=
decltype
(
Make
Reduce
GridDescriptor_M
(
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
...
...
@@ -394,25 +388,25 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
C0
DataType
,
C1
DataType
,
Bias
DataType
,
D0
DataType
,
ReduceAccDataType
,
D
PtrsGlobal
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C1
ElementwiseOperation
,
Dxs
ReduceOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
D0
ElementwiseOperation
,
ReduceOperation
s
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
InMemoryDataOperationEnum
::
Set
,
D
GlobalMemoryDataOperation
,
Reduce
GlobalMemoryDataOperation
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
C0GridDesc_M_N
,
C1GridDesc_M_N
,
D
GridDesc_M
,
Reduce
GridDesc_M
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
...
...
@@ -455,9 +449,9 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
const
C0
DataType
*
p_
c0
_grid
,
const
C1
DataType
*
p_
c1
_grid
,
D
PtrsGlobal
p_
d
s_grid
,
const
Bias
DataType
*
p_
bias
_grid
,
const
D0
DataType
*
p_
d0
_grid
,
Reduce
PtrsGlobal
p_
reduce
s_grid
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
...
...
@@ -468,32 +462,32 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
C1
ElementwiseOperation
c1
_element_op
,
Dxs
InElementwiseOperation
dxs
_in_element_op
,
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op
)
D0
ElementwiseOperation
d0
_element_op
,
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
,
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_
c0
_grid_
{
p_
c0
_grid
},
p_
c1
_grid_
{
p_
c1
_grid
},
p_
d
s_grid_
{
p_
d
s_grid
},
p_
bias
_grid_
{
p_
bias
_grid
},
p_
d0
_grid_
{
p_
d0
_grid
},
p_
reduce
s_grid_
{
p_
reduce
s_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
)},
c0_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
0
)},
c1_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC1
)},
d
_grid_desc_m_
{
DeviceOp
::
Make
D
GridDescriptor_M
(
MRaw
)},
reduce
_grid_desc_m_
{
DeviceOp
::
Make
Reduce
GridDescriptor_M
(
MRaw
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c0_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c1_grid_desc_mblock_mperblock_nblock_nperblock_
{},
d
_grid_desc_mblock_mperblock_
{},
reduce
_grid_desc_mblock_mperblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
},
c1
_element_op_
{
c1
_element_op
},
dxs
_in_element_op_
{
dxs
_in_element_op
},
dxs
_out_element_op_
{
dxs
_out_element_op
}
d0
_element_op_
{
d0
_element_op
},
reduce
_in_element_op
s
_
{
reduce
_in_element_op
s
},
reduce
_out_element_op
s
_
{
reduce
_out_element_op
s
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
...
...
@@ -512,8 +506,8 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c1_grid_desc_m_n_
);
d
_grid_desc_mblock_mperblock_
=
GridwiseGemm
::
Make
D
GridDescriptor_MBlock_MPerBlock
(
d
_grid_desc_m_
);
reduce
_grid_desc_mblock_mperblock_
=
GridwiseGemm
::
Make
Reduce
GridDescriptor_MBlock_MPerBlock
(
reduce
_grid_desc_m_
);
}
}
...
...
@@ -521,29 +515,30 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
const
C0
DataType
*
p_
c0
_grid_
;
const
C1
DataType
*
p_
c1
_grid_
;
D
PtrsGlobal
p_
d
s_grid_
;
const
Bias
DataType
*
p_
bias
_grid_
;
const
D0
DataType
*
p_
d0
_grid_
;
Reduce
PtrsGlobal
p_
reduce
s_grid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
C0GridDesc_M_N
c0_grid_desc_m_n_
;
C1GridDesc_M_N
c1_grid_desc_m_n_
;
D
GridDesc_M
d
_grid_desc_m_
;
Reduce
GridDesc_M
reduce
_grid_desc_m_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c0_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DGridDescriptor_MBlock_MPerBlock
d_grid_desc_mblock_mperblock_
;
typename
GridwiseGemm
::
ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
C1
ElementwiseOperation
c1
_element_op_
;
Dxs
InElementwiseOperation
dxs
_in_element_op_
;
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op_
;
D0
ElementwiseOperation
d0
_element_op_
;
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
_
;
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
_
;
};
// Invoker
...
...
@@ -574,21 +569,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
C0
DataType
,
C1
DataType
,
D
PtrsGlobal
,
Bias
DataType
,
D0
DataType
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C1
ElementwiseOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
D0
ElementwiseOperation
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
D
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
Reduce
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
true
>
;
...
...
@@ -601,21 +596,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_
c0
_grid_
,
arg
.
p_
c1
_grid_
,
arg
.
p_
d
s_grid_
,
arg
.
p_
bias
_grid_
,
arg
.
p_
d0
_grid_
,
arg
.
p_
reduce
s_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c1
_element_op_
,
arg
.
dxs
_in_element_op_
,
arg
.
dxs
_out_element_op_
,
arg
.
d0
_element_op_
,
arg
.
reduce
_in_element_op
s
_
,
arg
.
reduce
_out_element_op
s
_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c0_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c1_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d
_grid_desc_mblock_mperblock_
,
arg
.
reduce
_grid_desc_mblock_mperblock_
,
arg
.
block_2_ctile_map_
);
}
else
...
...
@@ -624,21 +619,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
C0
DataType
,
C1
DataType
,
D
PtrsGlobal
,
Bias
DataType
,
D0
DataType
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C1
ElementwiseOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
D0
ElementwiseOperation
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
D
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
Reduce
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
false
>
;
...
...
@@ -651,21 +646,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_
c0
_grid_
,
arg
.
p_
c1
_grid_
,
arg
.
p_
d
s_grid_
,
arg
.
p_
bias
_grid_
,
arg
.
p_
d0
_grid_
,
arg
.
p_
reduce
s_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c1
_element_op_
,
arg
.
dxs
_in_element_op_
,
arg
.
dxs
_out_element_op_
,
arg
.
d0
_element_op_
,
arg
.
reduce
_in_element_op
s
_
,
arg
.
reduce
_out_element_op
s
_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c0_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c1_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d
_grid_desc_mblock_mperblock_
,
arg
.
reduce
_grid_desc_mblock_mperblock_
,
arg
.
block_2_ctile_map_
);
}
...
...
@@ -700,45 +695,76 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
const
C0DataType
*
p_c0
,
const
C1DataType
*
p_c1
,
DPtrsGlobal
p_dxs
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC1
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
C1ElementwiseOperation
c1_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsReduceAccElementwiseOperation
dxs_out_element_op
)
static
constexpr
int
NumReduce
=
ReduceOperations
::
Size
();
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_bias
,
std
::
array
<
const
void
*
,
1
>
p_ds
,
void
*
p_c
,
std
::
array
<
void
*
,
NumReduce
>
p_reduces
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
std
::
array
<
ck
::
index_t
,
1
>
StrideDs
,
std
::
array
<
void
*
,
3
>
gemm_element_ops
,
std
::
array
<
void
*
,
1
>
d_element_ops
,
std
::
array
<
void
*
,
NumReduce
>
reduce_in_element_op
,
std
::
array
<
void
*
,
NumReduce
>
reduce_out_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
p_c0
,
p_c1
,
p_dxs
,
MRaw
,
NRaw
,
KRaw
,
ReducePtrsGlobal
reduce_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReducePtrsGlobal
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
static_cast
<
T
*>
(
p_reduces
[
I
]);
},
Number
<
NumReduce
>
{});
ReduceInElementwiseOperations
reduce_in_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceInElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_in_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
ReduceAccElementwiseOperations
reduce_out_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceAccElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_out_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
AElementwiseOperation
a_element_op
=
*
(
static_cast
<
AElementwiseOperation
*>
(
gemm_element_ops
[
0
]));
BElementwiseOperation
b_element_op
=
*
(
static_cast
<
BElementwiseOperation
*>
(
gemm_element_ops
[
1
]));
CElementwiseOperation
c_element_op
=
*
(
static_cast
<
CElementwiseOperation
*>
(
gemm_element_ops
[
2
]));
D0ElementwiseOperation
d_element_op
=
*
(
static_cast
<
D0ElementwiseOperation
*>
(
d_element_ops
[
0
]));
return
Argument
{
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
const
BiasDataType
*>
(
p_bias
),
static_cast
<
const
D0DataType
*>
(
p_ds
[
0
]),
reduce_tuple
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
Stride
C1
,
Stride
Ds
[
0
]
,
a_element_op
,
b_element_op
,
c_element_op
,
c1
_element_op
,
dxs
_in_element_op
,
dxs
_out_element_op
};
d
_element_op
,
reduce
_in_element_op
s
,
reduce
_out_element_op
s
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -747,45 +773,74 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_bias
,
std
::
array
<
const
void
*
,
1
>
p_ds
,
void
*
p_c
,
const
void
*
p_c0
,
const
void
*
p_c1
,
void
*
p_dxs
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC1
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
C1ElementwiseOperation
c1_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsReduceAccElementwiseOperation
dxs_out_element_op
,
std
::
array
<
void
*
,
NumReduce
>
p_reduces
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
std
::
array
<
ck
::
index_t
,
1
>
StrideDs
,
std
::
array
<
void
*
,
3
>
gemm_element_ops
,
std
::
array
<
void
*
,
1
>
d_element_ops
,
std
::
array
<
void
*
,
NumReduce
>
reduce_in_element_op
,
std
::
array
<
void
*
,
NumReduce
>
reduce_out_element_op
,
index_t
/* KBatch */
=
1
)
override
{
DPtrsGlobal
dxs_tuple
=
*
(
static_cast
<
DPtrsGlobal
*>
(
p_dxs
));
ReducePtrsGlobal
reduce_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReducePtrsGlobal
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
static_cast
<
T
*>
(
p_reduces
[
I
]);
},
Number
<
NumReduce
>
{});
ReduceInElementwiseOperations
reduce_in_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceInElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_in_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
ReduceAccElementwiseOperations
reduce_out_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceAccElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_out_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
AElementwiseOperation
a_element_op
=
*
(
static_cast
<
AElementwiseOperation
*>
(
gemm_element_ops
[
0
]));
BElementwiseOperation
b_element_op
=
*
(
static_cast
<
BElementwiseOperation
*>
(
gemm_element_ops
[
1
]));
CElementwiseOperation
c_element_op
=
*
(
static_cast
<
CElementwiseOperation
*>
(
gemm_element_ops
[
2
]));
D0ElementwiseOperation
d_element_op
=
*
(
static_cast
<
D0ElementwiseOperation
*>
(
d_element_ops
[
0
]));
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
const
C0
DataType
*>
(
p_
c0
),
static_cast
<
const
C1
DataType
*>
(
p_
c1
),
dxs
_tuple
,
M
Raw
,
N
Raw
,
K
Raw
,
static_cast
<
const
Bias
DataType
*>
(
p_
bias
),
static_cast
<
const
D0
DataType
*>
(
p_
ds
[
0
]
),
reduce
_tuple
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
Stride
C1
,
Stride
Ds
[
0
]
,
a_element_op
,
b_element_op
,
c_element_op
,
c1
_element_op
,
dxs
_in_element_op
,
dxs
_out_element_op
);
d
_element_op
,
reduce
_in_element_op
s
,
reduce
_out_element_op
s
);
}
// polymorphic
...
...
@@ -800,7 +855,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGemmReduce_Xdl_CShuffle"
str
<<
"DeviceGemm
BiasAdd
Reduce_Xdl_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/device_
batched_gemm_reduc
e.hpp
→
include/ck/tensor_operation/gpu/device/device_
gemm_bias_c_permut
e.hpp
View file @
f5de8b57
...
...
@@ -2,52 +2,55 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <array>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
struct
DEGridDesc_M0_M1_M2_N0_N1
{
ck
::
index_t
M0_
,
M1_
,
M2_
,
N0_
,
N1_
;
ck
::
index_t
stride_M0_
,
stride_M1_
,
stride_M2_
,
stride_N0_
,
stride_N1_
;
};
// input : A[M, K], B[K, N],
// input : D[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D)
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsReduceAccElementwiseOperation
>
struct
DeviceBatchedGemmReduce
:
public
BaseOperator
typename
CDEElementwiseOperation
>
struct
DeviceGemmBiasCPermute
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_
c
,
void
*
p_
dxs
,
const
void
*
p_
d
,
void
*
p_
e
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
DEGridDesc_M0_M1_M2_N0_N1
d_gride_desc
,
DEGridDesc_M0_M1_M2_N0_N1
e_gride_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsReduceAccElementwiseOperation
dxs_out_element_op
,
ck
::
index_t
Batch
)
=
0
;
CDEElementwiseOperation
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsReduceAccElementwiseOperation
>
using
DeviceBatchedGemmReducePtr
=
std
::
unique_ptr
<
DeviceBatchedGemmReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsReduceAccElementwiseOperation
>>
;
typename
CElementwiseOperation
>
using
DeviceGemmBiasCPermutePtr
=
std
::
unique_ptr
<
DeviceGemmBiasCPermute
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp
0 → 100644
View file @
f5de8b57
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatDsPointer
,
typename
FloatE
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_bias_c_permute
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatDsPointer
p_ds_grid
,
FloatE
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_etile_map
;
#endif
}
}
// namespace ck
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// input : A[M, K], or A[K, N]
// input : B[K, N], or A[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
template
<
typename
ALayout
,
typename
BLayout
,
typename
CDELayout
,
typename
ADataType
,
typename
BDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
DDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmBiasCPermute_Xdl
:
public
DeviceGemmBiasCPermute
<
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceGemmBiasCPermute_Xdl
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
index_t
NumDTensor
=
I1
;
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
}
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
static
auto
MakeEGridDescriptor_M_N
(
DEGridDesc_M0_M1_M2_N0_N1
d_e_grid_desc
)
{
index_t
M0
=
d_e_grid_desc
.
M0_
;
index_t
M1
=
d_e_grid_desc
.
M1_
;
index_t
M2
=
d_e_grid_desc
.
M2_
;
index_t
N0
=
d_e_grid_desc
.
N0_
;
index_t
N1
=
d_e_grid_desc
.
N1_
;
index_t
stride_M0
=
d_e_grid_desc
.
stride_M0_
;
index_t
stride_M1
=
d_e_grid_desc
.
stride_M1_
;
index_t
stride_M2
=
d_e_grid_desc
.
stride_M2_
;
index_t
stride_N0
=
d_e_grid_desc
.
stride_N0_
;
index_t
stride_N1
=
d_e_grid_desc
.
stride_N1_
;
const
auto
MRaw
=
M0
*
M1
*
M2
;
const
auto
NRaw
=
N0
*
N1
;
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
const
auto
c_grid_desc_m0_m1_m2_n0_n1
=
make_naive_tensor_descriptor
(
make_tuple
(
M0
,
M1
,
M2
,
N0
,
N1
),
make_tuple
(
stride_M0
,
stride_M1
,
stride_M2
,
stride_N0
,
stride_N1
));
return
transform_tensor_descriptor
(
c_grid_desc_m0_m1_m2_n0_n1
,
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
)),
make_merge_transform
(
make_tuple
(
N0
,
N1
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_grid_desc_mraw_nraw
;
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
(
DEGridDesc_M0_M1_M2_N0_N1
{}));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
CShuffleDataType
,
ck
::
Tuple
<
DDataType
>
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
EGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
void
*
p_a_grid
,
const
void
*
p_b_grid
,
const
void
*
p_d_grid
,
void
*
p_e_grid
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
DEGridDesc_M0_M1_M2_N0_N1
d_grid_desc
,
DEGridDesc_M0_M1_M2_N0_N1
e_grid_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
// FIXME
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_grid_desc
)},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
{
if
(
MRaw
!=
d_grid_desc
.
M0_
*
d_grid_desc
.
M1_
*
d_grid_desc
.
M2_
)
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
if
(
NRaw
!=
d_grid_desc
.
N0_
*
d_grid_desc
.
N1_
)
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
p_ds_grid_
(
I0
)
=
static_cast
<
const
DDataType
*>
(
p_d_grid
);
const
auto
d_grid_desc_m_n
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
d_grid_desc
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
(
I0
)
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
d_grid_desc_m_n
);
}
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
// FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N
e_grid_desc_m_n_
;
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DefaultBlock2ETileMap
block_2_etile_map_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_gemm_bias_c_permute
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
ck
::
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2ETileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
);
};
float
ave_time
=
0
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_d
,
void
*
p_e
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
DEGridDesc_M0_M1_M2_N0_N1
d_grid_desc
,
DEGridDesc_M0_M1_M2_N0_N1
e_grid_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_d
,
p_e
,
MRaw
,
NRaw
,
KRaw
,
StrideA
,
StrideB
,
d_grid_desc
,
e_grid_desc
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_d
,
void
*
p_e
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
DEGridDesc_M0_M1_M2_N0_N1
d_grid_desc
,
DEGridDesc_M0_M1_M2_N0_N1
e_grid_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_d
,
p_e
,
MRaw
,
NRaw
,
KRaw
,
StrideA
,
StrideB
,
d_grid_desc
,
e_grid_desc
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGemmBiasCPermute_Xdl"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
View file @
f5de8b57
...
...
@@ -9,91 +9,34 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsReduceAccElementwiseOperation
>
template
<
ck
::
index_t
NumDTensor
,
ck
::
index_t
NumReduce
>
struct
DeviceGemmReduce
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_bias
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_c
,
void
*
p_dx
s
,
std
::
array
<
void
*
,
NumReduce
>
p_reduce
s
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b
_element_op
,
CElementwiseOperation
c
_element_op
,
DxsInElementwiseOperation
dxs
_in_element_op
,
DxsReduceAccElementwiseOperation
dxs
_out_element_op
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
std
::
array
<
void
*
,
3
>
gemm
_element_op
s
,
std
::
array
<
void
*
,
NumDTensor
>
d
_element_op
s
,
std
::
array
<
void
*
,
NumReduce
>
reduce
_in_element_op
s
,
std
::
array
<
void
*
,
NumReduce
>
reduce
_out_element_op
s
,
ck
::
index_t
BatchCount
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsReduceAccElementwiseOperation
>
using
DeviceGemmReducePtr
=
std
::
unique_ptr
<
DeviceGemmReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsReduceAccElementwiseOperation
>>
;
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsReduceAccElementwiseOperation
>
struct
DeviceGemmBiasAddReduce
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
const
void
*
p_c0
,
const
void
*
p_c1
,
void
*
p_dxs
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
StrideC1
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
C1ElementwiseOperation
c1_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsReduceAccElementwiseOperation
dxs_out_element_op
,
ck
::
index_t
BatchCount
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsReduceAccElementwiseOperation
>
using
DeviceGemmBiasAddReducePtr
=
std
::
unique_ptr
<
DeviceGemmBiasAddReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C1ElementwiseOperation
,
DxsInElementwiseOperation
,
DxsReduceAccElementwiseOperation
>>
;
template
<
ck
::
index_t
NumDTensor
,
ck
::
index_t
NumReduce
>
using
DeviceGemmReducePtr
=
std
::
unique_ptr
<
DeviceGemmReduce
<
NumDTensor
,
NumReduce
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
View file @
f5de8b57
...
...
@@ -32,14 +32,14 @@ template <typename ALayout,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
ReduceAccDataType
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Dxs
ReduceOperation
,
typename
Dxs
InElementwiseOperation
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
D
GlobalMemoryDataOperation
,
typename
ReduceOperation
s
,
typename
Reduce
InElementwiseOperation
s
,
typename
ReduceAccElementwiseOperation
s
,
typename
Reduce
GlobalMemoryDataOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
...
...
@@ -74,11 +74,7 @@ template <typename ALayout,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsReduceAccElementwiseOperation
>
struct
DeviceGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
0
,
ReduceOperations
::
Size
()
>
{
using
DeviceOp
=
DeviceGemmReduce_Xdl_CShuffle
;
...
...
@@ -350,8 +346,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
}
}
// assume
D
is packed tensor
static
auto
Make
D
GridDescriptor_M
(
index_t
MRaw
)
// assume
Reduce
is packed tensor
static
auto
Make
Reduce
GridDescriptor_M
(
index_t
MRaw
)
{
const
auto
d_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
...
...
@@ -379,7 +375,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
D
GridDesc_M
=
decltype
(
Make
D
GridDescriptor_M
(
1
));
using
Reduce
GridDesc_M
=
decltype
(
Make
Reduce
GridDescriptor_M
(
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
...
...
@@ -388,19 +384,19 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
CShuffleDataType
,
CDataType
,
ReduceAccDataType
,
D
PtrsGlobal
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
Dxs
ReduceOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
ReduceOperation
s
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
InMemoryDataOperationEnum
::
Set
,
D
GlobalMemoryDataOperation
,
Reduce
GlobalMemoryDataOperation
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
D
GridDesc_M
,
Reduce
GridDesc_M
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
...
...
@@ -443,7 +439,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
D
PtrsGlobal
p_
d
s_grid
,
Reduce
PtrsGlobal
p_
reduce
s_grid
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
...
...
@@ -453,24 +449,24 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
Dxs
InElementwiseOperation
dxs
_in_element_op
,
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op
)
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
,
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_
d
s_grid_
{
p_
d
s_grid
},
p_
reduce
s_grid_
{
p_
reduce
s_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
)},
d
_grid_desc_m_
{
DeviceOp
::
Make
D
GridDescriptor_M
(
MRaw
)},
reduce
_grid_desc_m_
{
DeviceOp
::
Make
Reduce
GridDescriptor_M
(
MRaw
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
d
_grid_desc_mblock_mperblock_
{},
reduce
_grid_desc_mblock_mperblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
},
dxs
_in_element_op_
{
dxs
_in_element_op
},
dxs
_out_element_op_
{
dxs
_out_element_op
}
reduce
_in_element_op
s
_
{
reduce
_in_element_op
s
},
reduce
_out_element_op
s
_
{
reduce
_out_element_op
s
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
...
...
@@ -481,8 +477,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
d
_grid_desc_mblock_mperblock_
=
GridwiseGemm
::
Make
D
GridDescriptor_MBlock_MPerBlock
(
d
_grid_desc_m_
);
reduce
_grid_desc_mblock_mperblock_
=
GridwiseGemm
::
Make
Reduce
GridDescriptor_MBlock_MPerBlock
(
reduce
_grid_desc_m_
);
}
}
...
...
@@ -490,20 +486,21 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
D
PtrsGlobal
p_
d
s_grid_
;
Reduce
PtrsGlobal
p_
reduce
s_grid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
D
GridDesc_M
d
_grid_desc_m_
;
Reduce
GridDesc_M
reduce
_grid_desc_m_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DGridDescriptor_MBlock_MPerBlock
d_grid_desc_mblock_mperblock_
;
typename
GridwiseGemm
::
ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
Dxs
InElementwiseOperation
dxs
_in_element_op_
;
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op_
;
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
_
;
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
_
;
};
// Invoker
...
...
@@ -528,7 +525,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.
d
_grid_desc_m_{ " << arg.
d
_grid_desc_m_.GetLength(I0) << "}"
std::cout << "arg.
reduce
_grid_desc_m_{ " << arg.
reduce
_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
}
#endif
...
...
@@ -554,16 +551,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
D
PtrsGlobal
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
D
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
Reduce
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
true
>
;
...
...
@@ -576,16 +573,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_
d
s_grid_
,
arg
.
p_
reduce
s_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
dxs
_in_element_op_
,
arg
.
dxs
_out_element_op_
,
arg
.
reduce
_in_element_op
s
_
,
arg
.
reduce
_out_element_op
s
_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d
_grid_desc_mblock_mperblock_
,
arg
.
reduce
_grid_desc_mblock_mperblock_
,
arg
.
block_2_ctile_map_
);
}
else
...
...
@@ -594,16 +591,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
D
PtrsGlobal
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
Dxs
InElementwiseOperation
,
Dxs
ReduceAccElementwiseOperation
,
Reduce
InElementwiseOperation
s
,
ReduceAccElementwiseOperation
s
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
D
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
Reduce
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
false
>
;
...
...
@@ -616,16 +613,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_
d
s_grid_
,
arg
.
p_
reduce
s_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
dxs
_in_element_op_
,
arg
.
dxs
_out_element_op_
,
arg
.
reduce
_in_element_op
s
_
,
arg
.
reduce
_out_element_op
s
_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d
_grid_desc_mblock_mperblock_
,
arg
.
reduce
_grid_desc_mblock_mperblock_
,
arg
.
block_2_ctile_map_
);
}
...
...
@@ -660,37 +657,75 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
DPtrsGlobal
p_dxs
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsReduceAccElementwiseOperation
dxs_out_element_op
)
static
constexpr
int
NumReduce
=
ReduceOperations
::
Size
();
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_bias
,
std
::
array
<
const
void
*
,
0
>
p_ds
,
void
*
p_c
,
std
::
array
<
void
*
,
NumReduce
>
p_reduces
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
std
::
array
<
ck
::
index_t
,
0
>
StrideDs
,
std
::
array
<
void
*
,
3
>
gemm_element_ops
,
std
::
array
<
void
*
,
0
>
d_element_ops
,
std
::
array
<
void
*
,
NumReduce
>
reduce_in_element_op
,
std
::
array
<
void
*
,
NumReduce
>
reduce_out_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
p_dxs
,
MRaw
,
NRaw
,
KRaw
,
(
void
)
p_bias
;
(
void
)
p_ds
;
(
void
)
StrideDs
;
(
void
)
d_element_ops
;
ReducePtrsGlobal
reduce_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReducePtrsGlobal
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
static_cast
<
T
*>
(
p_reduces
[
I
]);
},
Number
<
NumReduce
>
{});
ReduceInElementwiseOperations
reduce_in_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceInElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_in_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
ReduceAccElementwiseOperations
reduce_out_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceAccElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_out_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
AElementwiseOperation
a_element_op
=
*
(
static_cast
<
AElementwiseOperation
*>
(
gemm_element_ops
[
0
]));
BElementwiseOperation
b_element_op
=
*
(
static_cast
<
BElementwiseOperation
*>
(
gemm_element_ops
[
1
]));
CElementwiseOperation
c_element_op
=
*
(
static_cast
<
CElementwiseOperation
*>
(
gemm_element_ops
[
2
]));
return
Argument
{
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
reduce_tuple
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
dxs
_in_element_op
,
dxs
_out_element_op
};
reduce
_in_element_op
s
,
reduce
_out_element_op
s
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -699,37 +734,73 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_bias
,
std
::
array
<
const
void
*
,
0
>
p_ds
,
void
*
p_c
,
void
*
p_dx
s
,
index_t
M
Raw
,
index_t
N
Raw
,
index_t
K
Raw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b
_element_op
,
CElementwiseOperation
c
_element_op
,
DxsInElementwiseOperation
dxs
_in_element_op
,
DxsReduceAccElementwiseOperation
dxs
_out_element_op
,
index_t
/* KBatch */
=
1
)
override
std
::
array
<
void
*
,
NumReduce
>
p_reduce
s
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
std
::
array
<
ck
::
index_t
,
0
>
StrideDs
,
std
::
array
<
void
*
,
3
>
gemm
_element_op
s
,
std
::
array
<
void
*
,
0
>
d
_element_op
s
,
std
::
array
<
void
*
,
NumReduce
>
reduce
_in_element_op
,
std
::
array
<
void
*
,
NumReduce
>
reduce
_out_element_op
,
ck
::
index_t
=
1
)
override
{
DPtrsGlobal
dxs_tuple
=
*
(
static_cast
<
DPtrsGlobal
*>
(
p_dxs
));
(
void
)
p_bias
;
(
void
)
p_ds
;
(
void
)
StrideDs
;
(
void
)
d_element_ops
;
ReducePtrsGlobal
reduce_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReducePtrsGlobal
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
static_cast
<
T
*>
(
p_reduces
[
I
]);
},
Number
<
NumReduce
>
{});
ReduceInElementwiseOperations
reduce_in_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceInElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_in_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
ReduceAccElementwiseOperations
reduce_out_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceAccElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_out_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
AElementwiseOperation
a_element_op
=
*
(
static_cast
<
AElementwiseOperation
*>
(
gemm_element_ops
[
0
]));
BElementwiseOperation
b_element_op
=
*
(
static_cast
<
BElementwiseOperation
*>
(
gemm_element_ops
[
1
]));
CElementwiseOperation
c_element_op
=
*
(
static_cast
<
CElementwiseOperation
*>
(
gemm_element_ops
[
2
]));
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
dxs
_tuple
,
M
Raw
,
N
Raw
,
K
Raw
,
reduce
_tuple
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
dxs
_in_element_op
,
dxs
_out_element_op
);
reduce
_in_element_op
s
,
reduce
_out_element_op
s
);
}
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/device_normalization.hpp
0 → 100644
View file @
f5de8b57
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
struct
DeviceNormalization
:
public
BaseOperator
{
// inLengths: input tensor extent(s) from high to low dimension
// inStrides: input tensor stride(s) from high to low dimension
// reduceDims: the dimension(s) the normalization operation is applied
// alpha: typeless pointer in host memory storing the alpha scaling value of type AccDataType
// beta: typeless pointer in host memory storing the beta scaling value of type AccDataType
// in_dev: typeless const pointer in device memory storing the input tensor
// out_dev: typeless pointer in device memory storing the output tensor
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
void
*
alpha
,
const
void
*
beta
,
const
void
*
in_dev
,
void
*
out_dev
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
index_t
GetRank
()
const
=
0
;
virtual
index_t
GetNumReduceDim
()
const
=
0
;
};
using
DeviceNormalizationPtr
=
std
::
unique_ptr
<
DeviceNormalization
>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_softmax.hpp
View file @
f5de8b57
...
...
@@ -9,6 +9,7 @@
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
...
...
@@ -33,8 +34,15 @@ template <typename InDataType,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
DeviceSoftmax
:
public
BaseOperator
struct
DeviceSoftmax
:
public
DeviceNormalization
{
static
constexpr
index_t
kRank
=
Rank
;
static
constexpr
index_t
kNumReduceDim
=
NumReduceDim
;
virtual
index_t
GetRank
()
const
override
{
return
kRank
;
}
virtual
index_t
GetNumReduceDim
()
const
override
{
return
kNumReduceDim
;
}
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
// Used for freeloading of some handy functions from DeviceReduceMultiBlock
...
...
@@ -61,18 +69,33 @@ struct DeviceSoftmax : public BaseOperator
using
GridDesc_M_K
=
decltype
(
Reduction
::
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridwiseReduce
=
GridwiseSoftmax_mk_to_mk
<
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
>
;
using
GridwiseSoftmaxGeneric
=
GridwiseSoftmax_mk_to_mk
<
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
,
false
>
;
using
GridwiseSoftmaxSweepOnce
=
GridwiseSoftmax_mk_to_mk
<
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
,
true
>
;
struct
Argument
:
public
Reduction
::
Argument
{
...
...
@@ -121,8 +144,19 @@ struct DeviceSoftmax : public BaseOperator
const
auto
out_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
kernel_main
=
kernel_softmax
<
GridwiseReduce
,
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
>
;
bool
sweep_once
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
const
auto
kernel_main
=
sweep_once
?
kernel_softmax
<
GridwiseSoftmaxSweepOnce
,
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
>
:
kernel_softmax
<
GridwiseSoftmaxGeneric
,
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
>
;
float
avg_time
=
0
;
...
...
@@ -167,24 +201,34 @@ struct DeviceSoftmax : public BaseOperator
return
true
;
};
// inLengths: input tensor extent(s) from high to low dimension
// inStrides: input tensor stride(s) from high to low dimension
// reduceDims: the dimension(s) the softmax normalization operate on
// alpha: typeless pointer in host memory storing the alpha scaling value as type AccDataType
// beta: typeless pointer in host memory storing the beta scaling value as type AccDataType
// in_dev: typeless const pointer in device memory storing the input tensor
// out_dev: typeless pointer in device memory storing the output tensor
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
int
>
reduceDims
,
AccDataType
alpha
,
AccDataType
beta
,
const
void
*
alpha
,
const
void
*
beta
,
const
void
*
in_dev
,
void
*
out_dev
)
void
*
out_dev
)
override
{
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
reduceDims
,
alpha
,
beta
,
*
static_cast
<
const
AccDataType
*>
(
alpha
)
,
*
static_cast
<
const
AccDataType
*>
(
beta
)
,
static_cast
<
const
InDataType
*>
(
in_dev
),
static_cast
<
OutDataType
*>
(
out_dev
));
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
f5de8b57
...
...
@@ -11,8 +11,8 @@ namespace element_wise {
struct
Add
{
template
<
typename
T
>
__host__
__device__
constexpr
void
operator
()(
T
&
y
,
const
T
&
x0
,
const
T
&
x1
)
const
;
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__host__
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
;
template
<
>
__host__
__device__
constexpr
void
...
...
@@ -28,6 +28,13 @@ struct Add
y
=
x0
+
x1
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
float
&
x0
,
const
half_t
&
x1
)
const
{
y
=
type_convert
<
half_t
>
(
x0
)
+
x1
;
};
// Question: should half_t be supported ?
template
<
>
__host__
__device__
constexpr
void
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
View file @
f5de8b57
...
...
@@ -23,19 +23,19 @@ template <typename GridwiseGemm,
typename
FloatC
,
typename
FloatC0
,
typename
FloatC1
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
Dxs
InElementwiseOperation
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
Reduce
InElementwiseOperation
s
,
typename
ReduceAccElementwiseOperation
s
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
D
GridDescriptor_MBlock_MPerBlock
,
typename
Reduce
GridDescriptor_MBlock_MPerBlock
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
__global__
void
...
...
@@ -46,15 +46,15 @@ __global__ void
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC0
*
__restrict__
p_
c0
_grid
,
const
FloatC1
*
__restrict__
p_
c1
_grid
,
D
PtrsGlobal
p_
d
s_grid
,
const
FloatC0
*
__restrict__
p_
bias
_grid
,
const
FloatC1
*
__restrict__
p_
d0
_grid
,
Reduce
PtrsGlobal
p_
reduce
s_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
C1ElementwiseOperation
c1_element_op
,
const
Dxs
InElementwiseOperation
dxs
_in_element_op
,
const
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op
,
const
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
,
const
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -63,7 +63,7 @@ __global__ void
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
D
GridDescriptor_MBlock_MPerBlock
d
_grid_desc_mblock_mperblock
,
const
Reduce
GridDescriptor_MBlock_MPerBlock
reduce
_grid_desc_mblock_mperblock
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...
...
@@ -72,42 +72,42 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_
c0
_grid
,
p_
c1
_grid
,
p_
d
s_grid
,
p_
bias
_grid
,
p_
d0
_grid
,
p_
reduce
s_grid
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
c1_element_op
,
dxs
_in_element_op
,
dxs
_out_element_op
,
reduce
_in_element_op
s
,
reduce
_out_element_op
s
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
d
_grid_desc_mblock_mperblock
,
reduce
_grid_desc_mblock_mperblock
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_
c0
_grid
;
ignore
=
p_
c1
_grid
;
ignore
=
p_
d
s_grid
;
ignore
=
p_
bias
_grid
;
ignore
=
p_
d0
_grid
;
ignore
=
p_
reduce
s_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c1_element_op
;
ignore
=
dxs
_in_element_op
;
ignore
=
dxs
_out_element_op
;
ignore
=
reduce
_in_element_op
s
;
ignore
=
reduce
_out_element_op
s
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c0_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c1_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
d
_grid_desc_mblock_mperblock
;
ignore
=
reduce
_grid_desc_mblock_mperblock
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
...
...
@@ -119,22 +119,22 @@ template <typename FloatAB,
typename
FloatC0
,
typename
FloatC1
,
typename
FloatReduceAcc
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
Dxs
ReduceOperation
,
typename
Dxs
InElementwiseOperation
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
ReduceOperation
s
,
typename
Reduce
InElementwiseOperation
s
,
typename
ReduceAccElementwiseOperation
s
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
D
GlobalMemoryDataOperation
,
typename
Reduce
GlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
typename
C0GridDesc_M_N
,
typename
C1GridDesc_M_N
,
typename
D
GridDesc_M
,
typename
Reduce
GridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
...
...
@@ -321,18 +321,18 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
__host__
__device__
static
constexpr
auto
Make
D
GridDescriptor_MBlock_MPerBlock
(
const
D
GridDesc_M
&
d_grid_desc_m
)
Make
Reduce
GridDescriptor_MBlock_MPerBlock
(
const
Reduce
GridDesc_M
&
d_grid_desc_m
)
{
const
auto
M
=
d_grid_desc_m
.
GetLength
(
I0
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
d
_grid_desc_mblock_mperblock
=
transform_tensor_descriptor
(
const
auto
reduce
_grid_desc_mblock_mperblock
=
transform_tensor_descriptor
(
d_grid_desc_m
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{}));
return
d
_grid_desc_mblock_mperblock
;
return
reduce
_grid_desc_mblock_mperblock
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
...
...
@@ -352,36 +352,37 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
C1GridDesc_M_N
{}))
>
;
using
D
GridDescriptor_MBlock_MPerBlock
=
remove_cvref_t
<
decltype
(
Make
D
GridDescriptor_MBlock_MPerBlock
(
D
GridDesc_M
{}))
>
;
using
Reduce
GridDescriptor_MBlock_MPerBlock
=
remove_cvref_t
<
decltype
(
Make
Reduce
GridDescriptor_MBlock_MPerBlock
(
Reduce
GridDesc_M
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC0
*
__restrict__
p_c0_grid
,
const
FloatC1
*
__restrict__
p_c1_grid
,
DPtrsGlobal
p_ds_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
C1ElementwiseOperation
&
c1_element_op
,
const
DxsInElementwiseOperation
&
dxs_in_element_op
,
const
DxsReduceAccElementwiseOperation
&
dxs_out_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
DGridDescriptor_MBlock_MPerBlock
&
d_grid_desc_mblock_mperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC0
*
__restrict__
p_bias_grid
,
const
FloatC1
*
__restrict__
p_d0_grid
,
ReducePtrsGlobal
p_reduces_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
C1ElementwiseOperation
&
c1_element_op
,
const
ReduceInElementwiseOperations
&
reduce_in_element_ops
,
const
ReduceAccElementwiseOperations
&
reduce_out_element_ops
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ReduceGridDescriptor_MBlock_MPerBlock
&
reduce_grid_desc_mblock_mperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
...
...
@@ -390,9 +391,9 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
c0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
c0
_grid
,
c0_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_
bias
_grid
,
c0_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
c1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
c1
_grid
,
c1_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_
d0
_grid
,
c1_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
block_work_idx
=
...
...
@@ -725,12 +726,12 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
mreduce_per_thread
>
{},
Number
<
nreduce_per_thread
>
{}));
// VGPR
d_
reduce_thread_desc_mperblock
constexpr
auto
d_
reduce_thread_desc_mperblock
=
// VGPR reduce_thread_desc_mperblock
constexpr
auto
reduce_thread_desc_mperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
mreduce_per_thread
>
{}));
// VGPR
d_
reduce_thread_desc_mblock_mperblock
constexpr
auto
d_
reduce_thread_desc_mblock_mperblock
=
// VGPR reduce_thread_desc_mblock_mperblock
constexpr
auto
reduce_thread_desc_mblock_mperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
mreduce_per_thread
>
{}));
auto
c_reduce_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
...
...
@@ -759,29 +760,29 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
1
,
true
>
{
c_reduce_block_desc_mperblock_nperblock
,
c_reduce_thread_data_idx_begin
};
auto
dxs_
reduce_thread_copy_vgpr_to_global
=
generate_tuple
(
auto
reduce_
tuple_
thread_copy_vgpr_to_global
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
p_
d
_grid
=
p_
d
s_grid
[
I
];
auto
d_out
_element_op
=
dxs
_out_element_op
[
I
];
auto
p_
reduce
_grid
=
p_
reduce
s_grid
[
I
];
auto
reduce_acc
_element_op
=
reduce
_out_element_op
s
[
I
];
return
ThreadwiseTensorSliceTransfer_v1r3
<
FloatReduceAcc
,
remove_pointer_t
<
decltype
(
p_
d
_grid
)
>
,
decltype
(
d_
reduce_thread_desc_mblock_mperblock
),
decltype
(
d
_grid_desc_mblock_mperblock
),
decltype
(
d_out
_element_op
),
remove_pointer_t
<
decltype
(
p_
reduce
_grid
)
>
,
decltype
(
reduce_thread_desc_mblock_mperblock
),
decltype
(
reduce
_grid_desc_mblock_mperblock
),
decltype
(
reduce_acc
_element_op
),
Sequence
<
1
,
mreduce_per_thread
>
,
Sequence
<
0
,
1
>
,
1
,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
D
GlobalMemoryDataOperation
::
At
(
I
),
Reduce
GlobalMemoryDataOperation
::
At
(
I
),
1
,
false
>
{
d
_grid_desc_mblock_mperblock
,
false
>
{
reduce
_grid_desc_mblock_mperblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
c_reduce_thread_data_idx_begin
[
I0
]),
// mperblock
d_out
_element_op
};
reduce_acc
_element_op
};
},
Number
<
p_
d
s_grid
.
Size
()
>
{});
Number
<
p_
reduce
s_grid
.
Size
()
>
{});
// c0 and c1
constexpr
auto
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock
=
...
...
@@ -909,35 +910,35 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
static_for
<
0
,
p_
d
s_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
auto
&
p_
d
_grid
=
p_
d
s_grid
[
In
];
static_for
<
0
,
p_
reduce
s_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
auto
&
p_
reduce
_grid
=
p_
reduce
s_grid
[
In
];
auto
d
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
d
_grid
,
d
_grid_desc_mblock_mperblock
.
GetElementSpaceSize
());
auto
reduce
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
reduce
_grid
,
reduce
_grid_desc_mblock_mperblock
.
GetElementSpaceSize
());
auto
d
_thread_buf
=
auto
reduce
_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
d_
reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
auto
&
d
_in_element_op
=
dxs
_in_element_op
[
In
];
auto
&
reduce
_in_element_op
=
reduce
_in_element_op
s
[
In
];
auto
&
d_
reduce_thread_copy_vgpr_to_global
=
dxs_
reduce_thread_copy_vgpr_to_global
(
In
);
auto
&
reduce_thread_copy_vgpr_to_global
=
reduce_
tuple_
thread_copy_vgpr_to_global
(
In
);
using
D
ReduceOperation
=
remove_cvref_t
<
decltype
(
Dxs
ReduceOperation
{}[
In
])
>
;
using
ReduceOperation
=
remove_cvref_t
<
decltype
(
ReduceOperation
s
{}[
In
])
>
;
using
ThreadwiseReduce
=
ThreadwiseReduction
<
FloatReduceAcc
,
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
d_
reduce_thread_desc_mperblock
),
D
ReduceOperation
,
decltype
(
reduce_thread_desc_mperblock
),
ReduceOperation
,
false
>
;
// Global write Gemm shuffle + reduction
const
auto
d_zero
Val
=
D
ReduceOperation
::
template
GetIdentityValue
<
FloatReduceAcc
>();
const
auto
reduce_identity
Val
=
ReduceOperation
::
template
GetIdentityValue
<
FloatReduceAcc
>();
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
[
&
](
auto
I
)
{
d
_thread_buf
(
I
)
=
d_zero
Val
;
});
[
&
](
auto
I
)
{
reduce
_thread_buf
(
I
)
=
reduce_identity
Val
;
});
// reduce in VGPR
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
im
)
{
...
...
@@ -946,26 +947,25 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Number
<
c_reduce_thread_desc_mperblock_nperblock
.
CalculateOffset
(
make_tuple
(
im
,
in
))
>
{};
d
_in_element_op
(
c_reduce_thread_buf
(
offset
),
c_reduce_thread_buf
(
offset
));
reduce
_in_element_op
(
c_reduce_thread_buf
(
offset
),
c_reduce_thread_buf
(
offset
));
});
});
ThreadwiseReduce
::
Reduce
(
c_reduce_thread_buf
,
d
_thread_buf
);
ThreadwiseReduce
::
Reduce
(
c_reduce_thread_buf
,
reduce
_thread_buf
);
// copy from VGPR to Global
d_reduce_thread_copy_vgpr_to_global
.
Run
(
d_reduce_thread_desc_mblock_mperblock
,
make_tuple
(
I0
,
I0
),
d_thread_buf
,
d_grid_desc_mblock_mperblock
,
d_grid_buf
);
reduce_thread_copy_vgpr_to_global
.
Run
(
reduce_thread_desc_mblock_mperblock
,
make_tuple
(
I0
,
I0
),
reduce_thread_buf
,
reduce_grid_desc_mblock_mperblock
,
reduce_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
d_
reduce_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
d
_grid_desc_mblock_mperblock
,
reduce_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
reduce
_grid_desc_mblock_mperblock
,
make_tuple
(
c_global_step
[
I0
],
c_global_step
[
I1
]));
}
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
View file @
f5de8b57
...
...
@@ -21,16 +21,16 @@ namespace ck {
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Dxs
InElementwiseOperation
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
Reduce
InElementwiseOperation
s
,
typename
ReduceAccElementwiseOperation
s
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
D
GridDescriptor_MBlock_MPerBlock
,
typename
Reduce
GridDescriptor_MBlock_MPerBlock
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
__global__
void
...
...
@@ -41,17 +41,17 @@ __global__ void
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
D
PtrsGlobal
p_
d
s_grid
,
Reduce
PtrsGlobal
p_
reduce
s_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Dxs
InElementwiseOperation
dxs
_in_element_op
,
const
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op
,
const
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
,
const
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
D
GridDescriptor_MBlock_MPerBlock
d
_grid_desc_mblock_mperblock
,
const
Reduce
GridDescriptor_MBlock_MPerBlock
reduce
_grid_desc_mblock_mperblock
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...
...
@@ -60,32 +60,32 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_
d
s_grid
,
p_
reduce
s_grid
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
dxs
_in_element_op
,
dxs
_out_element_op
,
reduce
_in_element_op
s
,
reduce
_out_element_op
s
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
d
_grid_desc_mblock_mperblock
,
reduce
_grid_desc_mblock_mperblock
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_
d
s_grid
;
ignore
=
p_
reduce
s_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
dxs
_in_element_op
;
ignore
=
dxs
_out_element_op
;
ignore
=
reduce
_in_element_op
s
;
ignore
=
reduce
_out_element_op
s
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
d
_grid_desc_mblock_mperblock
;
ignore
=
reduce
_grid_desc_mblock_mperblock
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
...
...
@@ -95,19 +95,19 @@ template <typename FloatAB,
typename
FloatCShuffle
,
typename
FloatC
,
typename
FloatReduceAcc
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Dxs
ReduceOperation
,
typename
Dxs
InElementwiseOperation
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
ReduceOperation
s
,
typename
Reduce
InElementwiseOperation
s
,
typename
ReduceAccElementwiseOperation
s
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
D
GlobalMemoryDataOperation
,
typename
Reduce
GlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
typename
D
GridDesc_M
,
typename
Reduce
GridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
...
...
@@ -293,18 +293,18 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
__host__
__device__
static
constexpr
auto
Make
D
GridDescriptor_MBlock_MPerBlock
(
const
D
GridDesc_M
&
d_grid_desc_m
)
Make
Reduce
GridDescriptor_MBlock_MPerBlock
(
const
Reduce
GridDesc_M
&
d_grid_desc_m
)
{
const
auto
M
=
d_grid_desc_m
.
GetLength
(
I0
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
d
_grid_desc_mblock_mperblock
=
transform_tensor_descriptor
(
const
auto
reduce
_grid_desc_mblock_mperblock
=
transform_tensor_descriptor
(
d_grid_desc_m
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{}));
return
d
_grid_desc_mblock_mperblock
;
return
reduce
_grid_desc_mblock_mperblock
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
...
...
@@ -318,29 +318,30 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
D
GridDescriptor_MBlock_MPerBlock
=
remove_cvref_t
<
decltype
(
Make
D
GridDescriptor_MBlock_MPerBlock
(
D
GridDesc_M
{}))
>
;
using
Reduce
GridDescriptor_MBlock_MPerBlock
=
remove_cvref_t
<
decltype
(
Make
Reduce
GridDescriptor_MBlock_MPerBlock
(
Reduce
GridDesc_M
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
DPtrsGlobal
p_ds_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
DxsInElementwiseOperation
&
dxs_in_element_op
,
const
DxsReduceAccElementwiseOperation
&
dxs_out_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
DGridDescriptor_MBlock_MPerBlock
&
d_grid_desc_mblock_mperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
ReducePtrsGlobal
p_reduces_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
ReduceInElementwiseOperations
&
reduce_in_element_ops
,
const
ReduceAccElementwiseOperations
&
reduce_out_element_ops
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ReduceGridDescriptor_MBlock_MPerBlock
&
reduce_grid_desc_mblock_mperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
...
...
@@ -706,12 +707,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
mreduce_per_thread
>
{},
Number
<
nreduce_per_thread
>
{}));
// VGPR
d_
reduce_thread_desc_mperblock
constexpr
auto
d_
reduce_thread_desc_mperblock
=
// VGPR reduce_thread_desc_mperblock
constexpr
auto
reduce_thread_desc_mperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
mreduce_per_thread
>
{}));
// VGPR
d_
reduce_thread_desc_mblock_mperblock
constexpr
auto
d_
reduce_thread_desc_mblock_mperblock
=
// VGPR reduce_thread_desc_mblock_mperblock
constexpr
auto
reduce_thread_desc_mblock_mperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
mreduce_per_thread
>
{}));
auto
c_reduce_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
...
...
@@ -740,29 +741,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
1
,
true
>
{
c_reduce_block_desc_mperblock_nperblock
,
c_reduce_thread_data_idx_begin
};
auto
dxs_
reduce_thread_copy_vgpr_to_global
=
generate_tuple
(
auto
reduce_
tuple_
thread_copy_vgpr_to_global
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
p_
d
_grid
=
p_
d
s_grid
[
I
];
auto
d_out
_element_op
=
dxs
_out_element_op
[
I
];
auto
p_
reduce
_grid
=
p_
reduce
s_grid
[
I
];
auto
reduce_acc
_element_op
=
reduce
_out_element_op
s
[
I
];
return
ThreadwiseTensorSliceTransfer_v1r3
<
FloatReduceAcc
,
remove_pointer_t
<
decltype
(
p_
d
_grid
)
>
,
decltype
(
d_
reduce_thread_desc_mblock_mperblock
),
decltype
(
d
_grid_desc_mblock_mperblock
),
decltype
(
d_out
_element_op
),
remove_pointer_t
<
decltype
(
p_
reduce
_grid
)
>
,
decltype
(
reduce_thread_desc_mblock_mperblock
),
decltype
(
reduce
_grid_desc_mblock_mperblock
),
decltype
(
reduce_acc
_element_op
),
Sequence
<
1
,
mreduce_per_thread
>
,
Sequence
<
0
,
1
>
,
1
,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
D
GlobalMemoryDataOperation
::
At
(
I
),
Reduce
GlobalMemoryDataOperation
::
At
(
I
),
1
,
false
>
{
d
_grid_desc_mblock_mperblock
,
false
>
{
reduce
_grid_desc_mblock_mperblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
c_reduce_thread_data_idx_begin
[
I0
]),
// mperblock
d_out
_element_op
};
reduce_acc
_element_op
};
},
Number
<
p_
d
s_grid
.
Size
()
>
{});
Number
<
p_
reduce
s_grid
.
Size
()
>
{});
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
...
...
@@ -797,35 +798,35 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple
(
I0
,
I0
),
c_reduce_thread_buf
);
static_for
<
0
,
p_
d
s_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
auto
&
p_
d
_grid
=
p_
d
s_grid
[
In
];
static_for
<
0
,
p_
reduce
s_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
auto
&
p_
reduce
_grid
=
p_
reduce
s_grid
[
In
];
auto
d
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
d
_grid
,
d
_grid_desc_mblock_mperblock
.
GetElementSpaceSize
());
auto
reduce
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
reduce
_grid
,
reduce
_grid_desc_mblock_mperblock
.
GetElementSpaceSize
());
auto
d
_thread_buf
=
auto
reduce
_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
d_
reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
auto
&
d
_in_element_op
=
dxs
_in_element_op
[
In
];
auto
&
reduce
_in_element_op
=
reduce
_in_element_op
s
[
In
];
auto
&
d_
reduce_thread_copy_vgpr_to_global
=
dxs_
reduce_thread_copy_vgpr_to_global
(
In
);
auto
&
reduce_thread_copy_vgpr_to_global
=
reduce_
tuple_
thread_copy_vgpr_to_global
(
In
);
using
D
ReduceOperation
=
remove_cvref_t
<
decltype
(
Dxs
ReduceOperation
{}[
In
])
>
;
using
ReduceOperation
=
remove_cvref_t
<
decltype
(
ReduceOperation
s
{}[
In
])
>
;
using
ThreadwiseReduce
=
ThreadwiseReduction
<
FloatReduceAcc
,
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
d_
reduce_thread_desc_mperblock
),
D
ReduceOperation
,
decltype
(
reduce_thread_desc_mperblock
),
ReduceOperation
,
false
>
;
// Global write Gemm shuffle + reduction
const
auto
d
_identityVal
=
D
ReduceOperation
::
template
GetIdentityValue
<
FloatReduceAcc
>();
const
auto
reduce
_identityVal
=
ReduceOperation
::
template
GetIdentityValue
<
FloatReduceAcc
>();
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
[
&
](
auto
I
)
{
d
_thread_buf
(
I
)
=
d
_identityVal
;
});
[
&
](
auto
I
)
{
reduce
_thread_buf
(
I
)
=
reduce
_identityVal
;
});
// reduce in VGPR
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
im
)
{
...
...
@@ -834,26 +835,25 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Number
<
c_reduce_thread_desc_mperblock_nperblock
.
CalculateOffset
(
make_tuple
(
im
,
in
))
>
{};
d
_in_element_op
(
c_reduce_thread_buf
(
offset
),
c_reduce_thread_buf
(
offset
));
reduce
_in_element_op
(
c_reduce_thread_buf
(
offset
),
c_reduce_thread_buf
(
offset
));
});
});
ThreadwiseReduce
::
Reduce
(
c_reduce_thread_buf
,
d
_thread_buf
);
ThreadwiseReduce
::
Reduce
(
c_reduce_thread_buf
,
reduce
_thread_buf
);
// copy from VGPR to Global
d_reduce_thread_copy_vgpr_to_global
.
Run
(
d_reduce_thread_desc_mblock_mperblock
,
make_tuple
(
I0
,
I0
),
d_thread_buf
,
d_grid_desc_mblock_mperblock
,
d_grid_buf
);
reduce_thread_copy_vgpr_to_global
.
Run
(
reduce_thread_desc_mblock_mperblock
,
make_tuple
(
I0
,
I0
),
reduce_thread_buf
,
reduce_grid_desc_mblock_mperblock
,
reduce_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
d_
reduce_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
d
_grid_desc_mblock_mperblock
,
reduce_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
reduce
_grid_desc_mblock_mperblock
,
make_tuple
(
c_global_step
[
I0
],
c_global_step
[
I1
]));
}
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
View file @
f5de8b57
...
...
@@ -49,7 +49,8 @@ template <typename InDataType,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
index_t
OutDstVectorSize
,
bool
SweepOnce
>
struct
GridwiseSoftmax_mk_to_mk
{
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
...
...
@@ -75,19 +76,6 @@ struct GridwiseSoftmax_mk_to_mk
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
BlockwiseMaxReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Max
,
false
>
;
// PropagateNan
using
ThreadwiseMaxReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
reduce
::
Max
,
false
>
;
// PropagateNan
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -105,6 +93,11 @@ struct GridwiseSoftmax_mk_to_mk
AccDataType
beta
,
OutDataType
*
const
__restrict__
p_out_value_global
)
{
if
constexpr
(
SweepOnce
)
{
num_k_block_tile_iteration
=
1
;
}
// LDS
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
...
...
@@ -149,6 +142,20 @@ struct GridwiseSoftmax_mk_to_mk
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
// Normally, 0 as invalid element value is adequate since 0 makes no contribution to
// accumulated result. However, in stable softmax, all values 0s or not are subtracted by
// another value_max. As numbers become non-zero, effectively it allows invalid values to
// slip through and contribute to the accumulated result.
//
// The trick here is leveraging the fact that many math functions (add, sub, exp, ...)
// propagate NaNs when operands have NaNs involved. By initialiing invalid element value
// with NaN, an invalid value doing math manipulations is still NaN, which in turn can still
// be identified as an invalid value. We can then discard the invalid values which
// originally failed the bound check during accumulation. This allows to ignore values that
// failed bound check even after multiple math manipulations.
//
// NOTE: reset coordinate after every step because the same threadwise copy will sweep
// through global memory 3 times back and forth
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
AccDataType
,
GridDesc_M_K
,
...
...
@@ -158,7 +165,8 @@ struct GridwiseSoftmax_mk_to_mk
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
true
/* ResetCoordAfterRun */
,
true
/* InvalidElementAsNaN */
>
(
in_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
...
...
@@ -198,21 +206,39 @@ struct GridwiseSoftmax_mk_to_mk
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
),
PassThroughOp
{});
constexpr
auto
in_thread_copy_fwd_step
=
make_multi_index
(
0
,
K_BlockTileSize
);
constexpr
auto
in_thread_copy_bwd_step
=
make_multi_index
(
0
,
-
K_BlockTileSize
);
constexpr
auto
in_thread_copy_fwd_step
=
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
in_thread_copy_bwd_step
=
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
///
/// max(x)
///
const
auto
in_global_val_buf_oob_non_zero
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
reduce
::
Max
::
template
GetIdentityValue
<
InDataType
>());
using
BlockwiseMaxReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Max
,
false
,
// param ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
;
using
ThreadwiseMaxReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
reduce
::
Max
,
false
,
// param ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
;
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
());
index_t
reducedTiles
=
0
;
do
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf
_oob_non_zero
,
in_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
...
...
@@ -232,26 +258,6 @@ struct GridwiseSoftmax_mk_to_mk
///
/// sum(exp(x - max(x)))
///
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
AccDataType
>();
});
// Normally, 0 as invalid element value is adequate since 0 makes no contribution to
// accumulated result. However, in stable softmax, all values 0s or not are subtracted by
// another value_max. As numbers become non-zero, effectively it allows invalid values to
// slip through and contribute to the accumulated result.
//
// The trick here is leveraging the fact that many math functions (add, sub, exp, ...)
// propagate NaNs when operands have NaNs involved. By initialiing invalid element value
// with NaN, an invalid value doing math manipulations is still NaN, which in turn can still
// be identified as an invalid value. We can then discard the invalid values which
// originally failed the bound check during accumulation. This allows to ignore values that
// failed bound check even after multiple math manipulations.
const
auto
in_global_val_buf_oob_nan
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
NumericLimits
<
InDataType
>::
QuietNaN
());
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
...
...
@@ -272,22 +278,25 @@ struct GridwiseSoftmax_mk_to_mk
reducedTiles
=
0
;
do
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf_oob_nan
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
if
constexpr
(
!
SweepOnce
)
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
}
// do element-wise pre-reduction operation
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
in
_thread_buf
(
Number
<
offset
>
{})
=
out
_thread_buf
(
Number
<
offset
>
{})
=
math
::
exp
(
in_thread_buf
(
Number
<
offset
>
{})
-
max_value_buf
(
iM
));
});
});
ThreadwiseSumReduce
::
Reduce
(
in
_thread_buf
,
accu_value_buf
);
ThreadwiseSumReduce
::
Reduce
(
out
_thread_buf
,
accu_value_buf
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_bwd_step
);
...
...
@@ -309,11 +318,14 @@ struct GridwiseSoftmax_mk_to_mk
{
do
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf_oob_nan
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
if
constexpr
(
!
SweepOnce
)
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x)))
...
...
@@ -340,18 +352,27 @@ struct GridwiseSoftmax_mk_to_mk
}
else
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_prior_dst_buf
;
do
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf_oob_nan
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
if
constexpr
(
!
SweepOnce
)
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
}
threadwise_dst_load
.
Run
(
out_grid_desc_m_k
,
out_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
out_thread_buf
);
in_prior_dst_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
...
...
@@ -360,7 +381,7 @@ struct GridwiseSoftmax_mk_to_mk
out_thread_buf
(
Number
<
offset
>
{})
=
alpha
*
math
::
exp
(
in_thread_buf
(
Number
<
offset
>
{})
-
max_value_buf
(
iM
))
/
accu_value_buf
(
iM
)
+
beta
*
out_thread
_buf
(
Number
<
offset
>
{});
beta
*
in_prior_dst
_buf
(
Number
<
offset
>
{});
});
});
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
f5de8b57
...
...
@@ -236,9 +236,14 @@ template <typename SrcData,
index_t
SrcScalarPerVector
,
index_t
SrcScalarStrideInVector
,
bool
SrcResetCoordinateAfterRun
,
bool
InvalidElementAsNaN
=
false
,
typename
enable_if
<
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_v2
{
static_assert
((
InvalidElementAsNaN
&&
!
std
::
is_integral
<
DstData
>::
value
)
||
(
!
InvalidElementAsNaN
),
"Filling invalid element as NaN is only for floating point types"
);
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
...
...
@@ -318,8 +323,18 @@ struct ThreadwiseTensorSliceTransfer_v2
dst_desc
.
CalculateOffset
(
to_multi_index
(
dst_slice_origin_idx
)
+
src_data_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
src_vector
.
template
AsType
<
SrcData
>()[
i
]);
if
constexpr
(
InvalidElementAsNaN
)
{
dst_buf
(
Number
<
dst_offset
>
{})
=
is_src_valid
?
type_convert
<
DstData
>
(
src_vector
.
template
AsType
<
SrcData
>()[
i
])
:
NumericLimits
<
DstData
>::
QuietNaN
();
}
else
{
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
src_vector
.
template
AsType
<
SrcData
>()[
i
]);
}
});
if
constexpr
(
idx_1d
.
value
!=
num_access
-
1
)
...
...
include/ck/utility/math.hpp
View file @
f5de8b57
...
...
@@ -148,6 +148,8 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
template
<
typename
T
>
__device__
T
exp
(
T
x
);
// TODO: add f16 support using v_exp_f16
template
<
>
__device__
float
exp
<
float
>
(
float
x
)
{
...
...
include/ck/utility/reduction_functions_accumulate.hpp
View file @
f5de8b57
...
...
@@ -17,7 +17,7 @@ struct AccumulateWithNanIgnore
{
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
)
{
if
(
!
isnan
(
currVal
))
if
(
!
ck
::
math
::
isnan
(
currVal
))
{
ReduceOperation
{}(
accuVal
,
currVal
);
}
...
...
library/include/ck/library/host_tensor/host_tensor.hpp
View file @
f5de8b57
...
...
@@ -222,6 +222,12 @@ struct Tensor
Tensor
(
const
Tensor
&
other
)
:
mDesc
(
other
.
mDesc
),
mData
(
other
.
mData
)
{}
Tensor
&
operator
=
(
const
Tensor
&
other
)
{
mDesc
=
other
.
mDesc
;
mData
=
other
.
mData
;
}
template
<
typename
F
>
void
ForEach_impl
(
F
&&
f
,
std
::
vector
<
size_t
>&
idx
,
size_t
rank
)
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp
View file @
f5de8b57
...
...
@@ -26,12 +26,11 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor
<
OutDataType
>&
out
,
AccDataType
alpha
,
AccDataType
beta
,
const
index_t
rank
,
const
std
::
vector
<
index_t
>
sm_reduce_dims
)
:
in_
(
in
),
out_
(
out
),
alpha_
(
alpha
),
beta_
(
beta
),
sm_reduce_dims_
(
sm_reduce_dims
)
{
// std::cout << "debug: scalar dims: ";
for
(
in
t
i
=
0
;
i
<
rank
;
i
++
)
for
(
size_
t
i
=
0
;
i
<
in
.
mDesc
.
GetNumOfDimension
()
;
i
++
)
{
if
(
std
::
find
(
sm_reduce_dims
.
begin
(),
sm_reduce_dims
.
end
(),
i
)
==
sm_reduce_dims
.
end
())
...
...
@@ -47,7 +46,6 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor
<
OutDataType
>&
out_
;
AccDataType
alpha_
;
AccDataType
beta_
;
index_t
rank_
;
std
::
vector
<
index_t
>
sm_reduce_dims_
;
std
::
vector
<
index_t
>
sm_scalar_dims_
;
// dim after internal max/sum reduction
};
...
...
@@ -136,10 +134,9 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor
<
OutDataType
>&
out
,
AccDataType
alpha
,
AccDataType
beta
,
const
index_t
rank
,
const
std
::
vector
<
index_t
>
sm_reduce_dims
)
{
return
Argument
{
in
,
out
,
alpha
,
beta
,
rank
,
sm_reduce_dims
};
return
Argument
{
in
,
out
,
alpha
,
beta
,
sm_reduce_dims
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance.hpp
View file @
f5de8b57
...
...
@@ -4,6 +4,7 @@
#pragma once
#include <vector>
#include "ck/utility/functional2.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp
0 → 100644
View file @
f5de8b57
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
using
Normalize
=
ck
::
tensor_operation
::
element_wise
::
Normalize
;
using
DeviceNormalizeFromMeanMeanSquarePtr
=
ck
::
tensor_operation
::
device
::
DeviceElementwisePtr
<
5
,
1
,
2
,
Normalize
>
;
void
add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances
(
std
::
vector
<
DeviceNormalizeFromMeanMeanSquarePtr
>&
instances
);
template
<
typename
InputType
,
typename
MeanType
,
typename
MeanSquareType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
OutputType
>
auto
get_device_normalize_from_mean_meansquare_instances
()
{
std
::
vector
<
DeviceNormalizeFromMeanMeanSquarePtr
>
op_ptrs
;
if
constexpr
(
is_same
<
InputType
,
half_t
>::
value
&&
is_same
<
MeanType
,
float
>::
value
&&
is_same
<
MeanSquareType
,
float
>::
value
&&
is_same
<
GammaDataType
,
half_t
>::
value
&&
is_same
<
BetaDataType
,
half_t
>::
value
&&
is_same
<
OutputType
,
half_t
>::
value
)
{
ck
::
tensor_operation
::
device
::
add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances
(
op_ptrs
);
}
return
op_ptrs
;
}
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp
0 → 100644
View file @
f5de8b57
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
DeviceGemmAddAddMeanSquareMeanPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmReducePtr
<
1
,
2
>
;
void
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmAddAddMeanSquareMeanPtr
>&
);
void
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmAddAddMeanSquareMeanPtr
>&
);
void
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmAddAddMeanSquareMeanPtr
>&
);
void
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmAddAddMeanSquareMeanPtr
>&
);
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
auto
get_device_gemm_add_add_mean_squaremean_instances
()
{
std
::
vector
<
DeviceGemmAddAddMeanSquareMeanPtr
>
op_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment