Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2b27d5fc
Commit
2b27d5fc
authored
Jul 01, 2022
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/develop' into rosenrodt/gemm-layernorm
parents
f689a155
fa9a0a5c
Changes
137
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1067 additions
and
210 deletions
+1067
-210
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
...eration/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
+84
-84
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
+82
-61
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+17
-2
include/ck/utility/math.hpp
include/ck/utility/math.hpp
+2
-0
include/ck/utility/reduction_functions_accumulate.hpp
include/ck/utility/reduction_functions_accumulate.hpp
+1
-1
library/CMakeLists.txt
library/CMakeLists.txt
+1
-1
library/include/ck/library/host_tensor/host_tensor.hpp
library/include/ck/library/host_tensor/host_tensor.hpp
+1
-5
library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp
...reference_tensor_operation/cpu/reference_batched_gemm.hpp
+6
-6
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+7
-6
library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp
...rary/reference_tensor_operation/cpu/reference_softmax.hpp
+2
-5
library/include/ck/library/tensor_operation_instance/device_operation_instance.hpp
...y/tensor_operation_instance/device_operation_instance.hpp
+1
-0
library/include/ck/library/tensor_operation_instance/gpu/device_batched_gemm_instance.hpp
...r_operation_instance/gpu/device_batched_gemm_instance.hpp
+203
-0
library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp
...or_operation_instance/gpu/device_elementwise_instance.hpp
+49
-0
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_add_add_fastgelu_instance.hpp
...on_instance/gpu/device_gemm_add_add_fastgelu_instance.hpp
+93
-0
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_instance.hpp
...ry/tensor_operation_instance/gpu/device_gemm_instance.hpp
+286
-0
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp
...ion_instance/gpu/device_gemm_mean_squaremean_instance.hpp
+84
-0
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_splitk_instance.hpp
...or_operation_instance/gpu/device_gemm_splitk_instance.hpp
+124
-0
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+2
-2
library/src/host_tensor/host_tensor.cpp
library/src/host_tensor/host_tensor.cpp
+0
-22
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+22
-15
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
View file @
2b27d5fc
...
@@ -21,16 +21,16 @@ namespace ck {
...
@@ -21,16 +21,16 @@ namespace ck {
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Dxs
InElementwiseOperation
,
typename
Reduce
InElementwiseOperation
s
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
ReduceAccElementwiseOperation
s
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
D
GridDescriptor_MBlock_MPerBlock
,
typename
Reduce
GridDescriptor_MBlock_MPerBlock
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
...
@@ -41,17 +41,17 @@ __global__ void
...
@@ -41,17 +41,17 @@ __global__ void
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
D
PtrsGlobal
p_
d
s_grid
,
Reduce
PtrsGlobal
p_
reduce
s_grid
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Dxs
InElementwiseOperation
dxs
_in_element_op
,
const
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
,
const
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op
,
const
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
D
GridDescriptor_MBlock_MPerBlock
d
_grid_desc_mblock_mperblock
,
const
Reduce
GridDescriptor_MBlock_MPerBlock
reduce
_grid_desc_mblock_mperblock
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...
@@ -60,32 +60,32 @@ __global__ void
...
@@ -60,32 +60,32 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_
d
s_grid
,
p_
reduce
s_grid
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
dxs
_in_element_op
,
reduce
_in_element_op
s
,
dxs
_out_element_op
,
reduce
_out_element_op
s
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
d
_grid_desc_mblock_mperblock
,
reduce
_grid_desc_mblock_mperblock
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_c_grid
;
ignore
=
p_
d
s_grid
;
ignore
=
p_
reduce
s_grid
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c_element_op
;
ignore
=
dxs
_in_element_op
;
ignore
=
reduce
_in_element_op
s
;
ignore
=
dxs
_out_element_op
;
ignore
=
reduce
_out_element_op
s
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
d
_grid_desc_mblock_mperblock
;
ignore
=
reduce
_grid_desc_mblock_mperblock
;
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -95,19 +95,19 @@ template <typename FloatAB,
...
@@ -95,19 +95,19 @@ template <typename FloatAB,
typename
FloatCShuffle
,
typename
FloatCShuffle
,
typename
FloatC
,
typename
FloatC
,
typename
FloatReduceAcc
,
typename
FloatReduceAcc
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Dxs
ReduceOperation
,
typename
ReduceOperation
s
,
typename
Dxs
InElementwiseOperation
,
typename
Reduce
InElementwiseOperation
s
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
ReduceAccElementwiseOperation
s
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
D
GlobalMemoryDataOperation
,
typename
Reduce
GlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
typename
CGridDesc_M_N
,
typename
D
GridDesc_M
,
typename
Reduce
GridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
...
@@ -293,18 +293,18 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -293,18 +293,18 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
Make
D
GridDescriptor_MBlock_MPerBlock
(
const
D
GridDesc_M
&
d_grid_desc_m
)
Make
Reduce
GridDescriptor_MBlock_MPerBlock
(
const
Reduce
GridDesc_M
&
d_grid_desc_m
)
{
{
const
auto
M
=
d_grid_desc_m
.
GetLength
(
I0
);
const
auto
M
=
d_grid_desc_m
.
GetLength
(
I0
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
d
_grid_desc_mblock_mperblock
=
transform_tensor_descriptor
(
const
auto
reduce
_grid_desc_mblock_mperblock
=
transform_tensor_descriptor
(
d_grid_desc_m
,
d_grid_desc_m
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{}))),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{}));
make_tuple
(
Sequence
<
0
,
1
>
{}));
return
d
_grid_desc_mblock_mperblock
;
return
reduce
_grid_desc_mblock_mperblock
;
}
}
// return block_id to C matrix tile idx (m0, n0) mapping
// return block_id to C matrix tile idx (m0, n0) mapping
...
@@ -318,29 +318,30 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -318,29 +318,30 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
D
GridDescriptor_MBlock_MPerBlock
=
using
Reduce
GridDescriptor_MBlock_MPerBlock
=
remove_cvref_t
<
decltype
(
Make
D
GridDescriptor_MBlock_MPerBlock
(
D
GridDesc_M
{}))
>
;
remove_cvref_t
<
decltype
(
Make
Reduce
GridDescriptor_MBlock_MPerBlock
(
Reduce
GridDesc_M
{}))
>
;
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
const
FloatAB
*
__restrict__
p_b_grid
,
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
DPtrsGlobal
p_ds_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
ReducePtrsGlobal
p_reduces_grid
,
const
AElementwiseOperation
&
a_element_op
,
void
*
__restrict__
p_shared
,
const
BElementwiseOperation
&
b_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
DxsInElementwiseOperation
&
dxs_in_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
DxsReduceAccElementwiseOperation
&
dxs_out_element_op
,
const
ReduceInElementwiseOperations
&
reduce_in_element_ops
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
ReduceAccElementwiseOperations
&
reduce_out_element_ops
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
DGridDescriptor_MBlock_MPerBlock
&
d_grid_desc_mblock_mperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
ReduceGridDescriptor_MBlock_MPerBlock
&
reduce_grid_desc_mblock_mperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
...
@@ -706,12 +707,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -706,12 +707,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_naive_tensor_descriptor_packed
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
mreduce_per_thread
>
{},
Number
<
nreduce_per_thread
>
{}));
make_tuple
(
Number
<
mreduce_per_thread
>
{},
Number
<
nreduce_per_thread
>
{}));
// VGPR
d_
reduce_thread_desc_mperblock
// VGPR reduce_thread_desc_mperblock
constexpr
auto
d_
reduce_thread_desc_mperblock
=
constexpr
auto
reduce_thread_desc_mperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
mreduce_per_thread
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
mreduce_per_thread
>
{}));
// VGPR
d_
reduce_thread_desc_mblock_mperblock
// VGPR reduce_thread_desc_mblock_mperblock
constexpr
auto
d_
reduce_thread_desc_mblock_mperblock
=
constexpr
auto
reduce_thread_desc_mblock_mperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
mreduce_per_thread
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
mreduce_per_thread
>
{}));
auto
c_reduce_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
auto
c_reduce_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
...
@@ -740,29 +741,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -740,29 +741,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
1
,
1
,
true
>
{
c_reduce_block_desc_mperblock_nperblock
,
c_reduce_thread_data_idx_begin
};
true
>
{
c_reduce_block_desc_mperblock_nperblock
,
c_reduce_thread_data_idx_begin
};
auto
dxs_
reduce_thread_copy_vgpr_to_global
=
generate_tuple
(
auto
reduce_
tuple_
thread_copy_vgpr_to_global
=
generate_tuple
(
[
&
](
auto
I
)
{
[
&
](
auto
I
)
{
auto
p_
d
_grid
=
p_
d
s_grid
[
I
];
auto
p_
reduce
_grid
=
p_
reduce
s_grid
[
I
];
auto
d_out
_element_op
=
dxs
_out_element_op
[
I
];
auto
reduce_acc
_element_op
=
reduce
_out_element_op
s
[
I
];
return
ThreadwiseTensorSliceTransfer_v1r3
<
return
ThreadwiseTensorSliceTransfer_v1r3
<
FloatReduceAcc
,
FloatReduceAcc
,
remove_pointer_t
<
decltype
(
p_
d
_grid
)
>
,
remove_pointer_t
<
decltype
(
p_
reduce
_grid
)
>
,
decltype
(
d_
reduce_thread_desc_mblock_mperblock
),
decltype
(
reduce_thread_desc_mblock_mperblock
),
decltype
(
d
_grid_desc_mblock_mperblock
),
decltype
(
reduce
_grid_desc_mblock_mperblock
),
decltype
(
d_out
_element_op
),
decltype
(
reduce_acc
_element_op
),
Sequence
<
1
,
mreduce_per_thread
>
,
Sequence
<
1
,
mreduce_per_thread
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
D
GlobalMemoryDataOperation
::
At
(
I
),
Reduce
GlobalMemoryDataOperation
::
At
(
I
),
1
,
1
,
false
>
{
d
_grid_desc_mblock_mperblock
,
false
>
{
reduce
_grid_desc_mblock_mperblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
make_multi_index
(
block_work_idx
[
I0
],
// mblock
c_reduce_thread_data_idx_begin
[
I0
]),
// mperblock
c_reduce_thread_data_idx_begin
[
I0
]),
// mperblock
d_out
_element_op
};
reduce_acc
_element_op
};
},
},
Number
<
p_
d
s_grid
.
Size
()
>
{});
Number
<
p_
reduce
s_grid
.
Size
()
>
{});
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
...
@@ -797,35 +798,35 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -797,35 +798,35 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
c_reduce_thread_buf
);
c_reduce_thread_buf
);
static_for
<
0
,
p_
d
s_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
static_for
<
0
,
p_
reduce
s_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
auto
&
p_
d
_grid
=
p_
d
s_grid
[
In
];
auto
&
p_
reduce
_grid
=
p_
reduce
s_grid
[
In
];
auto
d
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
reduce
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
d
_grid
,
d
_grid_desc_mblock_mperblock
.
GetElementSpaceSize
());
p_
reduce
_grid
,
reduce
_grid_desc_mblock_mperblock
.
GetElementSpaceSize
());
auto
d
_thread_buf
=
auto
reduce
_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
d_
reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
auto
&
d
_in_element_op
=
dxs
_in_element_op
[
In
];
auto
&
reduce
_in_element_op
=
reduce
_in_element_op
s
[
In
];
auto
&
d_
reduce_thread_copy_vgpr_to_global
=
auto
&
reduce_thread_copy_vgpr_to_global
=
dxs_
reduce_thread_copy_vgpr_to_global
(
In
);
reduce_
tuple_
thread_copy_vgpr_to_global
(
In
);
using
D
ReduceOperation
=
remove_cvref_t
<
decltype
(
Dxs
ReduceOperation
{}[
In
])
>
;
using
ReduceOperation
=
remove_cvref_t
<
decltype
(
ReduceOperation
s
{}[
In
])
>
;
using
ThreadwiseReduce
=
using
ThreadwiseReduce
=
ThreadwiseReduction
<
FloatReduceAcc
,
ThreadwiseReduction
<
FloatReduceAcc
,
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
d_
reduce_thread_desc_mperblock
),
decltype
(
reduce_thread_desc_mperblock
),
D
ReduceOperation
,
ReduceOperation
,
false
>
;
false
>
;
// Global write Gemm shuffle + reduction
// Global write Gemm shuffle + reduction
const
auto
d
_identityVal
=
const
auto
reduce
_identityVal
=
D
ReduceOperation
::
template
GetIdentityValue
<
FloatReduceAcc
>();
ReduceOperation
::
template
GetIdentityValue
<
FloatReduceAcc
>();
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
[
&
](
auto
I
)
{
d
_thread_buf
(
I
)
=
d
_identityVal
;
});
[
&
](
auto
I
)
{
reduce
_thread_buf
(
I
)
=
reduce
_identityVal
;
});
// reduce in VGPR
// reduce in VGPR
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
im
)
{
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
im
)
{
...
@@ -834,26 +835,25 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -834,26 +835,25 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Number
<
c_reduce_thread_desc_mperblock_nperblock
.
CalculateOffset
(
Number
<
c_reduce_thread_desc_mperblock_nperblock
.
CalculateOffset
(
make_tuple
(
im
,
in
))
>
{};
make_tuple
(
im
,
in
))
>
{};
d
_in_element_op
(
c_reduce_thread_buf
(
offset
),
reduce
_in_element_op
(
c_reduce_thread_buf
(
offset
),
c_reduce_thread_buf
(
offset
));
c_reduce_thread_buf
(
offset
));
});
});
});
});
ThreadwiseReduce
::
Reduce
(
c_reduce_thread_buf
,
d
_thread_buf
);
ThreadwiseReduce
::
Reduce
(
c_reduce_thread_buf
,
reduce
_thread_buf
);
// copy from VGPR to Global
// copy from VGPR to Global
d_reduce_thread_copy_vgpr_to_global
.
Run
(
reduce_thread_copy_vgpr_to_global
.
Run
(
reduce_thread_desc_mblock_mperblock
,
d_reduce_thread_desc_mblock_mperblock
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
reduce_thread_buf
,
d_thread_buf
,
reduce_grid_desc_mblock_mperblock
,
d_grid_desc_mblock_mperblock
,
reduce_grid_buf
);
d_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
if
constexpr
(
access_id
<
num_access
-
1
)
{
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
d_
reduce_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
reduce_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
d
_grid_desc_mblock_mperblock
,
reduce
_grid_desc_mblock_mperblock
,
make_tuple
(
c_global_step
[
I0
],
c_global_step
[
I1
]));
make_tuple
(
c_global_step
[
I0
],
c_global_step
[
I1
]));
}
}
});
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
View file @
2b27d5fc
...
@@ -49,7 +49,8 @@ template <typename InDataType,
...
@@ -49,7 +49,8 @@ template <typename InDataType,
index_t
KThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
index_t
OutDstVectorSize
,
bool
SweepOnce
>
struct
GridwiseSoftmax_mk_to_mk
struct
GridwiseSoftmax_mk_to_mk
{
{
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
...
@@ -75,19 +76,6 @@ struct GridwiseSoftmax_mk_to_mk
...
@@ -75,19 +76,6 @@ struct GridwiseSoftmax_mk_to_mk
using
ThreadReduceDstDesc_M
=
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
BlockwiseMaxReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Max
,
false
>
;
// PropagateNan
using
ThreadwiseMaxReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
reduce
::
Max
,
false
>
;
// PropagateNan
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -105,6 +93,11 @@ struct GridwiseSoftmax_mk_to_mk
...
@@ -105,6 +93,11 @@ struct GridwiseSoftmax_mk_to_mk
AccDataType
beta
,
AccDataType
beta
,
OutDataType
*
const
__restrict__
p_out_value_global
)
OutDataType
*
const
__restrict__
p_out_value_global
)
{
{
if
constexpr
(
SweepOnce
)
{
num_k_block_tile_iteration
=
1
;
}
// LDS
// LDS
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
...
@@ -149,6 +142,20 @@ struct GridwiseSoftmax_mk_to_mk
...
@@ -149,6 +142,20 @@ struct GridwiseSoftmax_mk_to_mk
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
// Normally, 0 as invalid element value is adequate since 0 makes no contribution to
// accumulated result. However, in stable softmax, all values 0s or not are subtracted by
// another value_max. As numbers become non-zero, effectively it allows invalid values to
// slip through and contribute to the accumulated result.
//
// The trick here is leveraging the fact that many math functions (add, sub, exp, ...)
// propagate NaNs when operands have NaNs involved. By initialiing invalid element value
// with NaN, an invalid value doing math manipulations is still NaN, which in turn can still
// be identified as an invalid value. We can then discard the invalid values which
// originally failed the bound check during accumulation. This allows to ignore values that
// failed bound check even after multiple math manipulations.
//
// NOTE: reset coordinate after every step because the same threadwise copy will sweep
// through global memory 3 times back and forth
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
AccDataType
,
AccDataType
,
GridDesc_M_K
,
GridDesc_M_K
,
...
@@ -158,7 +165,8 @@ struct GridwiseSoftmax_mk_to_mk
...
@@ -158,7 +165,8 @@ struct GridwiseSoftmax_mk_to_mk
InSrcVectorDim
,
InSrcVectorDim
,
InSrcVectorSize
,
InSrcVectorSize
,
1
,
1
,
false
>
(
true
/* ResetCoordAfterRun */
,
true
/* InvalidElementAsNaN */
>
(
in_grid_desc_m_k
,
in_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
block_local_id
*
reduceSizePerBlock
+
...
@@ -198,21 +206,39 @@ struct GridwiseSoftmax_mk_to_mk
...
@@ -198,21 +206,39 @@ struct GridwiseSoftmax_mk_to_mk
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
),
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
),
PassThroughOp
{});
PassThroughOp
{});
constexpr
auto
in_thread_copy_fwd_step
=
make_multi_index
(
0
,
K_BlockTileSize
);
constexpr
auto
in_thread_copy_fwd_step
=
constexpr
auto
in_thread_copy_bwd_step
=
make_multi_index
(
0
,
-
K_BlockTileSize
);
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
in_thread_copy_bwd_step
=
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
///
///
/// max(x)
/// max(x)
///
///
const
auto
in_global_val_buf_oob_non_zero
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
using
BlockwiseMaxReduce
=
PartitionedBlockwiseReduction
<
p_in_value_global
,
AccDataType
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
BlockSize
,
reduce
::
Max
::
template
GetIdentityValue
<
InDataType
>());
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Max
,
false
,
// param ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
;
using
ThreadwiseMaxReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
reduce
::
Max
,
false
,
// param ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
;
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
());
index_t
reducedTiles
=
0
;
index_t
reducedTiles
=
0
;
do
do
{
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf
_oob_non_zero
,
in_global_val_buf
,
thread_buffer_desc
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
in_thread_buf
);
in_thread_buf
);
...
@@ -232,26 +258,6 @@ struct GridwiseSoftmax_mk_to_mk
...
@@ -232,26 +258,6 @@ struct GridwiseSoftmax_mk_to_mk
///
///
/// sum(exp(x - max(x)))
/// sum(exp(x - max(x)))
///
///
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
AccDataType
>();
});
// Normally, 0 as invalid element value is adequate since 0 makes no contribution to
// accumulated result. However, in stable softmax, all values 0s or not are subtracted by
// another value_max. As numbers become non-zero, effectively it allows invalid values to
// slip through and contribute to the accumulated result.
//
// The trick here is leveraging the fact that many math functions (add, sub, exp, ...)
// propagate NaNs when operands have NaNs involved. By initialiing invalid element value
// with NaN, an invalid value doing math manipulations is still NaN, which in turn can still
// be identified as an invalid value. We can then discard the invalid values which
// originally failed the bound check during accumulation. This allows to ignore values that
// failed bound check even after multiple math manipulations.
const
auto
in_global_val_buf_oob_nan
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
NumericLimits
<
InDataType
>::
QuietNaN
());
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction
<
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
AccDataType
,
BlockSize
,
BlockSize
,
...
@@ -272,22 +278,25 @@ struct GridwiseSoftmax_mk_to_mk
...
@@ -272,22 +278,25 @@ struct GridwiseSoftmax_mk_to_mk
reducedTiles
=
0
;
reducedTiles
=
0
;
do
do
{
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
if
constexpr
(
!
SweepOnce
)
in_global_val_buf_oob_nan
,
{
thread_buffer_desc
,
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
make_tuple
(
I0
,
I0
),
in_global_val_buf
,
in_thread_buf
);
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
}
// do element-wise pre-reduction operation
// do element-wise pre-reduction operation
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
in
_thread_buf
(
Number
<
offset
>
{})
=
out
_thread_buf
(
Number
<
offset
>
{})
=
math
::
exp
(
in_thread_buf
(
Number
<
offset
>
{})
-
max_value_buf
(
iM
));
math
::
exp
(
in_thread_buf
(
Number
<
offset
>
{})
-
max_value_buf
(
iM
));
});
});
});
});
ThreadwiseSumReduce
::
Reduce
(
in
_thread_buf
,
accu_value_buf
);
ThreadwiseSumReduce
::
Reduce
(
out
_thread_buf
,
accu_value_buf
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_bwd_step
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_bwd_step
);
...
@@ -309,11 +318,14 @@ struct GridwiseSoftmax_mk_to_mk
...
@@ -309,11 +318,14 @@ struct GridwiseSoftmax_mk_to_mk
{
{
do
do
{
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
if
constexpr
(
!
SweepOnce
)
in_global_val_buf_oob_nan
,
{
thread_buffer_desc
,
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
make_tuple
(
I0
,
I0
),
in_global_val_buf
,
in_thread_buf
);
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x)))
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x)))
...
@@ -340,18 +352,27 @@ struct GridwiseSoftmax_mk_to_mk
...
@@ -340,18 +352,27 @@ struct GridwiseSoftmax_mk_to_mk
}
}
else
else
{
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_prior_dst_buf
;
do
do
{
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
if
constexpr
(
!
SweepOnce
)
in_global_val_buf_oob_nan
,
{
thread_buffer_desc
,
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
make_tuple
(
I0
,
I0
),
in_global_val_buf
,
in_thread_buf
);
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
}
threadwise_dst_load
.
Run
(
out_grid_desc_m_k
,
threadwise_dst_load
.
Run
(
out_grid_desc_m_k
,
out_global_val_buf
,
out_global_val_buf
,
thread_buffer_desc
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
out_thread_buf
);
in_prior_dst_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
...
@@ -360,7 +381,7 @@ struct GridwiseSoftmax_mk_to_mk
...
@@ -360,7 +381,7 @@ struct GridwiseSoftmax_mk_to_mk
out_thread_buf
(
Number
<
offset
>
{})
=
out_thread_buf
(
Number
<
offset
>
{})
=
alpha
*
math
::
exp
(
in_thread_buf
(
Number
<
offset
>
{})
-
max_value_buf
(
iM
))
/
alpha
*
math
::
exp
(
in_thread_buf
(
Number
<
offset
>
{})
-
max_value_buf
(
iM
))
/
accu_value_buf
(
iM
)
+
accu_value_buf
(
iM
)
+
beta
*
out_thread
_buf
(
Number
<
offset
>
{});
beta
*
in_prior_dst
_buf
(
Number
<
offset
>
{});
});
});
});
});
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
2b27d5fc
...
@@ -236,9 +236,14 @@ template <typename SrcData,
...
@@ -236,9 +236,14 @@ template <typename SrcData,
index_t
SrcScalarPerVector
,
index_t
SrcScalarPerVector
,
index_t
SrcScalarStrideInVector
,
index_t
SrcScalarStrideInVector
,
bool
SrcResetCoordinateAfterRun
,
bool
SrcResetCoordinateAfterRun
,
bool
InvalidElementAsNaN
=
false
,
typename
enable_if
<
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
typename
enable_if
<
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_v2
struct
ThreadwiseTensorSliceTransfer_v2
{
{
static_assert
((
InvalidElementAsNaN
&&
!
std
::
is_integral
<
DstData
>::
value
)
||
(
!
InvalidElementAsNaN
),
"Filling invalid element as NaN is only for floating point types"
);
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
...
@@ -318,8 +323,18 @@ struct ThreadwiseTensorSliceTransfer_v2
...
@@ -318,8 +323,18 @@ struct ThreadwiseTensorSliceTransfer_v2
dst_desc
.
CalculateOffset
(
to_multi_index
(
dst_slice_origin_idx
)
+
src_data_idx
+
dst_desc
.
CalculateOffset
(
to_multi_index
(
dst_slice_origin_idx
)
+
src_data_idx
+
i
*
src_scalar_step_in_vector
);
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
if
constexpr
(
InvalidElementAsNaN
)
type_convert
<
DstData
>
(
src_vector
.
template
AsType
<
SrcData
>()[
i
]);
{
dst_buf
(
Number
<
dst_offset
>
{})
=
is_src_valid
?
type_convert
<
DstData
>
(
src_vector
.
template
AsType
<
SrcData
>()[
i
])
:
NumericLimits
<
DstData
>::
QuietNaN
();
}
else
{
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
src_vector
.
template
AsType
<
SrcData
>()[
i
]);
}
});
});
if
constexpr
(
idx_1d
.
value
!=
num_access
-
1
)
if
constexpr
(
idx_1d
.
value
!=
num_access
-
1
)
...
...
include/ck/utility/math.hpp
View file @
2b27d5fc
...
@@ -148,6 +148,8 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
...
@@ -148,6 +148,8 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
template
<
typename
T
>
template
<
typename
T
>
__device__
T
exp
(
T
x
);
__device__
T
exp
(
T
x
);
// TODO: add f16 support using v_exp_f16
template
<
>
template
<
>
__device__
float
exp
<
float
>
(
float
x
)
__device__
float
exp
<
float
>
(
float
x
)
{
{
...
...
include/ck/utility/reduction_functions_accumulate.hpp
View file @
2b27d5fc
...
@@ -17,7 +17,7 @@ struct AccumulateWithNanIgnore
...
@@ -17,7 +17,7 @@ struct AccumulateWithNanIgnore
{
{
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
)
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
)
{
{
if
(
!
isnan
(
currVal
))
if
(
!
ck
::
math
::
isnan
(
currVal
))
{
{
ReduceOperation
{}(
accuVal
,
currVal
);
ReduceOperation
{}(
accuVal
,
currVal
);
}
}
...
...
library/CMakeLists.txt
View file @
2b27d5fc
add_subdirectory
(
src/host_tensor
)
add_subdirectory
(
src/tensor_operation_instance/gpu
)
add_subdirectory
(
src/tensor_operation_instance/gpu
)
add_subdirectory
(
src/host_tensor
)
add_subdirectory
(
src/utility
)
add_subdirectory
(
src/utility
)
library/include/ck/library/host_tensor/host_tensor.hpp
View file @
2b27d5fc
...
@@ -382,13 +382,8 @@ HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens,
...
@@ -382,13 +382,8 @@ HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens,
{
{
}
}
void
ostream_HostTensorDescriptor
(
const
HostTensorDescriptor
&
desc
,
std
::
ostream
&
os
=
std
::
cout
);
#if 1
#if 1
// FIXME: remove
// FIXME: remove
void
bf16_to_f32_
(
const
Tensor
<
ck
::
bhalf_t
>&
src
,
Tensor
<
float
>&
dst
);
#endif
template
<
typename
T
>
template
<
typename
T
>
float
check_error
(
const
Tensor
<
T
>&
ref
,
const
Tensor
<
T
>&
result
)
float
check_error
(
const
Tensor
<
T
>&
ref
,
const
Tensor
<
T
>&
result
)
{
{
...
@@ -434,3 +429,4 @@ float check_error(const Tensor<T>& ref, const Tensor<T>& result)
...
@@ -434,3 +429,4 @@ float check_error(const Tensor<T>& ref, const Tensor<T>& result)
return
linf_error
;
return
linf_error
;
}
}
#endif
library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp
View file @
2b27d5fc
...
@@ -62,20 +62,20 @@ struct ReferenceBatchedGemm : public device::BaseOperator
...
@@ -62,20 +62,20 @@ struct ReferenceBatchedGemm : public device::BaseOperator
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
float
v_a
;
ADataType
v_a
;
float
v_b
;
BDataType
v_b
;
arg
.
a_element_op_
(
v_a
,
static_cast
<
const
float
>
(
arg
.
a_g_m_k_
(
g
,
m
,
k
))
)
;
arg
.
a_element_op_
(
v_a
,
arg
.
a_g_m_k_
(
g
,
m
,
k
));
arg
.
b_element_op_
(
v_b
,
static_cast
<
const
float
>
(
arg
.
b_g_k_n_
(
g
,
k
,
n
))
)
;
arg
.
b_element_op_
(
v_b
,
arg
.
b_g_k_n_
(
g
,
k
,
n
));
v_acc
+=
v_a
*
v_b
;
v_acc
+=
ck
::
type_convert
<
float
>
(
v_a
)
*
ck
::
type_convert
<
float
>
(
v_b
)
;
}
}
float
v_c
;
float
v_c
;
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_g_m_n_
(
g
,
m
,
n
)
=
v_c
;
arg
.
c_g_m_n_
(
g
,
m
,
n
)
=
ck
::
type_convert
<
CDataType
>
(
v_c
)
;
};
};
make_ParallelTensorFunctor
(
f_gmk_gkn_gmn
,
make_ParallelTensorFunctor
(
f_gmk_gkn_gmn
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
2b27d5fc
...
@@ -63,20 +63,21 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -63,20 +63,21 @@ struct ReferenceGemm : public device::BaseOperator
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
A
cc
DataType
v_a
;
ADataType
v_a
;
Acc
DataType
v_b
;
B
DataType
v_b
;
arg
.
a_element_op_
(
v_a
,
static_cast
<
const
AccDataType
>
(
arg
.
a_m_k_
(
m
,
k
))
)
;
arg
.
a_element_op_
(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
arg
.
b_element_op_
(
v_b
,
static_cast
<
const
AccDataType
>
(
arg
.
b_k_n_
(
k
,
n
))
)
;
arg
.
b_element_op_
(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
v_acc
+=
v_a
*
v_b
;
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
}
}
AccDataType
v_c
;
AccDataType
v_c
;
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_m_n_
(
m
,
n
)
=
v_c
;
arg
.
c_m_n_
(
m
,
n
)
=
ck
::
type_convert
<
CDataType
>
(
v_c
)
;
};
};
make_ParallelTensorFunctor
(
make_ParallelTensorFunctor
(
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp
View file @
2b27d5fc
...
@@ -26,12 +26,11 @@ struct ReferenceSoftmax : public device::BaseOperator
...
@@ -26,12 +26,11 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor
<
OutDataType
>&
out
,
Tensor
<
OutDataType
>&
out
,
AccDataType
alpha
,
AccDataType
alpha
,
AccDataType
beta
,
AccDataType
beta
,
const
index_t
rank
,
const
std
::
vector
<
index_t
>
sm_reduce_dims
)
const
std
::
vector
<
index_t
>
sm_reduce_dims
)
:
in_
(
in
),
out_
(
out
),
alpha_
(
alpha
),
beta_
(
beta
),
sm_reduce_dims_
(
sm_reduce_dims
)
:
in_
(
in
),
out_
(
out
),
alpha_
(
alpha
),
beta_
(
beta
),
sm_reduce_dims_
(
sm_reduce_dims
)
{
{
// std::cout << "debug: scalar dims: ";
// std::cout << "debug: scalar dims: ";
for
(
in
t
i
=
0
;
i
<
rank
;
i
++
)
for
(
size_
t
i
=
0
;
i
<
in
.
mDesc
.
GetNumOfDimension
()
;
i
++
)
{
{
if
(
std
::
find
(
sm_reduce_dims
.
begin
(),
sm_reduce_dims
.
end
(),
i
)
==
if
(
std
::
find
(
sm_reduce_dims
.
begin
(),
sm_reduce_dims
.
end
(),
i
)
==
sm_reduce_dims
.
end
())
sm_reduce_dims
.
end
())
...
@@ -47,7 +46,6 @@ struct ReferenceSoftmax : public device::BaseOperator
...
@@ -47,7 +46,6 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor
<
OutDataType
>&
out_
;
Tensor
<
OutDataType
>&
out_
;
AccDataType
alpha_
;
AccDataType
alpha_
;
AccDataType
beta_
;
AccDataType
beta_
;
index_t
rank_
;
std
::
vector
<
index_t
>
sm_reduce_dims_
;
std
::
vector
<
index_t
>
sm_reduce_dims_
;
std
::
vector
<
index_t
>
sm_scalar_dims_
;
// dim after internal max/sum reduction
std
::
vector
<
index_t
>
sm_scalar_dims_
;
// dim after internal max/sum reduction
};
};
...
@@ -136,10 +134,9 @@ struct ReferenceSoftmax : public device::BaseOperator
...
@@ -136,10 +134,9 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor
<
OutDataType
>&
out
,
Tensor
<
OutDataType
>&
out
,
AccDataType
alpha
,
AccDataType
alpha
,
AccDataType
beta
,
AccDataType
beta
,
const
index_t
rank
,
const
std
::
vector
<
index_t
>
sm_reduce_dims
)
const
std
::
vector
<
index_t
>
sm_reduce_dims
)
{
{
return
Argument
{
in
,
out
,
alpha
,
beta
,
rank
,
sm_reduce_dims
};
return
Argument
{
in
,
out
,
alpha
,
beta
,
sm_reduce_dims
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance.hpp
View file @
2b27d5fc
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#pragma once
#pragma once
#include <vector>
#include <vector>
#include "ck/utility/functional2.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/device_batched_gemm_instance.hpp
0 → 100644
View file @
2b27d5fc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_batched_gemm_instance
{
using
DeviceBatchedGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
auto
get_device_batched_gemm_instances
()
{
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>
op_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
float
>::
value
&&
is_same
<
BDataType
,
float
>::
value
&&
is_same
<
CDataType
,
float
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ADataType
,
bhalf_t
>::
value
&&
is_same
<
BDataType
,
bhalf_t
>::
value
&&
is_same
<
CDataType
,
bhalf_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ADataType
,
int8_t
>::
value
&&
is_same
<
BDataType
,
int8_t
>::
value
&&
is_same
<
CDataType
,
int8_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
}
// namespace device_batched_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp
0 → 100644
View file @
2b27d5fc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
using
Normalize
=
ck
::
tensor_operation
::
element_wise
::
Normalize
;
using
DeviceNormalizeFromMeanMeanSquarePtr
=
ck
::
tensor_operation
::
device
::
DeviceElementwisePtr
<
5
,
1
,
2
,
Normalize
>
;
void
add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances
(
std
::
vector
<
DeviceNormalizeFromMeanMeanSquarePtr
>&
instances
);
template
<
typename
InputType
,
typename
MeanType
,
typename
MeanSquareType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
OutputType
>
auto
get_device_normalize_from_mean_meansquare_instances
()
{
std
::
vector
<
DeviceNormalizeFromMeanMeanSquarePtr
>
op_ptrs
;
if
constexpr
(
is_same
<
InputType
,
half_t
>::
value
&&
is_same
<
MeanType
,
float
>::
value
&&
is_same
<
MeanSquareType
,
float
>::
value
&&
is_same
<
GammaDataType
,
half_t
>::
value
&&
is_same
<
BetaDataType
,
half_t
>::
value
&&
is_same
<
OutputType
,
half_t
>::
value
)
{
ck
::
tensor_operation
::
device
::
add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances
(
op_ptrs
);
}
return
op_ptrs
;
}
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_add_add_fastgelu_instance.hpp
0 → 100644
View file @
2b27d5fc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
DeviceGemmAddAddFastGeluPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleDPtr
<
2
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
>
;
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmAddAddFastGeluPtr
>&
);
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmAddAddFastGeluPtr
>&
);
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmAddAddFastGeluPtr
>&
);
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmAddAddFastGeluPtr
>&
);
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
D0DataType
,
typename
D1DataType
,
typename
EDataType
,
typename
ALayout
,
typename
BLayout
,
typename
D0Layout
,
typename
D1Layout
,
typename
ELayout
>
auto
get_device_gemm_add_add_fastgelu_instances
()
{
std
::
vector
<
DeviceGemmAddAddFastGeluPtr
>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_same_v
<
ELayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
is_same_v
<
ELayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_same_v
<
ELayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
is_same_v
<
ELayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_instance.hpp
0 → 100644
View file @
2b27d5fc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
DeviceGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
void
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
auto
get_device_gemm_instances
()
{
std
::
vector
<
DeviceGemmNoOpPtr
>
op_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
float
>::
value
&&
is_same
<
BDataType
,
float
>::
value
&&
is_same
<
CDataType
,
float
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ADataType
,
ck
::
bhalf_t
>::
value
&&
is_same
<
BDataType
,
ck
::
bhalf_t
>::
value
&&
is_same
<
CDataType
,
ck
::
bhalf_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ADataType
,
int8_t
>::
value
&&
is_same
<
BDataType
,
int8_t
>::
value
&&
is_same
<
CDataType
,
int8_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp
0 → 100644
View file @
2b27d5fc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
DeviceGemmAddAddMeanSquareMeanPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmReducePtr
<
1
,
2
>
;
void
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmAddAddMeanSquareMeanPtr
>&
);
void
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmAddAddMeanSquareMeanPtr
>&
);
void
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmAddAddMeanSquareMeanPtr
>&
);
void
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmAddAddMeanSquareMeanPtr
>&
);
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
auto
get_device_gemm_add_add_mean_squaremean_instances
()
{
std
::
vector
<
DeviceGemmAddAddMeanSquareMeanPtr
>
op_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_splitk_instance.hpp
0 → 100644
View file @
2b27d5fc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
DeviceGemmSplitKNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmSplitKPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
void
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmSplitKNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmSplitKNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmSplitKNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmSplitKNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmSplitKNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmSplitKNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmSplitKNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmSplitKNoOpPtr
>&
);
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
auto
get_device_gemm_splitk_instances
()
{
std
::
vector
<
DeviceGemmSplitKNoOpPtr
>
op_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
float
>::
value
&&
is_same
<
BDataType
,
float
>::
value
&&
is_same
<
CDataType
,
float
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/utility/check_err.hpp
View file @
2b27d5fc
...
@@ -159,7 +159,7 @@ check_err(const std::vector<T>& out,
...
@@ -159,7 +159,7 @@ check_err(const std::vector<T>& out,
const
std
::
vector
<
T
>&
ref
,
const
std
::
vector
<
T
>&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
=
0
,
double
=
0
,
double
=
0
)
double
atol
=
0
)
{
{
if
(
out
.
size
()
!=
ref
.
size
())
if
(
out
.
size
()
!=
ref
.
size
())
{
{
...
@@ -179,7 +179,7 @@ check_err(const std::vector<T>& out,
...
@@ -179,7 +179,7 @@ check_err(const std::vector<T>& out,
int64_t
r
=
ref
[
i
];
int64_t
r
=
ref
[
i
];
err
=
std
::
abs
(
o
-
r
);
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
0
)
if
(
err
>
atol
)
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
...
...
library/src/host_tensor/host_tensor.cpp
View file @
2b27d5fc
...
@@ -54,25 +54,3 @@ std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc)
...
@@ -54,25 +54,3 @@ std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc)
return
os
;
return
os
;
}
}
void
ostream_HostTensorDescriptor
(
const
HostTensorDescriptor
&
desc
,
std
::
ostream
&
os
)
{
os
<<
"dim "
<<
desc
.
GetNumOfDimension
()
<<
", "
;
os
<<
"lengths {"
;
LogRange
(
os
,
desc
.
GetLengths
(),
", "
);
os
<<
"}, "
;
os
<<
"strides {"
;
LogRange
(
os
,
desc
.
GetStrides
(),
", "
);
os
<<
"}"
<<
std
::
endl
;
}
#if 1
// FIXME: remove
void
bf16_to_f32_
(
const
Tensor
<
ck
::
bhalf_t
>&
src
,
Tensor
<
float
>&
dst
)
{
for
(
std
::
size_t
i
=
0
;
i
<
src
.
mData
.
size
();
++
i
)
dst
.
mData
[
i
]
=
ck
::
type_convert
<
float
>
(
src
.
mData
[
i
]);
}
#endif
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
2b27d5fc
...
@@ -5,44 +5,51 @@ function(add_instance_library INSTANCE_NAME)
...
@@ -5,44 +5,51 @@ function(add_instance_library INSTANCE_NAME)
set_target_properties
(
${
INSTANCE_NAME
}
PROPERTIES POSITION_INDEPENDENT_CODE ON
)
set_target_properties
(
${
INSTANCE_NAME
}
PROPERTIES POSITION_INDEPENDENT_CODE ON
)
endfunction
(
add_instance_library INSTANCE_NAME
)
endfunction
(
add_instance_library INSTANCE_NAME
)
add_subdirectory
(
elementwise
)
add_subdirectory
(
gemm
)
add_subdirectory
(
gemm
)
add_subdirectory
(
gemm_splitk
)
add_subdirectory
(
gemm_bias2d
)
add_subdirectory
(
gemm_bias2d
)
add_subdirectory
(
gemm_bias_relu
)
add_subdirectory
(
gemm_bias_relu
)
add_subdirectory
(
gemm_bias_relu_add
)
add_subdirectory
(
gemm_bias_relu_add
)
add_subdirectory
(
gemm_reduce
)
add_subdirectory
(
gemm_reduce
)
add_subdirectory
(
gemm_bias_add_reduce
)
add_subdirectory
(
gemm_bias_add_reduce
)
add_subdirectory
(
gemm_add_add_fastgelu
)
add_subdirectory
(
batched_gemm
)
add_subdirectory
(
batched_gemm
)
add_subdirectory
(
batched_gemm_reduce
)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
conv1d_fwd
)
add_subdirectory
(
conv1d_fwd
)
add_subdirectory
(
conv2d_fwd
)
add_subdirectory
(
conv2d_fwd
)
add_subdirectory
(
conv3d_fwd
)
add_subdirectory
(
conv3d_fwd
)
add_subdirectory
(
conv2d_fwd_bias_relu
)
add_subdirectory
(
conv2d_fwd_bias_relu
)
add_subdirectory
(
conv2d_fwd_bias_relu_add
)
add_subdirectory
(
conv2d_fwd_bias_relu_add
)
add_subdirectory
(
conv2d_bwd_data
)
add_subdirectory
(
conv2d_bwd_data
)
add_subdirectory
(
reduce
)
add_subdirectory
(
convnd_bwd_data
)
add_subdirectory
(
convnd_bwd_data
)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
conv2d_bwd_weight
)
add_subdirectory
(
conv2d_bwd_weight
)
add_subdirectory
(
batched_gemm_reduce
)
add_subdirectory
(
normalization
)
add_subdirectory
(
gemm_add_add_fastgelu
)
add_subdirectory
(
reduce
)
add_library
(
device_operations STATIC
add_library
(
device_operations STATIC
$<TARGET_OBJECTS:device_conv1d_fwd_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance>
$<TARGET_OBJECTS:device_gemm_instance>
$<TARGET_OBJECTS:device_gemm_instance>
$<TARGET_OBJECTS:device_gemm_splitk_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_add_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_add_instance>
$<TARGET_OBJECTS:device_gemm_bias_add_reduce_instance>
$<TARGET_OBJECTS:device_gemm_bias2d_instance>
$<TARGET_OBJECTS:device_gemm_bias2d_instance>
$<TARGET_OBJECTS:device_reduce_instance>
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
$<TARGET_OBJECTS:device_convnd_bwd_data_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_grouped_gemm_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_weight_instance>
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
$<TARGET_OBJECTS:device_grouped_gemm_instance>
$<TARGET_OBJECTS:device_conv1d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_instance>
$<TARGET_OBJECTS:device_conv3d_fwd_instance>
$<TARGET_OBJECTS:device_conv3d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
$<TARGET_OBJECTS:device_convnd_bwd_data_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_weight_instance>
$<TARGET_OBJECTS:device_elementwise_instance>
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
$<TARGET_OBJECTS:device_reduce_instance>
)
)
add_library
(
composablekernels::device_operations ALIAS device_operations
)
add_library
(
composablekernels::device_operations ALIAS device_operations
)
...
@@ -67,8 +74,8 @@ target_include_directories(device_operations PUBLIC
...
@@ -67,8 +74,8 @@ target_include_directories(device_operations PUBLIC
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_operation/gpu/thread>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_operation/gpu/thread>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_operation/gpu/element>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_operation/gpu/element>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/host_tensor>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/host_tensor>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/host>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu/reduce>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu/reduce>
)
)
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment