Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c4411860
Commit
c4411860
authored
Aug 30, 2023
by
Jing Zhang
Browse files
Merge remote-tracking branch 'origin/develop' into grouped_gemm_dev_args_splitk
parents
b93575ca
38ada109
Changes
27
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1405 additions
and
31 deletions
+1405
-31
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+3
-0
example/01_gemm/gemm_xdl_bf16_rtn.cpp
example/01_gemm/gemm_xdl_bf16_rtn.cpp
+39
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
...pl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
+5
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
...ion/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
+6
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
...ation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
+4
-0
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
...evice/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+4
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
...n/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
+9
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp
...pl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp
+4
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
...vice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+2
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
.../impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
+4
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
+4
-0
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+45
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+6
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+19
-22
include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp
...grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp
+1076
-0
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
+2
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+2
-2
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+3
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_add.hpp
...brary/tensor_operation_instance/gpu/gemm_multiply_add.hpp
+161
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt
...r_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt
+7
-0
No files found.
example/01_gemm/CMakeLists.txt
View file @
c4411860
...
...
@@ -40,6 +40,9 @@ endif()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_xdl_bf16 gemm_xdl_bf16.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16
)
add_example_executable
(
example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16_rtn
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
...
...
example/01_gemm/gemm_xdl_bf16_rtn.cpp
0 → 100644
View file @
c4411860
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
using
ADataType
=
ck
::
bhalf_t
;
using
BDataType
=
ck
::
bhalf_t
;
using
CDataType
=
ck
::
bhalf_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
#include "run_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
View file @
c4411860
...
...
@@ -543,9 +543,13 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
EGridDesc_G_M_N
e_grid_desc_g_m_n_
;
};
using
ComputeDataType
=
ADataType
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
View file @
c4411860
...
...
@@ -331,8 +331,13 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
EGridDesc_G0_G1_M_N
e_grid_desc_g0_g1_m_n_
;
};
using
ComputeDataType
=
ADataType
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
ck
::
Tuple
<>
,
// DsDataType,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
View file @
c4411860
...
...
@@ -324,8 +324,12 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
index_t
BatchStrideE_
;
};
using
ComputeDataType
=
ADataType
;
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
BDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
c4411860
...
...
@@ -310,9 +310,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({{}},
{{}}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
({},
{}));
using
ComputeDataType
=
ADataType
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
BDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
c4411860
...
...
@@ -20,7 +20,8 @@
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
ABDataType
,
typename
ADataType
,
typename
BDataType
,
typename
DsPointer
,
typename
EDataType
,
typename
AElementwiseOperation
,
...
...
@@ -36,8 +37,8 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_multiple_d_xdl_cshuffle
(
const
A
B
DataType
*
__restrict__
p_a_grid
,
const
A
BDataType
*
__restrict__
p_b_grid
,
kernel_gemm_multiple_d_xdl_cshuffle
(
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
...
...
@@ -242,9 +243,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
ComputeDataType
=
EDataType
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
BDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
...
...
@@ -442,6 +447,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
const
auto
kernel
=
kernel_gemm_multiple_d_xdl_cshuffle
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
BDataType
,
// TODO: distiguish A/B datatype
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp
View file @
c4411860
...
...
@@ -355,9 +355,13 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({{}},
{{}}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
({},
{}));
using
ComputeDataType
=
ADataType
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
BDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
View file @
c4411860
...
...
@@ -355,6 +355,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ABDataType
,
// TODO: distinguish A/B datatype
ABDataType
,
// TODO: distinguish A/B datatype
ABDataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CShuffleDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
View file @
c4411860
...
...
@@ -367,9 +367,13 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}))
>
;
using
ComputeDataType
=
ADataType
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
BDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
View file @
c4411860
...
...
@@ -228,9 +228,13 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
ComputeDataType
=
ADataType
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
BDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
c4411860
...
...
@@ -195,6 +195,51 @@ struct AddMultiply
}
};
// C = A * B
// E = C x D0 + D1
struct
MultiplyAdd
{
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
__host__
__device__
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
half_t
,
half_t
,
half_t
>
(
half_t
&
e
,
const
half_t
&
c
,
const
half_t
&
d0
,
const
half_t
&
d1
)
const
{
const
half_t
y
=
(
c
*
d0
)
+
d1
;
e
=
y
;
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
float
,
half_t
,
half_t
>
(
half_t
&
e
,
const
float
&
c
,
const
half_t
&
d0
,
const
half_t
&
d1
)
const
{
const
half_t
y
=
type_convert
<
half_t
>
(
c
)
*
d0
+
d1
;
e
=
y
;
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
,
half_t
,
half_t
>
(
float
&
e
,
const
float
&
c
,
const
half_t
&
d0
,
const
half_t
&
d1
)
const
{
const
float
y
=
c
*
d0
+
d1
;
e
=
y
;
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
float
,
float
,
float
>
(
half_t
&
e
,
const
float
&
c
,
const
float
&
d0
,
const
float
&
d1
)
const
{
const
float
y
=
c
*
d0
+
d1
;
e
=
y
;
}
};
// E = FastGelu(C + D0 + D1)
struct
AddAddFastGelu
{
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
c4411860
...
...
@@ -39,6 +39,12 @@ struct PassThrough
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
float
>
(
half_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
half_t
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
c4411860
...
...
@@ -29,7 +29,9 @@ namespace ck {
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template
<
typename
ABDataType
,
// FIXME: don't assume A/B have same datatype
template
<
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType_
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
...
...
@@ -97,15 +99,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// ABDataTypeAdjusted -> ABDataType throughout this file
#if CK_WORKAROUND_DENORM_FIX
using
AB
DataType
Adjusted
=
conditional_t
<
is_same_v
<
AB
DataType
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
AB
DataType
>
;
using
Compute
DataType
=
conditional_t
<
is_same_v
<
Compute
DataType
_
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
Compute
DataType
_
>
;
#else
using
AB
DataType
Adjusted
=
AB
DataType
;
using
Compute
DataType
=
Compute
DataType
_
;
#endif
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
...
...
@@ -175,7 +173,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
AB
DataType
),
sizeof
(
Compute
DataType
),
c_block_size
*
sizeof
(
CShuffleDataType
));
}
...
...
@@ -318,8 +316,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// check tensor size: cannot be larger than 2GB each
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
if
(
!
(
a_grid_desc_m_k
.
GetElementSpaceSize
()
*
sizeof
(
A
B
DataType
)
<=
TwoGB
&&
b_grid_desc_n_k
.
GetElementSpaceSize
()
*
sizeof
(
A
BDataType
)
<=
TwoGB
&&
if
(
!
(
a_grid_desc_m_k
.
GetElementSpaceSize
()
*
sizeof
(
ADataType
)
<=
TwoGB
&&
b_grid_desc_n_k
.
GetElementSpaceSize
()
*
sizeof
(
BDataType
)
<=
TwoGB
&&
e_grid_desc_m_n
.
GetElementSpaceSize
()
*
sizeof
(
EDataType
)
<=
TwoGB
))
{
return
false
;
...
...
@@ -431,8 +429,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
typename
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
A
B
DataType
*
__restrict__
p_a_grid
,
const
A
BDataType
*
__restrict__
p_b_grid
,
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
DsGridPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
void
*
__restrict__
p_shared
,
...
...
@@ -501,8 +499,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
Sequence
<
AK0PerBlock
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
A
B
DataType
,
AB
DataType
Adjusted
,
ADataType
,
Compute
DataType
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
...
...
@@ -532,8 +530,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
Sequence
<
BK0PerBlock
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
A
BDataType
,
AB
DataType
Adjusted
,
BDataType
,
Compute
DataType
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
...
...
@@ -563,11 +561,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
AB
DataType
Adjusted
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
Compute
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
AB
DataType
Adjusted
,
Compute
DataType
,
AccDataType
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
...
@@ -585,11 +583,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ABDataTypeAdjusted
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
static_cast
<
ComputeDataType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
AB
DataType
Adjusted
*>
(
p_shared
)
+
a_block_space_size_aligned
,
static_cast
<
Compute
DataType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp
0 → 100644
View file @
c4411860
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
View file @
c4411860
...
...
@@ -104,13 +104,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1
// apply pointwise operation
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
Src
Data
v
;
Dst
Data
v
;
// apply element-wise operation
element_op_
(
v
,
src_vector_container
.
template
AsType
<
SrcData
>()[
i
]);
// apply type convert
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
v
)
;
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
v
;
});
const
bool
is_dst_valid
=
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
c4411860
...
...
@@ -92,11 +92,11 @@ struct ReferenceGemm : public device::BaseOperator
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
}
Acc
DataType
v_c
;
C
DataType
v_c
;
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_m_n_
(
m
,
n
)
=
ck
::
type_convert
<
CDataType
>
(
v_c
)
;
arg
.
c_m_n_
(
m
,
n
)
=
v_c
;
};
make_ParallelTensorFunctor
(
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
c4411860
...
...
@@ -32,6 +32,8 @@ using F32_Tuple = ck::Tuple<F32>;
using
I32_Tuple
=
ck
::
Tuple
<
I32
>
;
using
I32_F32_Tuple
=
ck
::
Tuple
<
I32
,
F32
>
;
using
F32_F32_Tuple
=
ck
::
Tuple
<
F32
,
F32
>
;
// GEMM layout
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
@@ -95,6 +97,7 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using
AddReluAdd
=
ck
::
tensor_operation
::
element_wise
::
AddReluAdd
;
using
FastGelu
=
ck
::
tensor_operation
::
element_wise
::
FastGelu
;
using
AddMultiply
=
ck
::
tensor_operation
::
element_wise
::
AddMultiply
;
using
MultiplyAdd
=
ck
::
tensor_operation
::
element_wise
::
MultiplyAdd
;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
using
Gelu
=
ck
::
tensor_operation
::
element_wise
::
Gelu
;
using
Swish
=
ck
::
tensor_operation
::
element_wise
::
Swish
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_add.hpp
0 → 100644
View file @
c4411860
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F16_F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
MultiplyAdd
>>>&
);
void
add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F16_F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
MultiplyAdd
>>>&
);
void
add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Row_Row_Tuple
,
Row
,
F16
,
F8
,
F32_F32_Tuple
,
F16
,
PassThrough
,
PassThrough
,
MultiplyAdd
>>>&
);
void
add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Row_Row_Tuple
,
Row
,
F16
,
F8
,
F32_F32_Tuple
,
F16
,
PassThrough
,
PassThrough
,
MultiplyAdd
>>>&
);
// GEMM + Multiply + Add
template
<
typename
ALayout
,
typename
BLayout
,
typename
D0Layout
,
typename
D1Layout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
D0DataType
,
typename
D1DataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
,
D1Layout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
,
D1DataType
>
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyAdd
>>
{
using
DeviceOp
=
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
,
D1Layout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
,
D1DataType
>
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyAdd
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
D0DataType
,
half_t
>
&&
is_same_v
<
D1DataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
D1Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
D1Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances
(
op_ptrs
);
}
}
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
D0DataType
,
float
>
&&
is_same_v
<
D1DataType
,
float
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
D1Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
D1Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt
0 → 100644
View file @
c4411860
add_instance_library
(
device_gemm_multiply_add_instance
device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment