Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6805df0e
Commit
6805df0e
authored
Jun 18, 2022
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/develop' into gelu
parents
1fdbe3fe
e4584d91
Changes
68
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2941 additions
and
568 deletions
+2941
-568
include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
..._operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
+28
-30
include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+244
-28
include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
...n/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
+813
-0
include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
...ude/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
+55
-11
include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
..._operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
+28
-27
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+24
-31
include/ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp
...nsor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp
+7
-11
include/ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp
.../tensor_operation/gpu/device/device_reduce_multiblock.hpp
+5
-8
include/ck/tensor_operation/gpu/device/device_unary_elementwise.hpp
.../tensor_operation/gpu/device/device_unary_elementwise.hpp
+178
-0
include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp
...ensor_operation/gpu/device/reduction_operator_mapping.hpp
+101
-60
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+145
-56
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+46
-279
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+119
-0
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp
...r_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp
+10
-10
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp
...r_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp
+10
-10
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
...pu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
+989
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
...eration/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
+6
-5
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
...or_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
+3
-1
include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
...k/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp
...nsor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp
+129
-0
No files found.
include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
6805df0e
...
...
@@ -557,11 +557,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
float
ave_time
=
0
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
<
CDataType
,
CDataType
,
CDataType
>
;
using
Substract
=
ck
::
tensor_operation
::
binary_element_wise
::
Substract
<
CDataType
,
CDataType
,
CDataType
>
;
using
GridwiseBinAdd
=
GridwiseBinaryElementwise_1D
<
CDataType
,
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
Subtract
=
ck
::
tensor_operation
::
element_wise
::
Subtract
;
using
GridwiseBinAdd
=
GridwiseBinaryElementwise_1D
<
CDataType
,
CDataType
,
CDataType
,
CDataType
,
...
...
@@ -573,19 +571,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
AScalarPerVector
,
BScalarPerVector
,
CScalarPerVector
>
;
using
GridwiseBinSub
s
tract
=
GridwiseBinaryElementwise_1D
<
CDataType
,
CDataType
,
CDataType
,
CDataType
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
Sub
s
tract
,
MPerThread
,
AScalarPerVector
,
BScalarPerVector
,
CScalarPerVector
>
;
const
auto
add_kernel
=
kernel_binary_elementwise_1d
<
GridwiseBinAdd
,
using
GridwiseBinSubtract
=
GridwiseBinaryElementwise_1D
<
CDataType
,
CDataType
,
CDataType
,
CDataType
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
Subtract
,
MPerThread
,
AScalarPerVector
,
BScalarPerVector
,
CScalarPerVector
>
;
const
auto
add_kernel
=
kernel_binary_elementwise_1d
<
GridwiseBinAdd
,
CDataType
,
CDataType
,
CDataType
,
...
...
@@ -593,14 +591,14 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CGridDesc_M
,
CGridDesc_M
,
Add
>
;
const
auto
sub
s
tract_kernel
=
kernel_binary_elementwise_1d
<
GridwiseBinSub
s
tract
,
CDataType
,
CDataType
,
CDataType
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
Sub
s
tract
>
;
const
auto
subtract_kernel
=
kernel_binary_elementwise_1d
<
GridwiseBinSubtract
,
CDataType
,
CDataType
,
CDataType
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
Subtract
>
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
...
...
@@ -653,7 +651,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
// c_real = aux - aux_2
ave_time
+=
launch_and_time_kernel
(
stream_config
,
sub
s
tract_kernel
,
subtract_kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
...
...
@@ -663,7 +661,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
Sub
s
tract
{});
Subtract
{});
ave_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -764,7 +762,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
// c_real = aux - aux_2
ave_time
+=
launch_and_time_kernel
(
stream_config
,
sub
s
tract_kernel
,
subtract_kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
...
...
@@ -774,7 +772,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
Sub
s
tract
{});
Subtract
{});
ave_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
6805df0e
...
...
@@ -11,6 +11,7 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_bwd_weight.hpp"
#include "gridwise_unary_elementwise_1d.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -628,6 +629,54 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
1
);
}
// type convert descs
template
<
typename
Desc_M0
>
static
auto
PadDescriptor_M0_1d
(
Desc_M0
desc_m0
,
index_t
gridSize
,
index_t
blockSize
)
{
const
auto
m0
=
desc_m0
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
blockSize
*
4
;
const
auto
pad
=
math
::
integer_least_multiple
(
m0
,
loop_step
)
-
m0
;
const
auto
desc_m0_pad
=
transform_tensor_descriptor
(
desc_m0
,
make_tuple
(
make_right_pad_transform
(
m0
,
pad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
desc_m0_pad
;
}
template
<
index_t
Dim
>
static
auto
MakeDescriptor_M0
(
const
std
::
vector
<
index_t
>&
shape
,
const
std
::
vector
<
index_t
>&
stride
,
index_t
gridSize
,
index_t
blockSize
)
{
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
shape
[
I
];
},
Number
<
Dim
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
stride
[
I
];
},
Number
<
Dim
>
{});
// nd desc - [s0, s1, s2, ...]
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleOfShape
,
tupleOfStride
);
// merge nd to 1d desc - [s0 * s1 * ...]
if
constexpr
(
Dim
>
1
)
{
const
auto
desc_m0
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleOfShape
)),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
Dim
>
{})),
make_tuple
(
Sequence
<
0
>
{}));
return
PadDescriptor_M0_1d
(
desc_m0
,
gridSize
,
blockSize
);
}
else
return
PadDescriptor_M0_1d
(
desc
,
gridSize
,
blockSize
);
}
using
TypeConvertFunctor
=
ck
::
tensor_operation
::
element_wise
::
UnaryTypeConvert
<
ck
::
bhalf_t
,
float
>
;
using
GridDesc_M0
=
decltype
(
MakeDescriptor_M0
<
1
>
({
1
},
{
1
},
1
,
1
));
using
GridwiseUEltwise
=
GridwiseUnaryElementwise_1D
<
AccDataType
,
InDataType
,
GridDesc_M0
,
TypeConvertFunctor
,
4
>
;
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NumDimSpatial
>
());
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
...
...
@@ -733,6 +782,55 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
true
,
true
>
;
using
GridwiseGemmAtomicAddFloatBf16Splitk
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
<
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
InMemoryDataOperationEnum
::
AtomicAdd
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerXdl
,
NPerXdl
,
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM
,
ABlockLdsM1PerBlock
,
ABlockLdsM0PerBlock
,
ABlockLdsM1Padding
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
BBlockLdsN1PerBlock
,
BBlockLdsN0PerBlock
,
BBlockLdsN1Padding
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
true
,
true
>
;
// Argument
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
...
...
@@ -910,41 +1008,159 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg
.
block_2_ctile_map_
);
};
// run kernel for bf16 with splitk
const
auto
run_bf16_splitk
=
[
&
](
const
auto
&
kernel
)
{
hipGetErrorString
(
hipMemset
(
arg
.
p_workspace_
,
0
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetElementSpaceSize
()
*
sizeof
(
AccDataType
)));
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
static_cast
<
AccDataType
*>
(
arg
.
p_workspace_
),
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
};
// kernel for type conversion
std
::
vector
<
std
::
size_t
>
filter_dims
{
static_cast
<
std
::
size_t
>
(
arg
.
Conv_K_
),
static_cast
<
std
::
size_t
>
(
arg
.
Conv_C_
)};
filter_dims
.
insert
(
std
::
end
(
filter_dims
),
std
::
begin
(
arg
.
filter_spatial_lengths_
),
std
::
end
(
arg
.
filter_spatial_lengths_
));
int
tensor_size
=
std
::
accumulate
(
filter_dims
.
begin
(),
filter_dims
.
end
(),
1
,
std
::
multiplies
<
int
>
{});
const
index_t
type_convert_grid_size
=
GridwiseUEltwise
::
CalculateGridSize
(
tensor_size
);
GridDesc_M0
a_grid_desc_m0_
=
MakeDescriptor_M0
<
1
>
({
tensor_size
},
{
1
},
type_convert_grid_size
,
256
);
GridDesc_M0
b_grid_desc_m0_
=
MakeDescriptor_M0
<
1
>
({
tensor_size
},
{
1
},
type_convert_grid_size
,
256
);
if
(
!
GridwiseUEltwise
::
CheckValidity
(
a_grid_desc_m0_
,
b_grid_desc_m0_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseUnaryElementwise_1D has invalid setting"
);
}
// run kernel for type conversion
void
*
p_c_grid_tmp_
=
static_cast
<
void
*>
(
arg
.
p_c_grid_
);
InDataType
*
p_c_grid_tmp_bf16_
=
static_cast
<
InDataType
*>
(
p_c_grid_tmp_
);
const
auto
Run_type_convert
=
[
&
](
const
auto
&
kernel
)
{
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
type_convert_grid_size
),
dim3
(
256
),
0
,
static_cast
<
AccDataType
*>
(
arg
.
p_workspace_
),
p_c_grid_tmp_bf16_
,
a_grid_desc_m0_
,
b_grid_desc_m0_
,
TypeConvertFunctor
{});
return
elapsed_time
;
};
if
constexpr
(
std
::
is_same
<
InDataType
,
ck
::
bhalf_t
>::
value
)
{
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
true
>
;
Run
(
kernel
);
if
(
kbatch
==
1
)
{
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
true
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel_type_convert
=
kernel_unary_elementwise_1d
<
GridwiseUEltwise
,
AccDataType
,
InDataType
,
GridDesc_M0
,
TypeConvertFunctor
>
;
const
auto
kernel_conv
=
kernel_gemm_xdlops_bwd_weight
<
GridwiseGemmAtomicAddFloatBf16Splitk
,
ADataType
,
// TODO: distiguish A/B datatype
AccDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
true
>
;
run_bf16_splitk
(
kernel_conv
);
ave_time
+=
Run_type_convert
(
kernel_type_convert
);
}
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
false
>
;
Run
(
kernel
);
if
(
kbatch
==
1
)
{
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
false
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
GridwiseGemmAtomicAddFloatBf16Splitk
,
ADataType
,
// TODO: distiguish A/B datatype
AccDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
false
>
;
run_bf16_splitk
(
kernel
);
}
}
}
else
...
...
include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
0 → 100644
View file @
6805df0e
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
View file @
6805df0e
...
...
@@ -6,19 +6,18 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
typename
DPtrsGlobal
,
typename
AElementwiseOperation
,
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
>
typename
Dxs
Reduce
AccElementwiseOperation
>
struct
DeviceGemmReduce
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
DPtrsGlobal
p_dxs
,
void
*
p_dxs
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
...
...
@@ -29,24 +28,69 @@ struct DeviceGemmReduce : public BaseOperator
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
,
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
,
ck
::
index_t
BatchCount
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
DPtrsGlobal
,
typename
AElementwiseOperation
,
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
>
using
DeviceGemmReducePtr
=
std
::
unique_ptr
<
DeviceGemmReduce
<
DPtrsGlobal
,
AElementwiseOperation
,
typename
DxsReduceAccElementwiseOperation
>
using
DeviceGemmReducePtr
=
std
::
unique_ptr
<
DeviceGemmReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
>>
;
DxsReduceAccElementwiseOperation
>>
;
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsReduceAccElementwiseOperation
>
struct
DeviceGemmBiasAddReduce
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
const
void
*
p_c0
,
const
void
*
p_c1
,
void
*
p_dxs
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
StrideC1
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
C1ElementwiseOperation
c1_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsReduceAccElementwiseOperation
dxs_out_element_op
,
ck
::
index_t
BatchCount
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsReduceAccElementwiseOperation
>
using
DeviceGemmBiasAddReducePtr
=
std
::
unique_ptr
<
DeviceGemmBiasAddReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C1ElementwiseOperation
,
DxsInElementwiseOperation
,
DxsReduceAccElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
View file @
6805df0e
...
...
@@ -32,7 +32,7 @@ template <typename ALayout,
typename
CElementwiseOperation
,
typename
DxsReduceOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
,
typename
Dxs
Reduce
AccElementwiseOperation
,
typename
DGlobalMemoryDataOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
...
...
@@ -68,12 +68,11 @@ template <typename ALayout,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
DPtrsGlobal
,
AElementwiseOperation
,
struct
DeviceGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
>
Dxs
Reduce
AccElementwiseOperation
>
{
using
DeviceOp
=
DeviceGemmReduce_Xdl_CShuffle
;
...
...
@@ -389,7 +388,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
CElementwiseOperation
,
DxsReduceOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
,
Dxs
Reduce
AccElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
DGlobalMemoryDataOperation
,
AGridDesc_AK0_M_AK1
,
...
...
@@ -449,7 +448,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
)
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
...
...
@@ -498,7 +497,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
DxsInElementwiseOperation
dxs_in_element_op_
;
DxsAccElementwiseOperation
dxs_out_element_op_
;
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op_
;
};
// Invoker
...
...
@@ -554,7 +553,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
,
Dxs
Reduce
AccElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -594,7 +593,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
,
Dxs
Reduce
AccElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -669,7 +668,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
)
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
)
{
return
Argument
{
p_a
,
p_b
,
...
...
@@ -691,27 +690,29 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
DPtrsGlobal
p_dxs
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
,
index_t
/* KBatch */
=
1
)
override
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
void
*
p_dxs
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsReduceAccElementwiseOperation
dxs_out_element_op
,
index_t
/* KBatch */
=
1
)
override
{
DPtrsGlobal
dxs_tuple
=
*
(
static_cast
<
DPtrsGlobal
*>
(
p_dxs
));
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
p_
dxs
,
dxs
_tuple
,
MRaw
,
NRaw
,
KRaw
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
6805df0e
...
...
@@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdl
{
grid_size_
=
0
;
gemm_descs_args
_workspace_
=
nullptr
;
p
_workspace_
=
nullptr
;
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_shapes
.
size
());
...
...
@@ -437,8 +437,6 @@ struct DeviceGroupedGemmXdl
std
::
vector
<
GemmDescKernelArg
>
gemm_desc_kernel_arg_
;
void
*
gemm_descs_args_workspace_
;
index_t
grid_size_
;
};
...
...
@@ -488,7 +486,7 @@ struct DeviceGroupedGemmXdl
}
hipGetErrorString
(
hipMemcpy
(
arg
.
gemm_descs_args
_workspace_
,
hipMemcpy
(
arg
.
p
_workspace_
,
arg
.
gemm_desc_kernel_arg_
.
data
(),
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmDescKernelArg
),
hipMemcpyHostToDevice
));
...
...
@@ -507,17 +505,17 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
arg
.
gemm_descs_args
_workspace_
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
arg
.
p
_workspace_
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
}
else
{
...
...
@@ -531,17 +529,17 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
arg
.
gemm_descs_args
_workspace_
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
arg
.
p
_workspace_
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
}
return
ave_time
;
...
...
@@ -635,11 +633,6 @@ struct DeviceGroupedGemmXdl
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
GemmDescKernelArg
);
}
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
workspace_ptr
)
const
override
{
dynamic_cast
<
Argument
*>
(
p_arg
)
->
gemm_descs_args_workspace_
=
workspace_ptr
;
}
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp
View file @
6805df0e
...
...
@@ -35,14 +35,13 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
using
IndexDataType
=
int32_t
;
using
ReduceOperation
=
typename
reduce_binary_operator
<
AccDataType
,
ReduceOpId
>::
opType
;
using
ReduceOperation
=
typename
reduce_binary_operator
<
ReduceOpId
>::
opType
;
using
InElementwiseOperation
=
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
using
AccElementwiseOperation
=
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
static
constexpr
index_t
InSrcOutDstVectorDim
=
0
;
// for NHWC, the dim C is the vector Dim for both input and output in memory, which is
...
...
@@ -178,13 +177,10 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
invariant_lowest_length_
=
C
;
reduce_lowest_length_
=
window_spatial_lengths
[
1
];
// TODO: is this correct?
if
constexpr
(
ReduceOpId
==
ck
::
ReduceTensorOp
::
AVG
)
{
ck
::
index_t
divider
=
window_spatial_lengths
[
0
]
*
window_spatial_lengths
[
1
];
in_element_op_
=
InElementwiseOperation
{
divider
};
acc_element_op_
=
AccElementwiseOperation
{
divider
};
}
int32_t
reduceLength
=
window_spatial_lengths
[
0
]
*
window_spatial_lengths
[
1
];
std
::
tie
(
in_element_op_
,
acc_element_op_
)
=
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
GetElementwiseOperator
(
reduceLength
);
}
const
InDataType
*
p_in_dev_
;
...
...
include/ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp
View file @
6805df0e
...
...
@@ -61,12 +61,9 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
static
constexpr
bool
use_multiblock
=
(
OutMemoryDataOperation
==
InMemoryDataOperationEnum
::
AtomicAdd
);
static
constexpr
bool
out_type_compatible_with_atomic_op
=
std
::
is_same
<
OutDataType
,
float
>::
value
||
std
::
is_same
<
OutDataType
,
double
>::
value
;
static_assert
(
!
use_multiblock
||
(
use_multiblock
&&
out_type_compatible_with_atomic_op
),
"The OutDataType must support the atomic operation for using MultiBlock reduction"
);
static_assert
(
ck
::
reduce
::
InMemoryDataOperatonSupportedOnDataType
<
OutMemoryDataOperation
,
OutDataType
>::
value
,
"The OutDataType must support the specified OutMemoryDataOperation!"
);
static_assert
(
!
use_multiblock
||
(
use_multiblock
&&
!
OutputIndex
),
"MultiBlock reduction can only be used when outputing index is not required"
);
...
...
@@ -349,7 +346,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
if
constexpr
(
use_multiblock
)
{
const
auto
identityVal
=
ck
::
reduce
::
GetIdentityValue
ue
ForInMemoryDataOperation
<
OutDataType
>
(
ck
::
reduce
::
GetIdentityValueForInMemoryDataOperation
<
OutDataType
>
(
OutMemoryDataOperation
);
const
auto
kernel_pre
=
...
...
@@ -492,7 +489,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceReduceMultiBlock
AtomicAdd
<"
<<
BlockSize
<<
","
;
str
<<
(
OutMemoryDataOperation
==
InMemoryDataOperationEnum
::
Set
?
"DeviceReduceBlockWise<"
:
"DeviceReduceMultiBlock<"
)
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"InSrcVectorDim_"
<<
InSrcVectorDim
<<
"_InSrcVectorSize_"
<<
InSrcVectorSize
<<
"_OutDstVectorSize_"
<<
OutDstVectorSize
<<
">"
;
...
...
include/ck/tensor_operation/gpu/device/device_unary_elementwise.hpp
0 → 100644
View file @
6805df0e
#pragma once
#include <iostream>
#include <vector>
#include "device.hpp"
#include "device_base.hpp"
#include "gridwise_unary_elementwise_1d.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
ElementwiseFunctor
,
index_t
Dim
,
index_t
ScalarPerVector
>
struct
DeviceUnaryElementwise
:
public
BaseOperator
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
template
<
typename
Desc_M0
>
static
auto
PadDescriptor_M0_1d
(
Desc_M0
desc_m0
,
index_t
gridSize
,
index_t
blockSize
)
{
const
auto
m0
=
desc_m0
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
blockSize
*
ScalarPerVector
;
const
auto
pad
=
math
::
integer_least_multiple
(
m0
,
loop_step
)
-
m0
;
const
auto
desc_m0_pad
=
transform_tensor_descriptor
(
desc_m0
,
make_tuple
(
make_right_pad_transform
(
m0
,
pad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
desc_m0_pad
;
}
static
auto
MakeDescriptor_M0
(
const
std
::
vector
<
index_t
>&
shape
,
const
std
::
vector
<
index_t
>&
stride
,
index_t
gridSize
,
index_t
blockSize
)
{
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
shape
[
I
];
},
Number
<
Dim
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
stride
[
I
];
},
Number
<
Dim
>
{});
// nd desc - [s0, s1, s2, ...]
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleOfShape
,
tupleOfStride
);
// merge nd to 1d desc - [s0 * s1 * ...]
if
constexpr
(
Dim
>
1
)
{
const
auto
desc_m0
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleOfShape
)),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
Dim
>
{})),
make_tuple
(
Sequence
<
0
>
{}));
return
PadDescriptor_M0_1d
(
desc_m0
,
gridSize
,
blockSize
);
}
else
return
PadDescriptor_M0_1d
(
desc
,
gridSize
,
blockSize
);
}
using
GridDesc_M0
=
decltype
(
MakeDescriptor_M0
({
1
,
1
},
{
1
,
1
},
1
,
1
));
using
GridwiseUEltwise
=
GridwiseUnaryElementwise_1D
<
ADataType
,
BDataType
,
GridDesc_M0
,
ElementwiseFunctor
,
ScalarPerVector
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a
,
BDataType
*
p_b
,
const
std
::
vector
<
index_t
>&
shape
,
const
std
::
vector
<
index_t
>&
stride_a
,
const
std
::
vector
<
index_t
>&
stride_b
,
ElementwiseFunctor
functor
)
:
p_a_
(
p_a
),
p_b_
(
p_b
),
shape_
(
shape
),
functor_
(
functor
),
blockSize_
(
256
)
// FIXME - Calculate the grid size by number of CU in the future
{
index_t
tensor_size
=
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
{});
gridSize_
=
GridwiseUEltwise
::
CalculateGridSize
(
tensor_size
);
a_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_a
,
gridSize_
,
blockSize_
);
b_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_b
,
gridSize_
,
blockSize_
);
}
const
ADataType
*
p_a_
;
BDataType
*
p_b_
;
std
::
vector
<
int
>
shape_
;
GridDesc_M0
a_grid_desc_m0_
;
GridDesc_M0
b_grid_desc_m0_
;
ElementwiseFunctor
functor_
;
index_t
blockSize_
;
index_t
gridSize_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
kernel
=
kernel_unary_elementwise_1d
<
GridwiseUEltwise
,
ADataType
,
BDataType
,
GridDesc_M0
,
ElementwiseFunctor
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
gridSize_
),
dim3
(
arg
.
blockSize_
),
0
,
arg
.
p_a_
,
arg
.
p_b_
,
arg
.
a_grid_desc_m0_
,
arg
.
b_grid_desc_m0_
,
arg
.
functor_
);
return
elapsed_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
pArg
==
nullptr
)
return
false
;
if
(
pArg
->
shape_
.
back
()
%
ScalarPerVector
!=
0
)
return
false
;
return
true
;
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
void
*
p_b
,
std
::
vector
<
index_t
>
shape
,
std
::
vector
<
index_t
>
stride_a
,
std
::
vector
<
index_t
>
stride_b
,
ElementwiseFunctor
functor
)
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
BDataType
*>
(
p_b
),
shape
,
stride_a
,
stride_b
,
functor
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
();
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceBinaryElementwise"
<<
"<"
<<
"ScalarPerVector = "
<<
ScalarPerVector
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp
View file @
6805df0e
...
...
@@ -29,6 +29,7 @@
#include "reduction_operator.hpp"
#include "reduction_enums.hpp"
#include "element_wise_operation.hpp"
#include <tuple>
namespace
ck
{
...
...
@@ -37,77 +38,69 @@ namespace ck {
// The boolean member "indexable" are also provided in reduce_binary_operactor for
// easier checking by the upper-layer codes in the kernels.
template
<
typename
T
,
ReduceTensorOp
Op
>
template
<
ReduceTensorOp
Op
>
struct
reduce_binary_operator
;
template
<
typename
T
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
::
ADD
>
template
<
>
struct
reduce_binary_operator
<
ReduceTensorOp
::
ADD
>
{
using
opType
=
reduce
::
Add
<
T
>
;
using
dataType
=
T
;
using
opType
=
reduce
::
Add
;
static
constexpr
bool
indexable
=
false
;
};
template
<
typename
T
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
::
MUL
>
template
<
>
struct
reduce_binary_operator
<
ReduceTensorOp
::
MUL
>
{
using
opType
=
reduce
::
Mul
<
T
>
;
using
dataType
=
T
;
using
opType
=
reduce
::
Mul
;
static
constexpr
bool
indexable
=
false
;
};
template
<
typename
T
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
::
MIN
>
template
<
>
struct
reduce_binary_operator
<
ReduceTensorOp
::
MIN
>
{
using
opType
=
reduce
::
Min
<
T
>
;
using
dataType
=
T
;
using
opType
=
reduce
::
Min
;
static
constexpr
bool
indexable
=
true
;
};
template
<
typename
T
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
::
MAX
>
template
<
>
struct
reduce_binary_operator
<
ReduceTensorOp
::
MAX
>
{
using
opType
=
reduce
::
Max
<
T
>
;
using
dataType
=
T
;
using
opType
=
reduce
::
Max
;
static
constexpr
bool
indexable
=
true
;
};
template
<
typename
T
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
::
AMAX
>
template
<
>
struct
reduce_binary_operator
<
ReduceTensorOp
::
AMAX
>
{
using
opType
=
reduce
::
AMax
<
T
>
;
using
dataType
=
T
;
using
opType
=
reduce
::
AMax
;
static
constexpr
bool
indexable
=
true
;
};
template
<
typename
T
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
::
AVG
>
template
<
>
struct
reduce_binary_operator
<
ReduceTensorOp
::
AVG
>
{
using
opType
=
reduce
::
Add
<
T
>
;
using
dataType
=
T
;
using
opType
=
reduce
::
Add
;
static
constexpr
bool
indexable
=
false
;
};
template
<
typename
T
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
::
NORM1
>
template
<
>
struct
reduce_binary_operator
<
ReduceTensorOp
::
NORM1
>
{
using
opType
=
reduce
::
Add
<
T
>
;
using
dataType
=
T
;
using
opType
=
reduce
::
Add
;
static
constexpr
bool
indexable
=
false
;
};
template
<
typename
T
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
::
NORM2
>
template
<
>
struct
reduce_binary_operator
<
ReduceTensorOp
::
NORM2
>
{
using
opType
=
reduce
::
Add
<
T
>
;
using
dataType
=
T
;
using
opType
=
reduce
::
Add
;
static
constexpr
bool
indexable
=
false
;
};
...
...
@@ -115,53 +108,101 @@ struct reduce_binary_operator<T, ReduceTensorOp::NORM2>
// The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary
// functor classes.
// The two unary functors are called before and afer the Reduction is executed respectively
template
<
typename
T
,
ReduceTensorOp
Op
,
bool
IsFirstReduce
,
bool
IsLastReduce
>
template
<
ReduceTensorOp
Op
,
bool
IsFirstReduce
,
bool
IsLastReduce
>
struct
reduce_unary_operator
{
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
T
,
T
>
;
using
AccElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
T
,
T
>
;
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
PassThrough
;
using
AccElementwiseOperation
=
tensor_operation
::
element_wise
::
PassThrough
;
static
std
::
tuple
<
InElementwiseOperation
,
AccElementwiseOperation
>
GetElementwiseOperator
(
int32_t
reduceLength
)
{
(
void
)
reduceLength
;
return
std
::
make_tuple
(
InElementwiseOperation
{},
AccElementwiseOperation
{});
};
};
template
<
typename
T
,
bool
IsFirstReduce
>
struct
reduce_unary_operator
<
T
,
ReduceTensorOp
::
AVG
,
IsFirstReduce
,
true
>
template
<
bool
IsFirstReduce
>
struct
reduce_unary_operator
<
ReduceTensorOp
::
AVG
,
IsFirstReduce
,
true
>
{
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
T
,
T
>
;
using
AccElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
T
,
T
,
true
>
;
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
PassThrough
;
using
AccElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryDivide
;
static
std
::
tuple
<
InElementwiseOperation
,
AccElementwiseOperation
>
GetElementwiseOperator
(
int32_t
reduceLength
)
{
return
std
::
make_tuple
(
InElementwiseOperation
{},
AccElementwiseOperation
{
reduceLength
});
};
};
template
<
typename
T
,
bool
IsLastReduce
>
struct
reduce_unary_operator
<
T
,
ReduceTensorOp
::
NORM1
,
true
,
IsLastReduce
>
template
<
bool
IsLastReduce
>
struct
reduce_unary_operator
<
ReduceTensorOp
::
NORM1
,
true
,
IsLastReduce
>
{
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryAbs
<
T
,
T
>
;
using
AccElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
T
,
T
>
;
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryAbs
;
using
AccElementwiseOperation
=
tensor_operation
::
element_wise
::
PassThrough
;
static
std
::
tuple
<
InElementwiseOperation
,
AccElementwiseOperation
>
GetElementwiseOperator
(
int32_t
reduceLength
)
{
(
void
)
reduceLength
;
return
std
::
make_tuple
(
InElementwiseOperation
{},
AccElementwiseOperation
{});
};
};
template
<
typename
T
,
bool
IsLastReduce
>
struct
reduce_unary_operator
<
T
,
ReduceTensorOp
::
AMAX
,
true
,
IsLastReduce
>
template
<
bool
IsLastReduce
>
struct
reduce_unary_operator
<
ReduceTensorOp
::
AMAX
,
true
,
IsLastReduce
>
{
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryAbs
<
T
,
T
>
;
using
AccElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
T
,
T
>
;
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryAbs
;
using
AccElementwiseOperation
=
tensor_operation
::
element_wise
::
PassThrough
;
static
std
::
tuple
<
InElementwiseOperation
,
AccElementwiseOperation
>
GetElementwiseOperator
(
int32_t
reduceLength
)
{
(
void
)
reduceLength
;
return
std
::
make_tuple
(
InElementwiseOperation
{},
AccElementwiseOperation
{});
};
};
template
<
typename
T
>
struct
reduce_unary_operator
<
T
,
ReduceTensorOp
::
NORM2
,
true
,
false
>
template
<
>
struct
reduce_unary_operator
<
ReduceTensorOp
::
NORM2
,
true
,
false
>
{
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
UnarySquare
<
T
,
T
>
;
using
AccElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
T
,
T
>
;
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
UnarySquare
;
using
AccElementwiseOperation
=
tensor_operation
::
element_wise
::
PassThrough
;
static
std
::
tuple
<
InElementwiseOperation
,
AccElementwiseOperation
>
GetElementwiseOperator
(
int32_t
reduceLength
)
{
(
void
)
reduceLength
;
return
std
::
make_tuple
(
InElementwiseOperation
{},
AccElementwiseOperation
{});
};
};
template
<
typename
T
>
struct
reduce_unary_operator
<
T
,
ReduceTensorOp
::
NORM2
,
true
,
true
>
template
<
>
struct
reduce_unary_operator
<
ReduceTensorOp
::
NORM2
,
true
,
true
>
{
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
UnarySquare
<
T
,
T
>
;
using
AccElementwiseOperation
=
tensor_operation
::
element_wise
::
UnarySqrt
<
T
,
T
>
;
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
UnarySquare
;
using
AccElementwiseOperation
=
tensor_operation
::
element_wise
::
UnarySqrt
;
static
std
::
tuple
<
InElementwiseOperation
,
AccElementwiseOperation
>
GetElementwiseOperator
(
int32_t
reduceLength
)
{
(
void
)
reduceLength
;
return
std
::
make_tuple
(
InElementwiseOperation
{},
AccElementwiseOperation
{});
};
};
template
<
typename
T
>
struct
reduce_unary_operator
<
T
,
ReduceTensorOp
::
NORM2
,
false
,
true
>
template
<
>
struct
reduce_unary_operator
<
ReduceTensorOp
::
NORM2
,
false
,
true
>
{
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
T
,
T
>
;
using
AccElementwiseOperation
=
tensor_operation
::
element_wise
::
UnarySqrt
<
T
,
T
>
;
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
PassThrough
;
using
AccElementwiseOperation
=
tensor_operation
::
element_wise
::
UnarySqrt
;
static
std
::
tuple
<
InElementwiseOperation
,
AccElementwiseOperation
>
GetElementwiseOperator
(
int32_t
reduceLength
)
{
(
void
)
reduceLength
;
return
std
::
make_tuple
(
InElementwiseOperation
{},
AccElementwiseOperation
{});
};
};
}
// end of namespace ck
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
6805df0e
...
...
@@ -28,100 +28,189 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
binary_element_wise
{
template
<
typename
Y
,
typename
X1
,
typename
X2
>
struct
Add
;
namespace
element_wise
{
template
<
>
struct
Add
<
double
,
double
,
double
>
struct
Add
{
template
<
typename
T
>
__host__
__device__
constexpr
void
operator
()(
T
&
y
,
const
T
&
x0
,
const
T
&
x1
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
(
double
&
dst
,
const
double
&
src1
,
const
double
&
src2
)
const
operator
()
<
float
>
(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
dst
=
src1
+
src2
;
}
};
y
=
x0
+
x1
;
};
template
<
>
struct
Add
<
float
,
float
,
float
>
{
template
<
>
__host__
__device__
constexpr
void
operator
()
(
float
&
dst
,
const
float
&
src1
,
const
float
&
src2
)
const
operator
()
<
double
>
(
double
&
y
,
const
double
&
x0
,
const
double
&
x1
)
const
{
dst
=
src1
+
src2
;
}
};
y
=
x0
+
x1
;
};
template
<
>
struct
Add
<
half_t
,
half_t
,
half_t
>
{
// Question: should half_t be supported ?
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
half_t
&
x0
,
const
half_t
&
x1
)
const
{
y
=
x0
+
x1
;
};
// Question: should bhalf_t be supported ?
template
<
>
__host__
__device__
constexpr
void
operator
()
(
half_t
&
dst
,
const
half_t
&
src1
,
const
half_t
&
src2
)
const
operator
()
<
bhalf_t
>
(
b
half_t
&
y
,
const
b
half_t
&
x0
,
const
b
half_t
&
x1
)
const
{
dst
=
src1
+
src2
;
const
float
x1_tmp
=
ck
::
type_convert
<
float
>
(
x0
);
const
float
x2_tmp
=
ck
::
type_convert
<
float
>
(
x1
);
const
float
y_tmp
=
x1_tmp
+
x2_tmp
;
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_tmp
);
}
};
template
<
>
struct
Add
<
bhalf_t
,
bhalf_t
,
bhalf_t
>
struct
Subtract
{
template
<
typename
T
>
__host__
__device__
constexpr
void
operator
()(
T
&
y
,
const
T
&
x0
,
const
T
&
x1
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
(
bhalf_t
&
dst
,
const
bhalf_t
&
src1
,
const
bhalf_t
&
src2
)
const
operator
()
<
float
>
(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
const
float
x1
=
ck
::
type_convert
<
float
>
(
src1
);
const
float
x2
=
ck
::
type_convert
<
float
>
(
src2
);
const
float
y
=
x1
+
x2
;
dst
=
ck
::
type_convert
<
bhalf_t
>
(
y
);
}
};
y
=
x0
-
x1
;
};
template
<
typename
Y
,
typename
X1
,
typename
X2
>
struct
Substract
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
double
>
(
double
&
y
,
const
double
&
x0
,
const
double
&
x1
)
const
{
y
=
x0
-
x1
;
};
template
<
>
struct
Substract
<
double
,
double
,
double
>
{
// Question: should half_t be supported ?
template
<
>
__host__
__device__
constexpr
void
operator
()
(
double
&
dst
,
const
double
&
src1
,
const
double
&
src2
)
const
operator
()
<
half_t
>
(
half_t
&
y
,
const
half_t
&
x0
,
const
half_t
&
x1
)
const
{
dst
=
src1
-
src2
;
y
=
x0
-
x1
;
};
// Question: should bhalf_t be supported ?
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x0
,
const
bhalf_t
&
x1
)
const
{
const
float
x1_tmp
=
ck
::
type_convert
<
float
>
(
x0
);
const
float
x2_tmp
=
ck
::
type_convert
<
float
>
(
x1
);
const
float
y_tmp
=
x1_tmp
-
x2_tmp
;
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_tmp
);
}
};
template
<
>
struct
Substract
<
float
,
float
,
float
>
struct
AlphaBetaAdd
{
AlphaBetaAdd
(
float
alpha
,
float
beta
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
template
<
typename
T
>
__host__
__device__
constexpr
void
operator
()(
T
&
y
,
const
T
&
x0
,
const
T
&
x1
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()(
float
&
dst
,
const
float
&
src1
,
const
float
&
src2
)
const
operator
()
<
float
>
(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
dst
=
src1
-
src2
;
}
y
=
alpha_
*
x0
+
beta_
*
x1
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
double
>
(
double
&
y
,
const
double
&
x0
,
const
double
&
x1
)
const
{
y
=
static_cast
<
double
>
(
alpha_
)
*
x0
+
static_cast
<
double
>
(
beta_
)
*
x1
;
};
// Question: should half_t be supported ?
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
half_t
&
x0
,
const
half_t
&
x1
)
const
{
y
=
static_cast
<
half_t
>
(
alpha_
*
static_cast
<
float
>
(
x0
)
+
beta_
*
static_cast
<
float
>
(
x1
));
};
float
alpha_
;
float
beta_
;
};
template
<
>
struct
Substract
<
half_t
,
half_t
,
half_t
>
struct
AddRelu
{
template
<
typename
T
>
__host__
__device__
constexpr
void
operator
()(
T
&
y
,
const
T
&
x0
,
const
T
&
x1
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
(
half_t
&
dst
,
const
half_t
&
src1
,
const
half_t
&
src2
)
const
operator
()
<
float
>
(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
dst
=
src1
-
src2
;
}
const
float
a
=
x0
+
x1
;
y
=
a
>
0.0
f
?
a
:
0.0
f
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
double
>
(
double
&
y
,
const
double
&
x0
,
const
double
&
x1
)
const
{
const
double
a
=
x0
+
x1
;
y
=
a
>
0.0
?
a
:
0.0
;
};
// Question: should half_t be supported ?
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
half_t
&
x0
,
const
half_t
&
x1
)
const
{
const
half_t
a
=
x0
+
x1
;
y
=
a
>
static_cast
<
half_t
>
(
0.0
f
)
?
a
:
static_cast
<
half_t
>
(
0.0
f
);
};
};
template
<
>
struct
Substract
<
bhalf_t
,
bhalf_t
,
bhalf_t
>
struct
AddHardswish
{
template
<
typename
T
>
__host__
__device__
constexpr
void
operator
()(
T
&
y
,
const
T
&
x0
,
const
T
&
x1
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
(
bhalf_t
&
dst
,
const
bhalf_t
&
src1
,
const
bhalf_t
&
src2
)
const
operator
()
<
float
>
(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
const
float
x1
=
ck
::
type_convert
<
float
>
(
src1
);
const
float
x2
=
ck
::
type_convert
<
float
>
(
src2
);
const
float
y
=
x1
-
x2
;
dst
=
ck
::
type_convert
<
bhalf_t
>
(
y
);
}
float
a
=
x0
+
x1
;
float
b
=
a
+
float
{
3
};
float
c
=
(
b
>
0
)
*
(
b
>
6.0
f
?
6.0
f
:
b
)
*
a
*
0.166667
f
;
y
=
c
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
double
>
(
double
&
y
,
const
double
&
x0
,
const
double
&
x1
)
const
{
double
a
=
x0
+
x1
;
double
b
=
a
+
3.0
;
double
c
=
(
b
>
0
)
*
(
b
>
6.0
?
6.0
:
b
)
*
a
*
0.166667
;
y
=
c
;
};
// Question: should half_t be supported ?
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
half_t
&
x0
,
const
half_t
&
x1
)
const
{
float
a
=
x0
+
x1
;
float
b
=
a
+
3.0
f
;
float
c
=
(
b
>
0
)
*
(
b
>
6.0
f
?
6.0
f
:
b
)
*
a
*
0.166667
f
;
y
=
c
;
};
};
}
// namespace binary_element_wise
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
6805df0e
#pragma once
#include "data_type.hpp"
#include "math_v2.hpp"
#include "unary_element_wise_operation.hpp"
#include "binary_element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
element_wise
{
struct
PassThrough
{
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
;
}
__host__
__device__
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
{
y
=
x
;
}
};
struct
FastGelu
{
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
));
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
/
(
float
(
1
)
+
emu
)
-
float
(
1
));
y
=
x
*
cdf
;
}
};
struct
Add
{
__host__
__device__
constexpr
void
operator
()(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
y
=
x0
+
x1
;
}
__host__
__device__
constexpr
void
operator
()(
half_t
&
y
,
const
half_t
&
x0
,
const
half_t
&
x1
)
const
{
// FIXME - Use float (acc type) bias in the future.
y
=
x0
+
x1
;
}
};
struct
AlphaBetaAdd
{
AlphaBetaAdd
(
float
alpha
,
float
beta
)
:
alpha_
(
alpha
),
beta_
(
beta
)
{}
__host__
__device__
constexpr
void
operator
()(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
y
=
alpha_
*
x0
+
beta_
*
x1
;
}
__host__
__device__
constexpr
void
operator
()(
half_t
&
y
,
const
half_t
&
x0
,
const
half_t
&
x1
)
const
{
// FIXME - Let x0 be acc type
y
=
static_cast
<
half_t
>
(
alpha_
*
static_cast
<
float
>
(
x0
)
+
beta_
*
static_cast
<
float
>
(
x1
));
}
float
alpha_
;
float
beta_
;
};
struct
AddRelu
{
__host__
__device__
constexpr
void
operator
()(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
const
float
a
=
x0
+
x1
;
y
=
a
>
0
?
a
:
0
;
}
__host__
__device__
constexpr
void
operator
()(
half_t
&
y
,
const
half_t
&
x0
,
const
half_t
&
x1
)
const
{
const
half_t
a
=
x0
+
x1
;
y
=
a
>
0
?
a
:
0
;
}
};
struct
AddHardswish
{
__host__
__device__
constexpr
void
operator
()(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
float
a
=
x0
+
x1
;
float
b
=
a
+
float
{
3
};
float
c
=
(
b
>
0
)
*
(
b
>
float
{
6
}
?
float
{
6
}
:
b
)
*
a
*
float
{
0.166667
};
y
=
c
;
}
__host__
__device__
constexpr
void
operator
()(
half_t
&
y
,
const
half_t
&
x0
,
const
half_t
&
x1
)
const
{
float
a
=
x0
+
x1
;
float
b
=
a
+
float
{
3
};
float
c
=
(
b
>
0
)
*
(
b
>
float
{
6
}
?
float
{
6
}
:
b
)
*
a
*
float
{
0.166667
};
y
=
c
;
}
};
struct
AddReluAdd
{
__host__
__device__
constexpr
void
...
...
@@ -162,8 +65,14 @@ struct AddHardswishAdd
// E = FastGelu(C + D0 + D1)
struct
AddAddFastGelu
{
__host__
__device__
void
operator
()(
ck
::
half_t
&
e
,
const
float
&
c
,
const
ck
::
half_t
&
d0
,
const
ck
::
half_t
&
d1
)
const
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
__host__
__device__
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
float
,
half_t
,
half_t
>
(
half_t
&
e
,
const
float
&
c
,
const
half_t
&
d0
,
const
half_t
&
d1
)
const
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
...
...
@@ -177,209 +86,67 @@ struct AddAddFastGelu
const
float
y
=
fast_gelu
(
c
+
float
(
d0
)
+
float
(
d1
));
e
=
ck
::
type_convert
<
ck
::
half_t
>
(
y
);
e
=
type_convert
<
half_t
>
(
y
);
}
};
struct
Normalize
{
Normalize
(
float
epsilon
=
1e-4
)
:
epsilon_
(
epsilon
)
{}
__host__
__device__
constexpr
void
operator
()(
float
&
y
,
const
float
&
x
,
const
float
&
mean
,
const
float
&
mean_square
,
const
float
&
gamma
,
const
float
&
beta
)
const
{
float
variance
=
mean_square
-
(
mean
*
mean
);
y
=
((
x
-
mean
)
/
sqrtf
(
variance
+
epsilon_
))
*
gamma
+
beta
;
}
float
epsilon_
;
};
// Unary operators are usually called element-wisely before/after the reduction is executed on the
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
template
<
typename
Y
,
typename
X
,
bool
HasDividing
=
false
>
struct
UnaryIdentic
;
template
<
>
struct
UnaryIdentic
<
float
,
float
,
false
>
{
__host__
__device__
UnaryIdentic
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
;
};
};
Normalize
(
double
epsilon
=
1e-4
)
:
epsilon_
(
epsilon
)
{}
template
<
>
struct
UnaryIdentic
<
float
,
float
,
true
>
{
__host__
__device__
UnaryIdentic
(
const
int32_t
divider
=
1
)
{
divider_
=
divider
;
};
template
<
typename
T
>
__host__
__device__
constexpr
void
operator
()(
T
&
y
,
const
T
&
x
,
const
T
&
mean
,
const
T
&
mean_square
,
const
T
&
gamma
,
const
T
&
beta
)
const
;
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
>
(
float
&
y
,
const
float
&
x
,
const
float
&
mean
,
const
float
&
mean_square
,
const
float
&
gamma
,
const
float
&
beta
)
const
{
y
=
x
/
type_convert
<
float
>
(
divider_
);
};
int32_t
divider_
=
1
;
};
using
ck
::
math
::
sqrt
;
template
<
>
struct
UnaryIdentic
<
half_t
,
half_t
,
false
>
{
__host__
__device__
UnaryIdentic
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
x
;
};
};
template
<
>
struct
UnaryIdentic
<
double
,
double
,
false
>
{
__host__
__device__
UnaryIdentic
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
{
y
=
x
;
};
};
template
<
>
struct
UnaryIdentic
<
double
,
double
,
true
>
{
__host__
__device__
UnaryIdentic
(
const
int32_t
divider
=
1
)
{
divider_
=
divider
;
};
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
{
y
=
x
/
type_convert
<
double
>
(
divider_
);
float
variance
=
mean_square
-
(
mean
*
mean
);
y
=
((
x
-
mean
)
/
sqrt
(
variance
+
static_cast
<
float
>
(
epsilon_
)))
*
gamma
+
beta
;
};
int32_t
divider_
=
1
;
};
template
<
>
struct
UnaryIdentic
<
int32_t
,
int32_t
,
false
>
{
__host__
__device__
UnaryIdentic
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
x
;
};
};
template
<
>
struct
UnaryIdentic
<
int32_t
,
int32_t
,
true
>
{
__host__
__device__
UnaryIdentic
(
const
int32_t
divider
=
1
)
{
divider_
=
divider
;
};
__host__
__device__
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
x
/
divider_
;
};
int32_t
divider_
=
1
;
};
template
<
>
struct
UnaryIdentic
<
int8_t
,
int8_t
,
false
>
{
__host__
__device__
UnaryIdentic
(
const
int8_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
x
;
};
};
template
<
typename
Y
,
typename
X
,
bool
HasDividing
=
false
>
struct
UnarySquare
;
template
<
>
struct
UnarySquare
<
float
,
float
,
false
>
{
__host__
__device__
UnarySquare
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
*
x
;
};
};
template
<
>
struct
UnarySquare
<
float
,
float
,
true
>
{
__host__
__device__
UnarySquare
(
const
int32_t
divider
=
1
)
{
divider_
=
divider
;
};
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
template
<
>
__host__
__device__
constexpr
void
operator
()
<
double
>
(
double
&
y
,
const
double
&
x
,
const
double
&
mean
,
const
double
&
mean_square
,
const
double
&
gamma
,
const
double
&
beta
)
const
{
y
=
x
*
x
/
type_convert
<
float
>
(
divider_
);
};
int32_t
divider_
=
1
;
};
template
<
>
struct
UnarySquare
<
double
,
double
,
false
>
{
__host__
__device__
UnarySquare
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
{
y
=
x
*
x
;
};
};
using
ck
::
math
::
sqrt
;
template
<
>
struct
UnarySquare
<
double
,
double
,
true
>
{
__host__
__device__
UnarySquare
(
const
int32_t
divider
=
1
)
{
divider_
=
divider
;
};
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
{
y
=
x
*
x
/
type_convert
<
double
>
(
divider_
);
double
variance
=
mean_square
-
(
mean
*
mean
);
y
=
((
x
-
mean
)
/
sqrt
(
variance
+
epsilon_
))
*
gamma
+
beta
;
};
int32_t
divider_
=
1
;
};
template
<
typename
Y
,
typename
X
>
struct
UnaryAbs
;
template
<
>
struct
UnaryAbs
<
float
,
float
>
{
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
ck
::
math
::
abs
(
x
);
};
};
template
<
>
struct
UnaryAbs
<
half_t
,
half_t
>
{
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
ck
::
math
::
abs
(
x
);
};
};
template
<
>
struct
UnaryAbs
<
double
,
double
>
{
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
{
y
=
ck
::
math
::
abs
(
x
);
};
};
template
<
>
struct
UnaryAbs
<
int8_t
,
int8_t
>
{
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
ck
::
math
::
abs
(
x
);
};
double
epsilon_
;
};
template
<
typename
Y
,
typename
X
>
struct
Unary
Sq
rt
;
struct
Unary
TypeConve
rt
;
template
<
>
struct
Unary
Sq
rt
<
float
,
floa
t
>
struct
Unary
TypeConve
rt
<
float
,
ck
::
bhalf_
t
>
{
__host__
__device__
UnarySqrt
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
ck
::
math
::
sqrt
(
x
);
};
__host__
__device__
void
operator
()(
float
&
y
,
ck
::
bhalf_t
&
x
)
const
{
y
=
ck
::
type_convert
<
float
,
ck
::
bhalf_t
>
(
x
);
};
};
template
<
>
struct
Unary
Sqrt
<
double
,
double
>
struct
Unary
TypeConvert
<
ck
::
bhalf_t
,
float
>
{
__host__
__device__
UnarySqrt
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
__host__
__device__
void
operator
()(
ck
::
bhalf_t
&
y
,
float
&
x
)
const
{
y
=
ck
::
math
::
sqrt
(
x
);
y
=
ck
::
type_convert
<
ck
::
bhalf_t
,
float
>
(
x
);
};
};
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
0 → 100644
View file @
6805df0e
#pragma once
#include "data_type.hpp"
#include "math_v2.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
element_wise
{
struct
PassThrough
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
bhalf_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
x
;
};
};
struct
UnaryDivide
{
__host__
__device__
UnaryDivide
(
const
int32_t
divider
=
1
)
:
divider_
(
divider
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
x
/
type_convert
<
T
>
(
divider_
);
};
int32_t
divider_
=
1
;
};
struct
UnarySquare
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
,
"Data type is not supported by this operation!"
);
y
=
x
*
x
;
};
};
struct
UnaryAbs
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
abs
(
x
);
};
};
struct
UnarySqrt
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
sqrt
(
x
);
};
};
struct
Relu
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
x
>
0
?
x
:
0
;
}
template
<
>
__host__
__device__
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
float
x_f32
=
ck
::
type_convert
<
float
>
(
x
);
float
y_f32
=
x_f32
>
0
?
x_f32
:
0
;
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_f32
);
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
struct
FastGelu
{
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
));
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
/
(
float
(
1
)
+
emu
)
-
float
(
1
));
y
=
x
*
cdf
;
}
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp
View file @
6805df0e
...
...
@@ -171,15 +171,15 @@ struct GridwiseReduction_mk_to_m_multiblock
AccDataType
beta
,
OutDataType
*
const
__restrict__
p_out_value_global
)
{
const
auto
identityVal
=
ReduceOperation
::
GetIdentityValue
();
const
auto
identityVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>
();
// LDS
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
identityVal
));
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
ReduceOperation
::
template
GetIdentityValue
<
InDataType
>());
auto
out_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_value_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
...
...
@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock
__shared__
AccDataType
p_reduce_work_val_buffer
[
BlockSize
];
__shared__
IndexDataType
p_reduce_work_idx_buffer
[
BlockSize
];
const
auto
identityVal
=
ReduceOperation
::
GetIdentityValue
();
const
auto
identityVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>
();
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
identityVal
));
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
ReduceOperation
::
template
GetIdentityValue
<
InDataType
>());
const
auto
in_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_index_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
());
auto
out_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp
View file @
6805df0e
...
...
@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise
ReduceOperation
,
PropagateNan
>
;
const
auto
identityVal
=
ReduceOperation
::
GetIdentityValue
();
const
auto
identityVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>
();
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
identityVal
));
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
ReduceOperation
::
template
GetIdentityValue
<
InDataType
>());
auto
dst_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_value_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
...
...
@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise
(
void
)
acc_elementwise_op
;
const
auto
identityVal
=
ReduceOperation
::
GetIdentityValue
();
const
auto
identityVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>
();
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
identityVal
));
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
ReduceOperation
::
template
GetIdentityValue
<
InDataType
>());
const
auto
in_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_index_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
());
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
0 → 100644
View file @
6805df0e
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
View file @
6805df0e
...
...
@@ -21,7 +21,7 @@ template <typename GridwiseGemm,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
,
typename
Dxs
Reduce
AccElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -41,7 +41,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
DxsInElementwiseOperation
dxs_in_element_op
,
const
DxsAccElementwiseOperation
dxs_out_element_op
,
const
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -96,7 +96,7 @@ template <typename FloatAB,
typename
CElementwiseOperation
,
typename
DxsReduceOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
,
typename
Dxs
Reduce
AccElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
DGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
...
...
@@ -329,7 +329,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
DxsInElementwiseOperation
&
dxs_in_element_op
,
const
DxsAccElementwiseOperation
&
dxs_out_element_op
,
const
Dxs
Reduce
AccElementwiseOperation
&
dxs_out_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
...
...
@@ -816,7 +816,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
false
>
;
// Global write Gemm shuffle + reduction
const
auto
d_identityVal
=
DReduceOperation
::
GetIdentityValue
();
const
auto
d_identityVal
=
DReduceOperation
::
template
GetIdentityValue
<
FloatReduceAcc
>();
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
[
&
](
auto
I
)
{
d_thread_buf
(
I
)
=
d_identityVal
;
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
View file @
6805df0e
...
...
@@ -791,8 +791,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr
auto
c_block_desc_mblock_mperblock_nblock_nperblock
=
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
void
*
p_shared
=
static_cast
<
void
*>
(
p_shared_block
);
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatC
*>
(
p_shared
_block
),
static_cast
<
FloatC
*>
(
p_shared
),
c_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
static_assert
(
M1
==
MWave
,
""
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
View file @
6805df0e
...
...
@@ -37,7 +37,7 @@ __global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffe
{
using
PassThroughOp
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
DataType
,
DataType
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp
0 → 100644
View file @
6805df0e
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
typename
GridwiseUEltwise
,
typename
ADataType
,
typename
BDataType
,
typename
GridDesc_M0
,
typename
ElementwiseFunctor
>
__global__
void
kernel_unary_elementwise_1d
(
const
ADataType
*
__restrict__
p_a_global
,
BDataType
*
__restrict__
p_b_global
,
const
GridDesc_M0
a_grid_desc_m0
,
const
GridDesc_M0
b_grid_desc_m0
,
const
ElementwiseFunctor
functor
)
{
GridwiseUEltwise
::
Run
(
p_a_global
,
p_b_global
,
a_grid_desc_m0
,
b_grid_desc_m0
,
functor
);
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
GridDesc_M0
,
typename
ElementwiseFunctor
,
index_t
ScalarPerVector
>
struct
GridwiseUnaryElementwise_1D
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
thread_desc_m0
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
ScalarPerVector
>
{}));
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
__device__
auto
CalculateElementwiseIndex
()
{
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
return
make_multi_index
(
global_thread_id
*
ScalarPerVector
);
}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
GridDesc_M0
a_grid_desc_m0
,
const
GridDesc_M0
b_grid_desc_m0
)
{
return
a_grid_desc_m0
.
GetLength
(
I0
)
==
b_grid_desc_m0
.
GetLength
(
I0
);
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
index_t
tensor_size
)
{
const
index_t
grid_size
=
math
::
integer_divide_ceil
(
tensor_size
,
256
*
ScalarPerVector
);
return
grid_size
;
}
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_global
,
BDataType
*
__restrict__
p_b_global
,
const
GridDesc_M0
a_grid_desc_m0
,
const
GridDesc_M0
b_grid_desc_m0
,
const
ElementwiseFunctor
functor
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_grid_desc_m0
.
GetElementSpaceSize
());
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_global
,
b_grid_desc_m0
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
,
ScalarPerVector
,
true
>
a_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
BDataType
,
ScalarPerVector
,
true
>
b_thread_buf
;
const
auto
thread_store_global_offset
=
CalculateElementwiseIndex
();
auto
a_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
ADataType
,
GridDesc_M0
,
decltype
(
thread_desc_m0
),
Sequence
<
ScalarPerVector
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
ScalarPerVector
,
1
,
// SrcScalarStrideInVector
false
>
{
a_grid_desc_m0
,
thread_store_global_offset
};
auto
b_global_write
=
ThreadwiseTensorSliceTransfer_v1r3
<
BDataType
,
BDataType
,
decltype
(
thread_desc_m0
),
GridDesc_M0
,
PassThrough
,
Sequence
<
ScalarPerVector
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// DstVectorDim
ScalarPerVector
,
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
false
>
{
b_grid_desc_m0
,
thread_store_global_offset
,
PassThrough
{}};
const
index_t
blockSize
=
get_block_size
();
const
index_t
blockPerGrid
=
get_grid_size
();
const
auto
m0
=
b_grid_desc_m0
.
GetLength
(
I0
);
const
index_t
loop_step
=
blockPerGrid
*
blockSize
*
ScalarPerVector
;
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
index_t
num_iter
=
m0
/
(
loop_step
);
do
{
// read and process ScalarPerVector elements
a_global_load
.
Run
(
a_grid_desc_m0
,
a_global_buf
,
thread_desc_m0
,
make_tuple
(
I0
),
a_thread_buf
);
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
offset
=
thread_desc_m0
.
CalculateOffset
(
make_tuple
(
m
));
functor
(
b_thread_buf
(
Number
<
offset
>
{}),
a_thread_buf
(
Number
<
offset
>
{}));
});
b_global_write
.
Run
(
thread_desc_m0
,
make_tuple
(
I0
),
// SrcSliceOriginIdx
b_thread_buf
,
b_grid_desc_m0
,
b_global_buf
);
a_global_load
.
MoveSrcSliceWindow
(
a_grid_desc_m0
,
loop_step_index
);
b_global_write
.
MoveDstSliceWindow
(
b_grid_desc_m0
,
loop_step_index
);
}
while
(
--
num_iter
);
}
};
}
// namespace ck
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment