Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5d015452
Commit
5d015452
authored
Jul 06, 2022
by
Chaitanya Inumella
Browse files
Rebased the hipTENSOR development branch with the contraction branch
parents
b7fa6bb1
ed3feb4d
Changes
425
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
939 additions
and
525 deletions
+939
-525
library/include/ck/library/obselete_driver_offline/driver_gemm_xdlops_v2r3.hpp
...brary/obselete_driver_offline/driver_gemm_xdlops_v2r3.hpp
+0
-220
library/include/ck/library/obselete_driver_offline/driver_gemm_xdlops_v2r4.hpp
...brary/obselete_driver_offline/driver_gemm_xdlops_v2r4.hpp
+0
-213
library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp
...reference_tensor_operation/cpu/reference_batched_gemm.hpp
+13
-11
library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp
...ibrary/reference_tensor_operation/cpu/reference_cgemm.hpp
+7
-27
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp
...e_tensor_operation/cpu/reference_conv_backward_weight.hpp
+6
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
...eference_tensor_operation/cpu/reference_conv_bwd_data.hpp
+10
-8
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_conv_fwd.hpp
+5
-3
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp
...nsor_operation/cpu/reference_conv_fwd_bias_activation.hpp
+7
-5
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp
..._operation/cpu/reference_conv_fwd_bias_activation_add.hpp
+7
-5
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+14
-8
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp
...reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp
+9
-7
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp
...e_tensor_operation/cpu/reference_gemm_bias_activation.hpp
+8
-5
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp
...nsor_operation/cpu/reference_gemm_bias_activation_add.hpp
+8
-5
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp
...ference_tensor_operation/cpu/reference_gemm_layernorm.hpp
+236
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp
...rary/reference_tensor_operation/cpu/reference_softmax.hpp
+164
-0
library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp
...library/reference_tensor_operation/gpu/naive_conv_fwd.hpp
+3
-0
library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp
...nsor_operation_instance/add_device_operation_instance.hpp
+15
-6
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+40
-0
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp
...ck/library/tensor_operation_instance/gpu/batched_gemm.hpp
+259
-0
library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp
...ry/tensor_operation_instance/gpu/contraction_bilinear.hpp
+128
-0
No files found.
Too many changes to show.
To preserve performance only
425 of 425+
files are displayed.
Plain diff
Email patch
library/include/ck/library/obselete_driver_offline/driver_gemm_xdlops_v2r3.hpp
deleted
100644 → 0
View file @
b7fa6bb1
#ifndef DRIVER_GEMM_XDLOPS_V2R3_HPP
#define DRIVER_GEMM_XDLOPS_V2R3_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
#include "element_wise_operation.hpp"
template
<
ck
::
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
ck
::
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K
,
typename
CMNGridDesc
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
K1
,
ck
::
index_t
MRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferThreadSliceLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
typename
CThreadTransferSrcDstAccessOrder
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
typename
AGridStepHacks
,
typename
BGridStepHacks
,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
bool
CAccessOrderMRepeatNRepeat
,
bool
ABlockLdsAddExtraM
,
bool
BBlockLdsAddExtraN
>
__host__
float
driver_gemm_xdlops_v2r3
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K
&
b_grid_desc_k0_n_k1
,
const
CMNGridDesc
&
c_grid_desc_m_n
,
ck
::
index_t
M01
,
ck
::
index_t
N01
,
AGridStepHacks
,
BGridStepHacks
,
CGridStepHacks
,
AGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
using
ElementwiseOperation
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
FloatAB
,
FloatAcc
,
FloatC
,
CGlobalMemoryDataOperation
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K
,
CMNGridDesc
,
ElementwiseOperation
,
ElementwiseOperation
,
ElementwiseOperation
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
AThreadTransferSrcResetCoordinateAfterRun
,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BBlockLdsAddExtraN
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
>
;
{
std
::
cout
<<
"a_grid_desc_k0_m_k1{"
<<
a_grid_desc_k0_m_k1
.
GetLength
(
I0
)
<<
", "
<<
a_grid_desc_k0_m_k1
.
GetLength
(
I1
)
<<
", "
<<
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"b_grid_desc_k0_n_k1{"
<<
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
<<
", "
<<
b_grid_desc_k0_n_k1
.
GetLength
(
I1
)
<<
", "
<<
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"c_grid_desc_m_n{ "
<<
c_grid_desc_m_n
.
GetLength
(
I0
)
<<
", "
<<
c_grid_desc_m_n
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m_n
,
M01
,
N01
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
}
const
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n
);
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
const
auto
block_2_ctile_map
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
,
M01
,
N01
);
using
Block2CTileMap
=
decltype
(
block_2_ctile_map
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
c_grid_desc_m_n
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
float
ave_time
=
0
;
auto
element_op_
=
ElementwiseOperation
{};
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
BGridDesc_K0_N_K
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
ElementwiseOperation
,
ElementwiseOperation
,
ElementwiseOperation
,
remove_reference_t
<
Block2CTileMap
>
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
element_op_
,
element_op_
,
element_op_
,
block_2_ctile_map
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
BGridDesc_K0_N_K
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
ElementwiseOperation
,
ElementwiseOperation
,
ElementwiseOperation
,
remove_reference_t
<
Block2CTileMap
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
element_op_
,
element_op_
,
element_op_
,
block_2_ctile_map
);
}
return
ave_time
;
}
#endif
library/include/ck/library/obselete_driver_offline/driver_gemm_xdlops_v2r4.hpp
deleted
100644 → 0
View file @
b7fa6bb1
#ifndef DRIVER_GEMM_XDLOPS_V2R4
#define DRIVER_GEMM_XDLOPS_V2R4
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4.hpp"
template
<
ck
::
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
ck
::
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
ABK0MK1GridDesc
,
typename
BBK0NK1GridDesc
,
typename
CMNGridDesc
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
K1
,
ck
::
index_t
MRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferThreadSliceLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
typename
CThreadTransferSrcDstAccessOrder
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
typename
AGridStepHacks
,
typename
BGridStepHacks
,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
bool
CAccessOrderMRepeatNRepeat
,
bool
ABlockLdsAddExtraM
,
bool
BBlockLdsAddExtraN
>
__host__
float
driver_gemm_xdlops_v2r4
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
const
ABK0MK1GridDesc
&
a_b_k0_m_k1_grid_desc
,
const
BBK0NK1GridDesc
&
b_b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
ck
::
index_t
M01
,
ck
::
index_t
N01
,
AGridStepHacks
,
BGridStepHacks
,
CGridStepHacks
,
AGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
<
BlockSize
,
FloatAB
,
FloatAcc
,
FloatC
,
CGlobalMemoryDataOperation
,
ABK0MK1GridDesc
,
BBK0NK1GridDesc
,
CMNGridDesc
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
AThreadTransferSrcResetCoordinateAfterRun
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BThreadTransferSrcResetCoordinateAfterRun
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
AGridStepHacks
,
BGridStepHacks
,
CGridStepHacks
,
AGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
,
CAccessOrderMRepeatNRepeat
,
ABlockLdsAddExtraM
,
BBlockLdsAddExtraN
>
;
{
std
::
cout
<<
"a_b_k0_m_k1_grid_desc{"
<<
a_b_k0_m_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
)
<<
", "
<<
a_b_k0_m_k1_grid_desc
.
GetLength
(
I2
)
<<
", "
<<
a_b_k0_m_k1_grid_desc
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"b_b_k0_n_k1_grid_desc{"
<<
b_b_k0_n_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
b_b_k0_n_k1_grid_desc
.
GetLength
(
I1
)
<<
", "
<<
b_b_k0_n_k1_grid_desc
.
GetLength
(
I2
)
<<
", "
<<
b_b_k0_n_k1_grid_desc
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_grid_desc{ "
<<
c_m_n_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
c_m_n_grid_desc
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
M01
,
N01
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r4 has invalid setting"
);
}
const
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
GridwiseGemm
::
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
c_m_n_grid_desc
);
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
const
auto
KBatch
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
,
M01
,
N01
,
KBatch
);
using
CBlockClusterAdaptor
=
decltype
(
c_block_cluster_adaptor
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
c_m_n_grid_desc
,
KBatch
);
{
std
::
cout
<<
"gridSize : "
<<
grid_size
<<
std
::
endl
;
}
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
float
ave_time
=
0
;
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
ABK0MK1GridDesc
>
,
remove_reference_t
<
BBK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
ABK0MK1GridDesc
>
,
remove_reference_t
<
BBK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
}
return
ave_time
;
}
#endif
library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp
View file @
5d015452
#ifndef REFERENCE_BATCHED_GEMM_HPP
// SPDX-License-Identifier: MIT
#define REFERENCE_BATCHED_GEMM_HPP
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -59,20 +62,20 @@ struct ReferenceBatchedGemm : public device::BaseOperator
...
@@ -59,20 +62,20 @@ struct ReferenceBatchedGemm : public device::BaseOperator
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
float
v_a
;
ADataType
v_a
;
float
v_b
;
BDataType
v_b
;
arg
.
a_element_op_
(
v_a
,
static_cast
<
const
float
>
(
arg
.
a_g_m_k_
(
g
,
m
,
k
))
)
;
arg
.
a_element_op_
(
v_a
,
arg
.
a_g_m_k_
(
g
,
m
,
k
));
arg
.
b_element_op_
(
v_b
,
static_cast
<
const
float
>
(
arg
.
b_g_k_n_
(
g
,
k
,
n
))
)
;
arg
.
b_element_op_
(
v_b
,
arg
.
b_g_k_n_
(
g
,
k
,
n
));
v_acc
+=
v_a
*
v_b
;
v_acc
+=
ck
::
type_convert
<
float
>
(
v_a
)
*
ck
::
type_convert
<
float
>
(
v_b
)
;
}
}
float
v_c
;
float
v_c
;
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_g_m_n_
(
g
,
m
,
n
)
=
v_c
;
arg
.
c_g_m_n_
(
g
,
m
,
n
)
=
ck
::
type_convert
<
CDataType
>
(
v_c
)
;
};
};
make_ParallelTensorFunctor
(
f_gmk_gkn_gmn
,
make_ParallelTensorFunctor
(
f_gmk_gkn_gmn
,
...
@@ -132,4 +135,3 @@ struct ReferenceBatchedGemm : public device::BaseOperator
...
@@ -132,4 +135,3 @@ struct ReferenceBatchedGemm : public device::BaseOperator
}
// namespace host
}
// namespace host
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp
View file @
5d015452
/*******************************************************************************
// SPDX-License-Identifier: MIT
*
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
* MIT License
*
* Copyright (c) 2022 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#pragma once
#pragma once
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp
View file @
5d015452
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
View file @
5d015452
#ifndef REFERENCE_CONV_BWD_DATA_HPP
// SPDX-License-Identifier: MIT
#define REFERENCE_CONV_BWD_DATA_HPP
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -106,9 +110,8 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -106,9 +110,8 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
}
}
float
v_in
;
arg
.
in_element_op_
(
v_acc
,
v_acc
);
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
input_
(
n
,
c
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_acc
);
arg
.
input_
(
n
,
c
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_in
);
};
};
make_ParallelTensorFunctor
(
f_ncw
,
make_ParallelTensorFunctor
(
f_ncw
,
...
@@ -352,4 +355,3 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -352,4 +355,3 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
// namespace host
}
// namespace host
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
View file @
5d015452
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <iostream>
#include <iostream>
#include <type_traits>
#include <type_traits>
#include <sstream>
#include <sstream>
#include "stream_config.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "host_tensor.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp
View file @
5d015452
#ifndef REFERENCE_CONV_FWD_BIAS_ACTIVATION_HPP
// SPDX-License-Identifier: MIT
#define REFERENCE_CONV_FWD_BIAS_ACTIVATION_HPP
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -187,4 +190,3 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
...
@@ -187,4 +190,3 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
}
// namespace host
}
// namespace host
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp
View file @
5d015452
#ifndef REFERENCE_CONV2D_FWD_BIAS_ACTIVATION_ADD_HPP
// SPDX-License-Identifier: MIT
#define REFERENCE_CONV2D_FWD_BIAS_ACTIVATION_ADD_HPP
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -195,4 +198,3 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
...
@@ -195,4 +198,3 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
}
// namespace host
}
// namespace host
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
5d015452
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -58,20 +63,21 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -58,20 +63,21 @@ struct ReferenceGemm : public device::BaseOperator
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
A
cc
DataType
v_a
;
ADataType
v_a
;
Acc
DataType
v_b
;
B
DataType
v_b
;
arg
.
a_element_op_
(
v_a
,
static_cast
<
const
AccDataType
>
(
arg
.
a_m_k_
(
m
,
k
))
)
;
arg
.
a_element_op_
(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
arg
.
b_element_op_
(
v_b
,
static_cast
<
const
AccDataType
>
(
arg
.
b_k_n_
(
k
,
n
))
)
;
arg
.
b_element_op_
(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
v_acc
+=
v_a
*
v_b
;
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
}
}
AccDataType
v_c
;
AccDataType
v_c
;
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_m_n_
(
m
,
n
)
=
v_c
;
arg
.
c_m_n_
(
m
,
n
)
=
ck
::
type_convert
<
CDataType
>
(
v_c
)
;
};
};
make_ParallelTensorFunctor
(
make_ParallelTensorFunctor
(
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp
View file @
5d015452
#ifndef REFERENCE_GEMM_BIAS_BIAS_2D_HPP
// SPDX-License-Identifier: MIT
#define REFERENCE_GEMM_BIAS_BIAS_2D_HPP
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -66,8 +69,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator
...
@@ -66,8 +69,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
arg
.
a_element_op_
(
a
,
arg
.
a_m_k_
(
m
,
k
));
arg
.
a_element_op_
(
a
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
a_m_k_
(
m
,
k
))
)
;
arg
.
b_element_op_
(
b
,
arg
.
b_k_n_
(
k
,
n
));
arg
.
b_element_op_
(
b
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
b_k_n_
(
k
,
n
))
)
;
acc
+=
a
*
b
;
acc
+=
a
*
b
;
}
}
...
@@ -131,4 +134,3 @@ struct ReferenceGemmBias2D : public device::BaseOperator
...
@@ -131,4 +134,3 @@ struct ReferenceGemmBias2D : public device::BaseOperator
}
// namespace host
}
// namespace host
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp
View file @
5d015452
#ifndef REFERENCE_GEMM_BIAS_ACTIVATION_HPP
// SPDX-License-Identifier: MIT
#define REFERENCE_GEMM_BIAS_ACTIVATION_HPP
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -134,4 +138,3 @@ struct ReferenceGemmBiasActivation : public device::BaseOperator
...
@@ -134,4 +138,3 @@ struct ReferenceGemmBiasActivation : public device::BaseOperator
}
// namespace host
}
// namespace host
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp
View file @
5d015452
#ifndef REFERENCE_GEMM_BIAS_ACTIVATION_ADD_HPP
// SPDX-License-Identifier: MIT
#define REFERENCE_GEMM_BIAS_ACTIVATION_ADD_HPP
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -142,4 +146,3 @@ struct ReferenceGemmBiasActivationAdd : public device::BaseOperator
...
@@ -142,4 +146,3 @@ struct ReferenceGemmBiasActivationAdd : public device::BaseOperator
}
// namespace host
}
// namespace host
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp
0 → 100644
View file @
5d015452
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta)
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
C0DataType
,
typename
AccDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
CElementwiseOperation
>
struct
ReferenceGemmLayernorm
:
public
device
::
BaseOperator
{
using
ReferenceGemmInstance
=
ReferenceGemm
<
ADataType
,
BDataType
,
AccDataType
,
AccDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
element_wise
::
PassThrough
>
;
template
<
typename
InDataType
,
typename
OutDataType
,
typename
ComputeDataType
>
static
void
RunLayernorm
(
Tensor
<
OutDataType
>&
result
,
const
Tensor
<
ComputeDataType
>&
acc
,
// MxN
const
Tensor
<
InDataType
>&
gamma
,
// 1xN
const
Tensor
<
InDataType
>&
beta
,
// 1xN
const
InDataType
epsilon
=
1e-5
)
{
assert
(
acc
.
mDesc
.
GetLengths
()[
1
]
==
gamma
.
mDesc
.
GetLengths
()[
0
]
&&
acc
.
mDesc
.
GetLengths
()[
1
]
==
beta
.
mDesc
.
GetLengths
()[
0
]);
size_t
M
=
acc
.
mDesc
.
GetLengths
()[
0
];
size_t
N
=
acc
.
mDesc
.
GetLengths
()[
1
];
Tensor
<
ComputeDataType
>
avg_acc_sq
(
HostTensorDescriptor
(
std
::
vector
<
size_t
>
({
M
})));
Tensor
<
ComputeDataType
>
avg_acc
(
HostTensorDescriptor
(
std
::
vector
<
size_t
>
({
M
})));
Tensor
<
ComputeDataType
>
acc_layernorm
(
acc
);
// reduce N dim
for
(
size_t
i
=
0
;
i
<
M
;
i
++
)
{
ComputeDataType
sum_acc_sq
=
0
;
ComputeDataType
sum_acc
=
0
;
for
(
size_t
j
=
0
;
j
<
N
;
j
++
)
{
sum_acc_sq
+=
acc_layernorm
(
i
,
j
)
*
acc_layernorm
(
i
,
j
);
sum_acc
+=
acc_layernorm
(
i
,
j
);
}
avg_acc_sq
(
i
)
=
sum_acc_sq
/
N
;
avg_acc
(
i
)
=
sum_acc
/
N
;
}
// normalize
acc_layernorm
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
[
0
],
idx
[
1
])
=
(
self
(
idx
[
0
],
idx
[
1
])
-
avg_acc
(
idx
[
0
]))
/
sqrt
(
avg_acc_sq
(
idx
[
0
])
-
avg_acc
(
idx
[
0
])
*
avg_acc
(
idx
[
0
])
+
epsilon
);
});
// affine
acc_layernorm
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
[
0
],
idx
[
1
])
=
self
(
idx
[
0
],
idx
[
1
])
*
gamma
(
idx
[
1
])
+
beta
(
idx
[
1
]);
});
// cast
result
=
acc_layernorm
.
template
CopyAsType
<
OutDataType
>();
}
// Argument
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
Tensor
<
ADataType
>&
a_m_k
,
const
Tensor
<
BDataType
>&
b_k_n
,
Tensor
<
CDataType
>&
c_m_n
,
const
Tensor
<
C0DataType
>&
c0_n_bias
,
// 1xN
const
Tensor
<
C0DataType
>&
c0_m_n_add
,
// MxN
const
Tensor
<
C0DataType
>&
c0_n_gamma
,
// 1xN
const
Tensor
<
C0DataType
>&
c0_n_beta
,
// 1xN
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
CElementwiseOperation
c_element_op
,
const
CDataType
epsilon
=
1e-5
)
:
a_m_k_
{
a_m_k
},
b_k_n_
{
b_k_n
},
c_m_n_
{
c_m_n
},
c0_n_bias_
{
c0_n_bias
},
c0_m_n_add_
{
c0_m_n_add
},
c0_n_gamma_
{
c0_n_gamma
},
c0_n_beta_
{
c0_n_beta
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
acc_element_op_
{
acc_element_op
},
c_element_op_
{
c_element_op
},
epsilon_
{
epsilon
}
{
}
const
Tensor
<
ADataType
>&
a_m_k_
;
const
Tensor
<
BDataType
>&
b_k_n_
;
Tensor
<
CDataType
>&
c_m_n_
;
const
Tensor
<
C0DataType
>&
c0_n_bias_
;
const
Tensor
<
C0DataType
>&
c0_m_n_add_
;
const
Tensor
<
C0DataType
>&
c0_n_gamma_
;
const
Tensor
<
C0DataType
>&
c0_n_beta_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
AccElementwiseOperation
acc_element_op_
;
CElementwiseOperation
c_element_op_
;
const
CDataType
epsilon_
;
};
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
{
// using Argument = ReferenceGemm::Argument;
float
Run
(
const
Argument
&
arg
)
{
Tensor
<
AccDataType
>
acc_m_n
(
arg
.
c_m_n_
.
mDesc
);
acc_m_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
AccDataType
>
{
0
});
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
arg
.
a_m_k_
,
arg
.
b_k_n_
,
acc_m_n
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
element_wise
::
PassThrough
{});
// gemm
ref_invoker
.
Run
(
ref_argument
);
// activation(acc + bias)
acc_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
AccDataType
out
;
arg
.
acc_element_op_
(
out
,
acc_m_n
(
idx
[
0
],
idx
[
1
])
+
arg
.
c0_n_bias_
(
idx
[
1
]));
self
(
idx
[
0
],
idx
[
1
])
=
out
;
});
// add from other layers
acc_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
[
0
],
idx
[
1
])
+=
arg
.
c0_m_n_add_
(
idx
[
0
],
idx
[
1
]);
});
// layernorm
RunLayernorm
(
arg
.
c_m_n_
,
acc_m_n
,
arg
.
c0_n_gamma_
,
arg
.
c0_n_beta_
);
// elementwise op
arg
.
c_m_n_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
arg
.
c_element_op_
(
self
(
idx
[
0
],
idx
[
1
]),
self
(
idx
[
0
],
idx
[
1
]));
});
return
0
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
const
Tensor
<
ADataType
>&
a_m_k
,
const
Tensor
<
BDataType
>&
b_k_n
,
Tensor
<
CDataType
>&
c_m_n
,
const
Tensor
<
C0DataType
>&
c0_n_bias
,
// 1xN
const
Tensor
<
C0DataType
>&
c0_m_n_add
,
// 1xN
const
Tensor
<
C0DataType
>&
c0_n_gamma
,
// 1xN
const
Tensor
<
C0DataType
>&
c0_n_beta
,
// 1xN
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
CElementwiseOperation
c_element_op
,
const
CDataType
epsilon
=
1e-5
)
{
return
Argument
{
a_m_k
,
b_k_n
,
c_m_n
,
c0_n_bias
,
c0_m_n_add
,
c0_n_gamma
,
c0_n_beta
,
a_element_op
,
b_element_op
,
acc_element_op
,
c_element_op
,
epsilon
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceGemmLayernorm"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp
0 → 100644
View file @
5d015452
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
template
<
typename
InDataType
,
typename
OutDataType
,
typename
AccDataType
>
struct
ReferenceSoftmax
:
public
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
Tensor
<
InDataType
>&
in
,
Tensor
<
OutDataType
>&
out
,
AccDataType
alpha
,
AccDataType
beta
,
const
std
::
vector
<
index_t
>
sm_reduce_dims
)
:
in_
(
in
),
out_
(
out
),
alpha_
(
alpha
),
beta_
(
beta
),
sm_reduce_dims_
(
sm_reduce_dims
)
{
// std::cout << "debug: scalar dims: ";
for
(
size_t
i
=
0
;
i
<
in
.
mDesc
.
GetNumOfDimension
();
i
++
)
{
if
(
std
::
find
(
sm_reduce_dims
.
begin
(),
sm_reduce_dims
.
end
(),
i
)
==
sm_reduce_dims
.
end
())
{
sm_scalar_dims_
.
push_back
(
i
);
// std::cout << i << ", ";
}
}
// std::cout << std::endl;
}
const
Tensor
<
InDataType
>&
in_
;
Tensor
<
OutDataType
>&
out_
;
AccDataType
alpha_
;
AccDataType
beta_
;
std
::
vector
<
index_t
>
sm_reduce_dims_
;
std
::
vector
<
index_t
>
sm_scalar_dims_
;
// dim after internal max/sum reduction
};
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
{
float
Run
(
const
Argument
&
arg
)
{
std
::
vector
<
size_t
>
scalar_lengths
;
for
(
index_t
dim
:
arg
.
sm_scalar_dims_
)
{
scalar_lengths
.
push_back
(
arg
.
in_
.
mDesc
.
GetLengths
()[
dim
]);
}
Tensor
<
AccDataType
>
reduce_max
(
scalar_lengths
);
reduce_max
.
GenerateTensorValue
(
GeneratorTensor_1
<
AccDataType
>
{
std
::
numeric_limits
<
AccDataType
>::
lowest
()});
Tensor
<
AccDataType
>
reduce_sum
(
scalar_lengths
);
reduce_sum
.
GenerateTensorValue
(
GeneratorTensor_1
<
AccDataType
>
{
0
});
auto
to_sm_scalar_idx
=
[
&
](
auto
idx
)
{
std
::
vector
<
size_t
>
sm_scalar_idx
;
for
(
index_t
dim
:
arg
.
sm_scalar_dims_
)
{
sm_scalar_idx
.
push_back
(
idx
[
dim
]);
}
return
sm_scalar_idx
;
};
arg
.
in_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
reduce_max
(
to_sm_scalar_idx
(
idx
))
=
std
::
max
(
reduce_max
(
to_sm_scalar_idx
(
idx
)),
static_cast
<
AccDataType
>
(
self
(
idx
)));
});
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
// std::endl;
Tensor
<
AccDataType
>
in_stable
(
arg
.
in_
.
mDesc
);
in_stable
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
// numerator = exp(x - max(x))
self
(
idx
)
=
std
::
exp
(
static_cast
<
AccDataType
>
(
arg
.
in_
(
idx
))
-
reduce_max
(
to_sm_scalar_idx
(
idx
)));
});
// LogRangeAsType<float>(std::cout << "in_stable: ", in_stable.mData, ",") << std::endl;
in_stable
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
// denominator = sum(exp(x - max(x)))
reduce_sum
(
to_sm_scalar_idx
(
idx
))
+=
self
(
idx
);
});
// LogRangeAsType<float>(std::cout << "reduce_sum: ", reduce_sum.mData, ",") <<
// std::endl;
arg
.
out_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
arg
.
alpha_
*
in_stable
(
idx
)
/
reduce_sum
(
to_sm_scalar_idx
(
idx
))
+
arg
.
beta_
*
self
(
idx
);
});
// LogRangeAsType<float>(std::cout << "out: ", arg.out_.mData, ",") << std::endl;
// reduction along reduce dims
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
// std::endl; LogRangeAsType<float>(std::cout << "reduce_sum: ", reduce_sum.mData, ",")
// << std::endl;
return
0
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
const
Tensor
<
InDataType
>&
in
,
Tensor
<
OutDataType
>&
out
,
AccDataType
alpha
,
AccDataType
beta
,
const
std
::
vector
<
index_t
>
sm_reduce_dims
)
{
return
Argument
{
in
,
out
,
alpha
,
beta
,
sm_reduce_dims
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceSoftmax"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp
View file @
5d015452
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef NAIVE_CONV_FWD_HPP
#ifndef NAIVE_CONV_FWD_HPP
#define NAIVE_CONV_FWD_HPP
#define NAIVE_CONV_FWD_HPP
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance.hpp
→
library/include/ck/library/tensor_operation_instance/
add_
device_operation_instance.hpp
View file @
5d015452
#ifndef CK_DEVICE_OPERATION_INSTANCE_HPP
// SPDX-License-Identifier: MIT
#define CK_DEVICE_OPERATION_INSTANCE_HPP
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <stdlib.h>
#pragma once
#include <vector>
#include <type_traits>
#include "ck/utility/functional2.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
template
<
typename
OpInstance
,
typename
NewOpInstances
>
template
<
typename
BaseOp
,
typename
NewOpInstances
>
void
add_device_operation_instances
(
std
::
vector
<
std
::
unique_ptr
<
OpInstance
>>&
op_instances
,
void
add_device_operation_instances
(
std
::
vector
<
std
::
unique_ptr
<
BaseOp
>>&
op_instances
,
const
NewOpInstances
&
new_op_instances
)
const
NewOpInstances
&
new_op_instances
)
{
{
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
NewOpInstances
>
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
NewOpInstances
>
,
1
>
{}([
&
](
auto
i
)
{
...
@@ -16,11 +22,14 @@ void add_device_operation_instances(std::vector<std::unique_ptr<OpInstance>>& op
...
@@ -16,11 +22,14 @@ void add_device_operation_instances(std::vector<std::unique_ptr<OpInstance>>& op
using
NewOpInstance
=
remove_cvref_t
<
decltype
(
new_op_instance
)
>
;
using
NewOpInstance
=
remove_cvref_t
<
decltype
(
new_op_instance
)
>
;
static_assert
(
std
::
is_base_of_v
<
BaseOp
,
NewOpInstance
>
,
"wrong! NewOpInstance should be derived from BaseOp"
);
op_instances
.
push_back
(
std
::
make_unique
<
NewOpInstance
>
(
new_op_instance
));
op_instances
.
push_back
(
std
::
make_unique
<
NewOpInstance
>
(
new_op_instance
));
});
});
}
}
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
0 → 100644
View file @
5d015452
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// aliasing, for commonly used type
using
F64
=
double
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
EMPTY_TUPLE
=
ck
::
Tuple
<>
;
using
F16_TUPLE
=
ck
::
Tuple
<
F16
>
;
using
F16_F16_TUPLE
=
ck
::
Tuple
<
F16
,
F16
>
;
using
F32_TUPLE
=
ck
::
Tuple
<
F32
>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
AddAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
;
template
<
typename
DeviceOp
>
struct
DeviceOperationInstanceFactory
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp
0 → 100644
View file @
5d015452
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Row
,
Row
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Row
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceBatchedGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
float
>
&&
is_same_v
<
BDataType
,
float
>
&&
is_same_v
<
CDataType
,
float
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
bhalf_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
int8_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
CDataType
,
int8_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp
0 → 100644
View file @
5d015452
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
F32_TUPLE
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
F32_TUPLE
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
F32_TUPLE
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
F32_TUPLE
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
// Contraction + Bilinear
template
<
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
DDataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceContractionMultipleD
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
ck
::
Tuple
<
DDataType
>
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
Bilinear
>>
{
using
DeviceOp
=
DeviceContractionMultipleD
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
ck
::
Tuple
<
DDataType
>
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
Bilinear
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
float
>
&&
is_same_v
<
BDataType
,
float
>
&&
is_same_v
<
DDataType
,
float
>
&&
is_same_v
<
EDataType
,
float
>
)
{
if
constexpr
(
NumDimM
==
2
&&
NumDimN
==
2
&&
NumDimK
==
2
)
{
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
…
13
14
15
16
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment