Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9f1b4276
Commit
9f1b4276
authored
Apr 04, 2024
by
Jakub Piasecki
Browse files
resolved conflicts
parents
711857c4
c7010716
Changes
198
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1892 additions
and
670 deletions
+1892
-670
example/64_fpAintB_gemm/CMakeLists.txt
example/64_fpAintB_gemm/CMakeLists.txt
+3
-5
example/CMakeLists.txt
example/CMakeLists.txt
+36
-0
include/ck/ck.hpp
include/ck/ck.hpp
+8
-6
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
...ation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
+6
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+9
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
.../impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
+6
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+150
-69
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+104
-0
include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp
...operation/gpu/element/combined_element_wise_operation.hpp
+103
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+239
-9
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp
...ion/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp
+12
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
+22
-17
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+5
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
...gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
+186
-91
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+555
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp
.../reference_tensor_operation/cpu/reference_elementwise.hpp
+110
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
...include/ck/library/tensor_operation_instance/gpu/gemm.hpp
+133
-448
library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp
...k/library/tensor_operation_instance/gpu/gemm_bilinear.hpp
+4
-4
library/include/ck/library/tensor_operation_instance/gpu/gemm_dl.inc
...lude/ck/library/tensor_operation_instance/gpu/gemm_dl.inc
+167
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_wmma.inc
...de/ck/library/tensor_operation_instance/gpu/gemm_wmma.inc
+34
-0
No files found.
example/64_fpAintB_gemm/CMakeLists.txt
View file @
9f1b4276
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_custom_target
(
example_fpAintB_gemm_wmma
)
add_example_executable
(
example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp
)
add_dependencies
(
example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma
)
endif
()
add_custom_target
(
example_fpAintB_gemm_wmma
)
add_example_executable
(
example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp
)
add_example_dependencies
(
example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma
)
example/CMakeLists.txt
View file @
9f1b4276
...
...
@@ -5,6 +5,12 @@ include_directories(BEFORE
add_custom_target
(
examples
)
function
(
add_example_dependencies EXAMPLE_NAME FILE_NAME
)
if
(
FILE_NAME
)
add_dependencies
(
EXAMPLE_NAME FILE_NAME
)
endif
()
endfunction
(
add_example_dependencies EXAMPLE_NAME
)
function
(
add_example_executable EXAMPLE_NAME FILE_NAME
)
message
(
"adding example
${
EXAMPLE_NAME
}
"
)
set
(
result 1
)
...
...
@@ -38,12 +44,27 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
endif
()
endforeach
()
endif
()
#Do not build any DL examples if DL_KERNELS not set
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
message
(
"removing dl example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endforeach
()
#Do not build any XDL examples if gfx9 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT GPU_TARGETS MATCHES
"gfx9"
AND source MATCHES
"_xdl"
)
message
(
"removing xdl example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endforeach
()
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT GPU_TARGETS MATCHES
"gfx11"
AND source MATCHES
"_wmma"
)
message
(
"removing wmma example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endforeach
()
#only continue if there are some source files left on the list
if
(
FILE_NAME
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
...
...
@@ -97,12 +118,27 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endif
()
endforeach
()
endif
()
#Do not build any DL examples if DL_KERNELS not set
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
message
(
"removing dl example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endforeach
()
#Do not build any XDL examples if gfx9 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT GPU_TARGETS MATCHES
"gfx9"
AND source MATCHES
"_xdl"
)
message
(
"removing xdl example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endforeach
()
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT GPU_TARGETS MATCHES
"gfx11"
AND source MATCHES
"_wmma"
)
message
(
"removing wmma example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endforeach
()
#only continue if there are some source files left on the list
if
(
FILE_NAME
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
...
...
include/ck/ck.hpp
View file @
9f1b4276
...
...
@@ -45,6 +45,10 @@
#endif
// define general macros for various architectures
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
#define __gfx9__
#endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
...
...
@@ -62,8 +66,7 @@
// buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__)
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx9__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
...
...
@@ -75,8 +78,7 @@
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
#define CK_USE_AMD_V_MAC_F32
#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx103__) || \
defined(__gfx94__) // for GPU code
#elif defined(__gfx906__) || defined(__gfx9__) || defined(__gfx103__) // for GPU code
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8
...
...
@@ -89,7 +91,7 @@
// MFMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_MFMA
#elif defined(__gfx9
08__) || defined(__gfx90a__) || defined(__gfx94
__) // for GPU code
#elif defined(__gfx9__) // for GPU code
#define CK_USE_AMD_MFMA
#endif
...
...
@@ -120,7 +122,7 @@
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx9
08__) || defined(__gfx90a__) || defined(__gfx94
__) // for GPU code
#elif defined(__gfx9__) // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
View file @
9f1b4276
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -40,7 +40,8 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
* \tparam AElementwiseOperation A elementwise operation.
* \tparam BElementwiseOperation B elementwise operation.
* \tparam CDEElementwiseOperation CDE elementwise operation.
* \tparam ComputeType Compute data type (default: ADataType, first if tuple passed).
* \tparam AComputeType Compute data type for A tensor (default: ADataType, first if tuple passed).
* \tparam BComputeType Compute data type for B tensor (default: AComputeType).
*/
template
<
index_t
NDimSpatial
,
typename
ALayout
,
...
...
@@ -54,12 +55,13 @@ template <index_t NDimSpatial,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
ComputeType
=
typename
A
ComputeType
=
decltype
(
UnpackDataType
<
is_detected
<
is_tuple
,
ADataType
>
::
value
,
Number
<
0
>
,
ADataType
>
())
>
// ComputeType is InputType by default (first
ADataType
>
())
,
//
A
ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
typename
BComputeType
=
AComputeType
>
struct
DeviceGroupedConvFwdMultipleABD
:
public
BaseOperator
{
static
constexpr
bool
isMultiA
=
is_detected
<
is_tuple
,
ADataType
>::
value
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
9f1b4276
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -254,13 +254,14 @@ template <index_t NDimSpatial,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
typename
ComputeDataType
=
typename
A
ComputeDataType
=
decltype
(
UnpackDataType
<
is_detected
<
is_tuple
,
ADataType
>
::
value
,
Number
<
0
>
,
ADataType
>
()),
// ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()
>
typename
BComputeDataType
=
AComputeDataType
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()
>
struct
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
:
public
DeviceGroupedConvFwdMultipleABD
<
NDimSpatial
,
ALayout
,
...
...
@@ -274,7 +275,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
ComputeDataType
>
AComputeDataType
,
BComputeDataType
>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
;
...
...
@@ -386,7 +388,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
using
GemmBDataType
=
std
::
conditional_t
<!
isMultiB
&&
isMultiA
,
Tuple
<
BDataType
>
,
BDataType
>
;
#define GridwiseGemmTemplateParameters \
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType,
\
GemmADataType, GemmBDataType,
A
ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
...
...
@@ -399,7 +401,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
BComputeDataType
// Use appropriate gridwise gemm
using
GridwiseGemm
=
std
::
conditional_t
<
isMultiA
||
isMultiB
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
View file @
9f1b4276
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -75,13 +75,14 @@ template <index_t NDimSpatial,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
typename
ComputeDataType
=
typename
A
ComputeDataType
=
decltype
(
UnpackDataType
<
is_detected
<
is_tuple
,
ADataType
>
::
value
,
Number
<
0
>
,
ADataType
>
()),
// ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()
>
typename
BComputeDataType
=
AComputeDataType
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()
>
using
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
=
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
...
...
@@ -128,7 +129,8 @@ using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipl
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
ComputeDataType
,
AComputeDataType
,
BComputeDataType
,
LoopSched
>
;
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
9f1b4276
...
...
@@ -23,6 +23,7 @@ namespace device {
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
GemmSpecialization
GemmSpec
,
bool
Zeroing
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
...
...
@@ -106,33 +107,63 @@ __global__ void
const
auto
block_2_etile_map
=
GroupedGemmBlock2ETileMap
(
local_b2e_tile_map
,
BlockStart
,
id_off
);
auto
barrier_count_finished
=
barrier_count
+
group_id
*
barrier_size_grp
+
id_local
%
mn_blocks
;
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
EGlobalMemoryDataOperation
,
GemmSpec
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
>(
gemm_desc_ptr
[
group_id
].
p_a_grid
,
gemm_desc_ptr
[
group_id
].
p_b_grid
,
p_ds_grid_
,
gemm_desc_ptr
[
group_id
].
p_e_grid
,
p_shared
,
barrier_count_finished
,
a_element_op
,
b_element_op
,
c_element_op
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
KBatch
,
block_2_etile_map
);
if
constexpr
(
Zeroing
)
{
auto
barrier_count_finished
=
barrier_count
+
group_id
*
barrier_size_grp
+
id_local
%
mn_blocks
;
GridwiseGemm
::
template
RunWithZeroing
<
HasMainKBlockLoop
,
EGlobalMemoryDataOperation
,
GemmSpec
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
>(
gemm_desc_ptr
[
group_id
].
p_a_grid
,
gemm_desc_ptr
[
group_id
].
p_b_grid
,
p_ds_grid_
,
gemm_desc_ptr
[
group_id
].
p_e_grid
,
p_shared
,
barrier_count_finished
,
a_element_op
,
b_element_op
,
c_element_op
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
KBatch
,
block_2_etile_map
);
}
else
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
EGlobalMemoryDataOperation
,
GemmSpec
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
>(
gemm_desc_ptr
[
group_id
].
p_a_grid
,
gemm_desc_ptr
[
group_id
].
p_b_grid
,
p_ds_grid_
,
gemm_desc_ptr
[
group_id
].
p_e_grid
,
p_shared
,
nullptr
,
a_element_op
,
b_element_op
,
c_element_op
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
KBatch
,
block_2_etile_map
);
}
id_off
+=
grid_size_grp
;
id_local
+=
grid_size_grp
;
...
...
@@ -193,8 +224,11 @@ template <typename ALayout,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
typename
ComputeType
=
ADataType
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
typename
ComputeType
=
ADataType
,
typename
ALDSType
=
ComputeType
,
typename
BLDSType
=
ComputeType
>
struct
DeviceGroupedGemm_Xdl_Fixed_NK
:
public
DeviceGroupedGemmFixedNK
<
ALayout
,
BLayout
,
DsLayout
,
...
...
@@ -215,11 +249,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
using
AComputeType
=
ComputeType
;
using
BComputeType
=
ComputeType
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_splitk_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
BDataType
,
ComputeType
,
AComputeType
,
BComputeType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
...
...
@@ -258,7 +296,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
LoopSched
,
PipelineVer
,
ALDSType
,
BLDSType
>
;
template
<
typename
UnderlyingBlockToCTileMap
>
struct
OffsettedBlockToCTileMapMLoops
...
...
@@ -613,45 +654,85 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
e_global_memory_operation_
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_fixed_nk
<
GridwiseGemm
,
GroupedGemmKernelArgument
<
NumDTensor
>
,
GemmSpec
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
DsDataType
,
Block2ETileMap
,
GroupedGemmBlock2ETileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
e_global_memory_operation_
,
has_main_k_block_loop_
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
arg
.
grouped_gemm_kernel_args_dev
),
reinterpret_cast
<
uint32_t
*>
(
arg
.
p_workspace_
),
arg
.
barrier_size_grp_
,
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
grid_size_grp_
,
arg
.
k_batch_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
if
(
arg
.
k_batch_
==
1
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_fixed_nk
<
GridwiseGemm
,
GroupedGemmKernelArgument
<
NumDTensor
>
,
GemmSpec
,
false
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
DsDataType
,
Block2ETileMap
,
GroupedGemmBlock2ETileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
e_global_memory_operation_
,
has_main_k_block_loop_
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
arg
.
grouped_gemm_kernel_args_dev
),
nullptr
,
arg
.
barrier_size_grp_
,
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
grid_size_grp_
,
arg
.
k_batch_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
}
else
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_fixed_nk
<
GridwiseGemm
,
GroupedGemmKernelArgument
<
NumDTensor
>
,
GemmSpec
,
true
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
DsDataType
,
Block2ETileMap
,
GroupedGemmBlock2ETileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
e_global_memory_operation_
,
has_main_k_block_loop_
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
arg
.
grouped_gemm_kernel_args_dev
),
reinterpret_cast
<
uint32_t
*>
(
arg
.
p_workspace_
),
arg
.
barrier_size_grp_
,
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
grid_size_grp_
,
arg
.
k_batch_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
}
};
constexpr
auto
AtomicAdd
=
InMemoryDataOperationEnum
::
AtomicAdd
;
constexpr
auto
Set
=
InMemoryDataOperationEnum
::
Set
;
// For bf16 datatype only kbatch = 1 scenario is supported. This condition is
enforced
// in IsSupportedArgument function
// For bf16 datatype only kbatch = 1 scenario is supported. This condition is
//
enforced
in IsSupportedArgument function
if
constexpr
(
std
::
is_same
<
ADataType
,
ck
::
bhalf_t
>::
value
)
{
if
(
has_main_k_block_loop
)
...
...
@@ -719,12 +800,12 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
bool
supported
=
true
;
// If we use padding we do not support vector loads for dimensions not divisible by
vector
// load size.
// If we use padding we do not support vector loads for dimensions not divisible by
//
vector
load size.
if
constexpr
(
GemmSpec
!=
GemmSpecialization
::
Default
)
{
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1}
layout,
// thus we have to adapt it to the {M,K} or {N,K} layout.
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1}
//
layout,
thus we have to adapt it to the {M,K} or {N,K} layout.
const
auto
a_raw_vector_dim
=
ABlockTransferSrcVectorDim
!=
1
?
1
:
0
;
const
auto
b_raw_vector_dim
=
BBlockTransferSrcVectorDim
!=
1
?
1
:
0
;
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
9f1b4276
...
...
@@ -92,6 +92,110 @@ struct Add
};
};
struct
Max
{
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
{
const
Y
x0_converted
=
type_convert
<
Y
>
(
x0
);
const
Y
x1_converted
=
type_convert
<
Y
>
(
x1
);
y
=
ck
::
math
::
max
(
x0_converted
,
x1_converted
);
}
};
struct
Min
{
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
{
const
Y
x0_converted
=
type_convert
<
Y
>
(
x0
);
const
Y
x1_converted
=
type_convert
<
Y
>
(
x1
);
y
=
ck
::
math
::
min
(
x0_converted
,
x1_converted
);
}
};
struct
Multiply
{
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__host__
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
>
(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
y
=
x0
*
x1
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
double
>
(
double
&
y
,
const
double
&
x0
,
const
double
&
x1
)
const
{
y
=
x0
*
x1
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
>
(
float
&
y
,
const
float
&
x0
,
const
half_t
&
x1
)
const
{
y
=
x0
*
type_convert
<
half_t
>
(
x1
);
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
y
=
type_convert
<
half_t
>
(
x0
*
x1
);
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
float
&
x0
,
const
half_t
&
x1
)
const
{
y
=
type_convert
<
half_t
>
(
x0
)
*
x1
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
half_t
&
x0
,
const
half_t
&
x1
)
const
{
y
=
x0
*
x1
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
>
(
float
&
y
,
const
float
&
x0
,
const
bhalf_t
&
x1
)
const
{
const
float
x1_tmp
=
ck
::
type_convert
<
float
>
(
x1
);
y
=
x0
*
x1_tmp
;
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x0
,
const
bhalf_t
&
x1
)
const
{
const
float
x1_tmp
=
ck
::
type_convert
<
float
>
(
x0
);
const
float
x2_tmp
=
ck
::
type_convert
<
float
>
(
x1
);
const
float
y_tmp
=
x1_tmp
*
x2_tmp
;
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_tmp
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
>
(
bhalf_t
&
y
,
const
float
&
x0
,
const
bhalf_t
&
x1
)
const
{
const
float
x2_tmp
=
ck
::
type_convert
<
float
>
(
x1
);
const
float
y_tmp
=
x0
*
x2_tmp
;
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_tmp
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x0
,
const
int8_t
&
x1
)
const
{
y
=
x0
*
x1
;
};
};
struct
ScaleAdd
{
__host__
__device__
ScaleAdd
(
float
scale
=
1.
f
)
:
scale_
(
scale
)
{}
...
...
include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp
0 → 100644
View file @
9f1b4276
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
element_wise
{
// y = UnaryOp0(UnaryOp1(...(x)))
template
<
typename
...
UnaryOpsSet
>
struct
UnaryCombinedOp
{
__host__
__device__
UnaryCombinedOp
(
UnaryOpsSet
...
unary_ops
)
:
unary_ops_
(
unary_ops
...)
{}
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
// Execute first unary op to copy data to y
unary_ops_
.
At
(
Number
<
0
>
{})(
y
,
x
);
static_for
<
1
,
Tuple
<
UnaryOpsSet
...
>::
Size
(),
1
>
{}([
&
](
auto
i
)
{
unary_ops_
.
At
(
i
)(
y
,
y
);
});
};
Tuple
<
UnaryOpsSet
...
>
unary_ops_
;
};
// y = BinaryOp(UnaryOp0(x0), UnaryOp1(x1))
template
<
typename
BinaryOp
,
typename
UnaryOp0
,
typename
UnaryOp1
>
struct
BinaryWithUnaryCombinedOp
{
__host__
__device__
BinaryWithUnaryCombinedOp
(
BinaryOp
binary_op
,
UnaryOp0
unary_op0
,
UnaryOp1
unary_op1
)
:
binary_op_
(
binary_op
),
unary_op0_
(
unary_op0
),
unary_op1_
(
unary_op1
)
{
}
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
{
Y
unary_x0_tmp_result
;
Y
unary_x1_tmp_result
;
unary_op0_
(
unary_x0_tmp_result
,
x0
);
unary_op1_
(
unary_x1_tmp_result
,
x1
);
binary_op_
(
y
,
unary_x0_tmp_result
,
unary_x1_tmp_result
);
};
private:
BinaryOp
binary_op_
;
UnaryOp0
unary_op0_
;
UnaryOp1
unary_op1_
;
};
// y = BinaryOp0(BinaryOp1(UnaryOp0(x0), UnaryOp1(x1)), UnaryOp2(x2))
template
<
typename
BinaryOp0
,
typename
BinaryOp1
,
typename
UnaryOp0
,
typename
UnaryOp1
,
typename
UnaryOp2
>
struct
TrinaryWithUnaryCombinedOp
{
__host__
__device__
TrinaryWithUnaryCombinedOp
(
BinaryOp0
binary_op0
,
BinaryOp0
binary_op1
,
UnaryOp0
unary_op0
,
UnaryOp1
unary_op1
,
UnaryOp2
unary_op2
)
:
binary_op0_
(
binary_op0
),
binary_op1_
(
binary_op1
),
unary_op0_
(
unary_op0
),
unary_op1_
(
unary_op1
),
unary_op2_
(
unary_op2
)
{
}
template
<
typename
Y
,
typename
X0
,
typename
X1
,
typename
X2
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
,
const
X2
&
x2
)
const
{
Y
unary_x0_tmp_result
;
Y
unary_x1_tmp_result
;
Y
unary_x2_tmp_result
;
unary_op0_
(
unary_x0_tmp_result
,
x0
);
unary_op1_
(
unary_x1_tmp_result
,
x1
);
unary_op2_
(
unary_x2_tmp_result
,
x2
);
binary_op0_
(
unary_x0_tmp_result
,
unary_x0_tmp_result
,
unary_x1_tmp_result
);
binary_op1_
(
y
,
unary_x0_tmp_result
,
unary_x2_tmp_result
);
};
private:
BinaryOp0
binary_op0_
{};
BinaryOp1
binary_op1_
{};
UnaryOp0
unary_op0_
{};
UnaryOp1
unary_op1_
{};
UnaryOp2
unary_op2_
{};
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
9f1b4276
...
...
@@ -12,10 +12,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
element_wise
{
#if CK_WORKAROUND_SWDEV_383542
extern
"C"
__device__
float
__ocml_native_recip_f32
(
float
);
#endif
struct
PassThroughPack2
{
template
<
typename
Y
,
typename
X
>
...
...
@@ -449,11 +445,7 @@ struct FastGelu
const
float
u
=
x
*
(
c1
*
x
*
x
+
c2
);
const
float
emu
=
__expf
(
u
);
#if !CK_WORKAROUND_SWDEV_383542
y
=
x
*
__frcp_rn
(
1.
f
+
emu
);
#else
y
=
x
*
__ocml_native_recip_f32
(
1.
f
+
emu
);
#endif
y
=
x
*
ck
::
math
::
rcp
(
1.
f
+
emu
);
}
template
<
>
...
...
@@ -559,6 +551,244 @@ struct TanH
};
};
struct
ACos
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
acos
(
x
);
};
};
struct
Neg
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
neg
(
x
);
};
};
struct
ATan
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
atan
(
x
);
};
};
struct
Sin
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
sin
(
x
);
};
};
struct
ASinH
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
asinh
(
x
);
};
};
struct
Cos
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
cos
(
x
);
};
};
struct
ACosH
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
acosh
(
x
);
};
};
struct
Tan
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
tan
(
x
);
};
};
struct
ATanH
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
atanh
(
x
);
};
};
struct
SinH
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
sinh
(
x
);
};
};
struct
Ceil
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
ceil
(
x
);
};
};
struct
Exp
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
exp
(
x
);
};
};
struct
CosH
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
cosh
(
x
);
};
};
struct
Floor
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
floor
(
x
);
};
};
struct
Log
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
log
(
x
);
};
};
struct
ASin
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
asin
(
x
);
};
};
struct
Rcp
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
rcp
(
x
);
};
};
struct
Swish
{
Swish
(
float
beta
=
1.0
f
)
:
beta_
(
beta
)
{}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp
View file @
9f1b4276
...
...
@@ -118,8 +118,16 @@ struct GridwiseElementwise
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
M0PerBlock
);
const
index_t
m1_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
M1PerBlock
);
const
auto
thread_grid_offset
=
make_multi_index
(
m0_block_data_idx_on_grid
,
m1_block_data_idx_on_grid
);
const
auto
input_thread_grid_offset
=
generate_tuple
(
[
&
](
auto
)
{
return
make_multi_index
(
m0_block_data_idx_on_grid
,
m1_block_data_idx_on_grid
);
},
Number
<
NumInput
>
{});
const
auto
output_thread_grid_offset
=
generate_tuple
(
[
&
](
auto
)
{
return
make_multi_index
(
m0_block_data_idx_on_grid
,
m1_block_data_idx_on_grid
);
},
Number
<
NumOutput
>
{});
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
// If src and dst have same vector dim, then:
...
...
@@ -157,9 +165,9 @@ struct GridwiseElementwise
uniform_sequence_gen_t
<
NumOutput
,
1
>
,
uniform_sequence_gen_t
<
NumInput
,
false
>
,
uniform_sequence_gen_t
<
NumOutput
,
false
>>
{
in_grid_desc_tuple
,
thread_grid_offset
,
input_
thread_grid_offset
,
out_grid_desc_tuple
,
thread_grid_offset
,
output_
thread_grid_offset
,
elementwise_op
};
global_to_global_transfer
.
Run
(
in_grid_desc_tuple
,
in_global_buf_tuple
,
out_grid_desc_tuple
,
out_global_buf_tuple
,
I0
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
View file @
9f1b4276
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -30,7 +30,7 @@ namespace ck {
// D0, D1, ... and E have the same layout
template
<
typename
AsDataType
,
typename
BsDataType
,
typename
ComputeDataType_
,
typename
A
ComputeDataType_
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
...
...
@@ -71,7 +71,8 @@ template <typename AsDataType,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
typename
BComputeDataType_
=
AComputeDataType_
>
struct
GridwiseGemmMultipleABD_xdl_cshuffle
{
static
constexpr
index_t
NumATensor
=
AsDataType
::
Size
();
...
...
@@ -101,10 +102,13 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
#if CK_WORKAROUND_DENORM_FIX
using
ComputeDataType
=
conditional_t
<
is_same_v
<
ComputeDataType_
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
ComputeDataType_
>
;
using
AComputeDataType
=
conditional_t
<
is_same_v
<
AComputeDataType_
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
AComputeDataType_
>
;
using
BComputeDataType
=
conditional_t
<
is_same_v
<
BComputeDataType_
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
BComputeDataType_
>
;
#else
using
ComputeDataType
=
ComputeDataType_
;
using
AComputeDataType
=
AComputeDataType_
;
using
BComputeDataType
=
BComputeDataType_
;
#endif
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
...
...
@@ -195,8 +199,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
(
(
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
ComputeDataType
),
return
math
::
max
(
a_block_space_size_aligned
*
sizeof
(
AComputeDataType
)
+
b_block_space_size_aligned
*
sizeof
(
B
ComputeDataType
),
c_block_size
*
sizeof
(
CShuffleDataType
));
}
...
...
@@ -597,7 +601,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v7r2
<
ThisThreadBlock
,
AsDataType
,
Tuple
<
ComputeDataType
>
,
Tuple
<
A
ComputeDataType
>
,
decltype
(
as_grid_desc_ak0_m_ak1
),
decltype
(
tie
(
a_block_desc_ak0_m_ak1
)),
AElementwiseOperation
,
...
...
@@ -628,7 +632,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v7r2
<
ThisThreadBlock
,
BsDataType
,
Tuple
<
ComputeDataType
>
,
Tuple
<
B
ComputeDataType
>
,
decltype
(
bs_grid_desc_bk0_n_bk1
),
decltype
(
tie
(
b_block_desc_bk0_n_bk1
)),
BElementwiseOperation
,
...
...
@@ -656,14 +660,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
ComputeDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
AComputeDataType
,
MPerXdl
,
NPerXdl
,
BComputeDataType
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
ComputeDataType
,
// ComputeDataType for A
ComputeDataType
,
// ComputeDataType for B
A
ComputeDataType
,
B
ComputeDataType
,
AccDataType
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
...
@@ -681,10 +686,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ComputeDataType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
static_cast
<
A
ComputeDataType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ComputeDataType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
static_cast
<
B
ComputeDataType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
9f1b4276
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -73,7 +73,7 @@ template <typename ADataType,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
typename
BComputeDataType
=
AComputeDataType_
>
typename
BComputeDataType
_
=
AComputeDataType_
>
struct
GridwiseGemmMultipleD_xdl_cshuffle
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
...
@@ -103,8 +103,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
#if CK_WORKAROUND_DENORM_FIX
using
AComputeDataType
=
conditional_t
<
is_same_v
<
AComputeDataType_
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
AComputeDataType_
>
;
using
BComputeDataType
=
conditional_t
<
is_same_v
<
BComputeDataType_
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
BComputeDataType_
>
;
#else
using
AComputeDataType
=
AComputeDataType_
;
using
BComputeDataType
=
BComputeDataType_
;
#endif
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
View file @
9f1b4276
...
...
@@ -31,7 +31,8 @@ namespace ck {
// D0, D1, ... and E have the same layout
template
<
typename
ADataType
,
typename
BDataType
,
typename
ComputeType
,
typename
AComputeType
,
typename
BComputeType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
...
...
@@ -71,7 +72,9 @@ template <typename ADataType,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
,
typename
ALDSType
,
typename
BLDSType
>
struct
GridwiseGemmMultipleD_xdl_splitk_cshuffle
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
...
@@ -186,8 +189,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
(
(
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
Compute
Type
),
return
math
::
max
(
a_block_space_size_aligned
*
sizeof
(
ALDSType
)
+
b_block_space_size_aligned
*
sizeof
(
BLDS
Type
),
c_block_size
*
sizeof
(
CShuffleDataType
));
}
...
...
@@ -455,6 +458,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
index_t
NumDTensor_
,
typename
DsDataType_
,
bool
Zeroing
,
typename
AGridDesc_KBatch_AK0_M_AK1
,
typename
BGridDesc_KBatch_BK0_N_BK1
,
typename
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -530,7 +534,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ADataType
,
Compute
Type
,
ALDS
Type
,
decltype
(
a_grid_desc_kbatch_ak0_m_ak1
),
decltype
(
a_block_desc_kbatch_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
...
...
@@ -561,7 +565,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BDataType
,
Compute
Type
,
BLDS
Type
,
decltype
(
b_grid_desc_kbatch_bk0_n_bk1
),
decltype
(
b_block_desc_kbatch_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
...
...
@@ -597,12 +601,12 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
ComputeType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
A
ComputeType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
Compute
Type
,
Compute
Type
,
ALDS
Type
,
BLDS
Type
,
AccDataType
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
...
@@ -611,62 +615,65 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
MXdlPerWave
,
NXdlPerWave
,
KPack
,
LoopSched
>
();
LoopSched
,
AComputeType
,
BComputeType
>
();
#if 1
if
(
block_work_idx
[
I0
]
==
0
)
if
constexpr
(
Zeroing
)
{
const
index_t
nThreadSize
=
CDEShuffleBlockTransferScalarPerVector_NPerBlock
;
const
index_t
numNThreads
=
NPerBlock
/
nThreadSize
;
const
index_t
numMThreads
=
BlockSize
/
numNThreads
;
const
index_t
mThreadSize
=
MPerBlock
/
numMThreads
;
const
index_t
m_tid
=
get_thread_local_1d_id
()
/
numNThreads
;
const
index_t
n_tid
=
get_thread_local_1d_id
()
%
numNThreads
;
auto
c_thread_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
mThreadSize
>
{},
I1
,
Number
<
nThreadSize
>
{}));
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
EDataType
,
c_thread_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
(),
true
>
e_thread_zero_buf
;
auto
c_thread_copy
=
ThreadwiseTensorSliceTransfer_v1r3
<
EDataType
,
EDataType
,
decltype
(
c_thread_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
mThreadSize
,
1
,
nThreadSize
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
e_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I1
],
m_tid
*
mThreadSize
,
block_work_idx
[
I2
],
n_tid
*
nThreadSize
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
c_thread_copy
.
Run
(
c_thread_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
e_thread_zero_buf
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_buf
);
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
if
(
block_work_idx
[
I0
]
==
0
)
{
atomicAdd
(
barrier_count_finished
,
1
);
const
index_t
nThreadSize
=
CDEShuffleBlockTransferScalarPerVector_NPerBlock
;
const
index_t
numNThreads
=
NPerBlock
/
nThreadSize
;
const
index_t
numMThreads
=
BlockSize
/
numNThreads
;
const
index_t
mThreadSize
=
MPerBlock
/
numMThreads
;
const
index_t
m_tid
=
get_thread_local_1d_id
()
/
numNThreads
;
const
index_t
n_tid
=
get_thread_local_1d_id
()
%
numNThreads
;
auto
c_thread_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
mThreadSize
>
{},
I1
,
Number
<
nThreadSize
>
{}));
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
EDataType
,
c_thread_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
(),
true
>
e_thread_zero_buf
;
auto
c_thread_copy
=
ThreadwiseTensorSliceTransfer_v1r3
<
EDataType
,
EDataType
,
decltype
(
c_thread_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
mThreadSize
,
1
,
nThreadSize
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
e_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I1
],
m_tid
*
mThreadSize
,
block_work_idx
[
I2
],
n_tid
*
nThreadSize
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
c_thread_copy
.
Run
(
c_thread_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
e_thread_zero_buf
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_buf
);
__builtin_amdgcn_s_barrier
();
if
(
threadIdx
.
x
==
0
)
{
atomicAdd
(
barrier_count_finished
,
1
);
}
}
}
#endif
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
...
...
@@ -675,10 +682,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
Compute
Type
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
static_cast
<
ALDS
Type
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
Compute
Type
*>
(
p_shared
)
+
a_block_space_size_aligned
,
static_cast
<
BLDS
Type
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
/
AK1
,
0
,
0
);
...
...
@@ -711,13 +718,15 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// shuffle C and write out
{
if
(
threadIdx
.
x
==
0
)
if
constexpr
(
Zeroing
)
{
while
(
__atomic_load_n
(
barrier_count_finished
,
__ATOMIC_RELAXED
)
==
0
)
{}
if
(
threadIdx
.
x
==
0
)
{
while
(
__atomic_load_n
(
barrier_count_finished
,
__ATOMIC_RELAXED
)
==
0
)
{}
}
__builtin_amdgcn_s_barrier
();
}
__syncthreads
();
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
...
...
@@ -951,13 +960,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
}
});
if
(
threadIdx
.
x
==
0
)
if
constexpr
(
Zeroing
)
{
index_t
k_id_finished_t
=
atomicAdd
(
barrier_count_finished
,
1
);
if
(
k_id_finished_t
==
KBatch
)
if
(
threadIdx
.
x
==
0
)
{
*
barrier_count_finished
=
0
;
index_t
k_id_finished_t
=
atomicAdd
(
barrier_count_finished
,
1
);
if
(
k_id_finished_t
==
KBatch
)
{
*
barrier_count_finished
=
0
;
}
}
}
}
...
...
@@ -971,24 +983,24 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
typename
DsLayout
,
typename
ELayout
,
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
void
*
__restrict__
p_a_grid_
,
const
void
*
__restrict__
p_b_grid_
,
DsGridPointer
p_ds_grid
,
void
*
__restrict__
p_e_grid_
,
void
*
__restrict__
p_shared
,
uint32_t
*
barrier_count_finished
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
index_t
M
,
const
index_t
N
,
const
index_t
K
,
const
index_t
StrideA
,
const
index_t
StrideB
,
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
const
index_t
StrideE
,
const
index_t
KBatch
,
const
Block2ETileMap
&
block_2_etile_map
)
__device__
static
void
Run
WithZeroing
(
const
void
*
__restrict__
p_a_grid_
,
const
void
*
__restrict__
p_b_grid_
,
DsGridPointer
p_ds_grid
,
void
*
__restrict__
p_e_grid_
,
void
*
__restrict__
p_shared
,
uint32_t
*
barrier_count_finished
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
index_t
M
,
const
index_t
N
,
const
index_t
K
,
const
index_t
StrideA
,
const
index_t
StrideB
,
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
const
index_t
StrideE
,
const
index_t
KBatch
,
const
Block2ETileMap
&
block_2_etile_map
)
{
const
auto
p_a_grid
=
reinterpret_cast
<
const
ADataType
*>
(
p_a_grid_
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
BDataType
*>
(
p_b_grid_
);
...
...
@@ -1035,7 +1047,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
if
(
kbatch_id
==
KBatch
-
1
)
{
Run
<
HasMainKBlockLoop
,
EGlobalMemoryDataOperation
,
NumDTensor
,
DsDataType
>
(
Run
<
HasMainKBlockLoop
,
EGlobalMemoryDataOperation
,
NumDTensor
,
DsDataType
,
true
>
(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
...
...
@@ -1054,7 +1066,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
}
else
{
Run
<
HasMainKBlockLoop
,
EGlobalMemoryDataOperation
,
0
,
Tuple
<>>
(
Run
<
HasMainKBlockLoop
,
EGlobalMemoryDataOperation
,
0
,
Tuple
<>
,
true
>
(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
...
...
@@ -1072,6 +1084,89 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
block_2_etile_map
);
}
}
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
GemmSpecialization
GemmSpec
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
void
*
__restrict__
p_a_grid_
,
const
void
*
__restrict__
p_b_grid_
,
DsGridPointer
p_ds_grid
,
void
*
__restrict__
p_e_grid_
,
void
*
__restrict__
p_shared
,
uint32_t
*
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
index_t
M
,
const
index_t
N
,
const
index_t
K
,
const
index_t
StrideA
,
const
index_t
StrideB
,
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
const
index_t
StrideE
,
const
index_t
KBatch
,
const
Block2ETileMap
&
block_2_etile_map
)
{
const
auto
p_a_grid
=
reinterpret_cast
<
const
ADataType
*>
(
p_a_grid_
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
BDataType
*>
(
p_b_grid_
);
const
auto
p_e_grid
=
reinterpret_cast
<
EDataType
*>
(
p_e_grid_
);
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
<
DsLayout
,
GemmSpec
>
({},
{},
{}))
>
;
DsGridDesc_M_N
ds_grid_desc_m_n
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
DsLayout
>>
;
ds_grid_desc_m_n
(
j
)
=
MakeEGridDescriptor_M_N
<
DLayout
,
GemmSpec
>
(
M
,
N
,
StrideDs
[
j
]);
});
const
auto
e_grid_desc_m_n
=
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>
(
M
,
N
,
StrideE
);
// tensor descriptors for block/thread-wise copy
const
auto
a_grid_desc_kbatch_ak0_m_ak1
=
MakeAGridDescriptor_KBatch_AK0_M_AK1
<
ALayout
,
GemmSpec
>
(
M
,
K
,
StrideA
,
KBatch
);
const
auto
b_grid_desc_kbatch_bk0_n_bk1
=
MakeBGridDescriptor_KBatch_BK0_N_BK1
<
BLayout
,
GemmSpec
>
(
K
,
N
,
StrideB
,
KBatch
);
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
ds_grid_desc_mblock_mperblock_nblock_nperblock
(
j
)
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
[
j
]);
});
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
);
Run
<
HasMainKBlockLoop
,
EGlobalMemoryDataOperation
,
NumDTensor
,
DsDataType
,
false
>
(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared
,
nullptr
,
KBatch
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_kbatch_ak0_m_ak1
,
b_grid_desc_kbatch_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
}
};
}
// namespace ck
include/ck/utility/math_v2.hpp
View file @
9f1b4276
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -14,6 +14,10 @@
namespace
ck
{
namespace
math
{
#if CK_WORKAROUND_SWDEV_383542
extern
"C"
__device__
float
__ocml_native_recip_f32
(
float
);
#endif
// math functions for the host, some are implemented by calling C++ std functions
static
inline
__host__
float
abs
(
float
x
)
{
return
std
::
abs
(
x
);
};
...
...
@@ -111,6 +115,276 @@ inline __host__ double tanh<double>(double x)
return
std
::
tanh
(
x
);
};
template
<
typename
T
>
inline
__host__
T
acos
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
acosf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__host__
float
acos
<
float
>
(
float
x
)
{
return
std
::
acosf
(
x
);
};
template
<
>
inline
__host__
double
acos
<
double
>
(
double
x
)
{
return
std
::
acos
(
x
);
};
template
<
typename
T
>
inline
__host__
T
neg
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
-
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__host__
float
neg
<
float
>
(
float
x
)
{
return
-
x
;
};
template
<
>
inline
__host__
double
neg
<
double
>
(
double
x
)
{
return
-
x
;
};
template
<
>
inline
__host__
int32_t
neg
<
int32_t
>
(
int32_t
x
)
{
return
-
x
;
};
template
<
>
inline
__host__
int8_t
neg
<
int8_t
>
(
int8_t
x
)
{
return
-
x
;
};
template
<
typename
T
>
inline
__host__
T
atan
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
atanf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__host__
float
atan
<
float
>
(
float
x
)
{
return
std
::
atanf
(
x
);
};
template
<
>
inline
__host__
double
atan
<
double
>
(
double
x
)
{
return
std
::
atan
(
x
);
};
template
<
typename
T
>
inline
__host__
T
sin
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
sinf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__host__
float
sin
<
float
>
(
float
x
)
{
return
std
::
sinf
(
x
);
};
template
<
>
inline
__host__
double
sin
<
double
>
(
double
x
)
{
return
std
::
sin
(
x
);
};
template
<
typename
T
>
inline
__host__
T
asin
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
asinf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__host__
float
asin
<
float
>
(
float
x
)
{
return
std
::
asinf
(
x
);
};
template
<
>
inline
__host__
double
asin
<
double
>
(
double
x
)
{
return
std
::
asin
(
x
);
};
template
<
typename
T
>
inline
__host__
T
asinh
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
asinhf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__host__
float
asinh
<
float
>
(
float
x
)
{
return
std
::
asinhf
(
x
);
};
template
<
>
inline
__host__
double
asinh
<
double
>
(
double
x
)
{
return
std
::
asinh
(
x
);
};
template
<
typename
T
>
inline
__host__
T
cos
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
cosf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__host__
float
cos
<
float
>
(
float
x
)
{
return
std
::
cosf
(
x
);
};
template
<
>
inline
__host__
double
cos
<
double
>
(
double
x
)
{
return
std
::
cos
(
x
);
};
template
<
typename
T
>
inline
__host__
T
acosh
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
acoshf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__host__
float
acosh
<
float
>
(
float
x
)
{
return
std
::
acoshf
(
x
);
};
template
<
>
inline
__host__
double
acosh
<
double
>
(
double
x
)
{
return
std
::
acosh
(
x
);
};
template
<
typename
T
>
inline
__host__
T
tan
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
tanf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__host__
float
tan
<
float
>
(
float
x
)
{
return
std
::
tanf
(
x
);
};
template
<
>
inline
__host__
double
tan
<
double
>
(
double
x
)
{
return
std
::
tan
(
x
);
};
template
<
typename
T
>
inline
__host__
T
atanh
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
atanhf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__host__
float
atanh
<
float
>
(
float
x
)
{
return
std
::
atanhf
(
x
);
};
template
<
>
inline
__host__
double
atanh
<
double
>
(
double
x
)
{
return
std
::
atanh
(
x
);
};
template
<
typename
T
>
inline
__host__
T
sinh
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
sinhf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__host__
float
sinh
<
float
>
(
float
x
)
{
return
std
::
sinhf
(
x
);
};
template
<
>
inline
__host__
double
sinh
<
double
>
(
double
x
)
{
return
std
::
sinh
(
x
);
};
template
<
typename
T
>
inline
__host__
T
ceil
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
ceilf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__host__
float
ceil
<
float
>
(
float
x
)
{
return
std
::
ceilf
(
x
);
};
template
<
>
inline
__host__
double
ceil
<
double
>
(
double
x
)
{
return
std
::
ceil
(
x
);
};
template
<
typename
T
>
inline
__host__
T
cosh
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
coshf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__host__
float
cosh
<
float
>
(
float
x
)
{
return
std
::
coshf
(
x
);
};
template
<
>
inline
__host__
double
cosh
<
double
>
(
double
x
)
{
return
std
::
cosh
(
x
);
};
template
<
typename
T
>
inline
__host__
T
floor
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
floorf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__host__
float
floor
<
float
>
(
float
x
)
{
return
std
::
floorf
(
x
);
};
template
<
>
inline
__host__
double
floor
<
double
>
(
double
x
)
{
return
std
::
floor
(
x
);
};
template
<
typename
T
>
inline
__host__
T
rcp
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
1.
f
/
ck
::
type_convert
<
float
>
(
x
));
};
template
<
typename
T
>
inline
__host__
T
exp
(
T
x
)
{
...
...
@@ -282,6 +556,286 @@ inline __device__ double tanh<double>(double x)
return
::
tanh
(
x
);
};
template
<
typename
T
>
inline
__device__
T
acos
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
::
acosf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
acos
<
float
>
(
float
x
)
{
return
::
acosf
(
x
);
};
template
<
>
inline
__device__
double
acos
<
double
>
(
double
x
)
{
return
::
acos
(
x
);
};
template
<
typename
T
>
inline
__device__
T
neg
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
-
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
neg
<
float
>
(
float
x
)
{
return
-
x
;
};
template
<
>
inline
__device__
double
neg
<
double
>
(
double
x
)
{
return
-
x
;
};
template
<
>
inline
__device__
int32_t
neg
<
int32_t
>
(
int32_t
x
)
{
return
-
x
;
};
template
<
>
inline
__device__
int8_t
neg
<
int8_t
>
(
int8_t
x
)
{
return
-
x
;
};
template
<
>
inline
__device__
half_t
neg
<
half_t
>
(
half_t
x
)
{
return
__hneg
(
x
);
};
template
<
typename
T
>
inline
__device__
T
atan
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
::
atanf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
atan
<
float
>
(
float
x
)
{
return
::
atanf
(
x
);
};
template
<
>
inline
__device__
double
atan
<
double
>
(
double
x
)
{
return
::
atan
(
x
);
};
template
<
typename
T
>
inline
__device__
T
sin
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
::
sinf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
sin
<
float
>
(
float
x
)
{
return
::
sinf
(
x
);
};
template
<
>
inline
__device__
double
sin
<
double
>
(
double
x
)
{
return
::
sin
(
x
);
};
template
<
>
inline
__device__
half_t
sin
<
half_t
>
(
half_t
x
)
{
return
::
hsin
(
x
);
};
template
<
typename
T
>
inline
__device__
T
asin
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
::
asinf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
asin
<
float
>
(
float
x
)
{
return
::
asinf
(
x
);
};
template
<
>
inline
__device__
double
asin
<
double
>
(
double
x
)
{
return
::
asin
(
x
);
};
template
<
typename
T
>
inline
__device__
T
asinh
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
::
asinhf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
asinh
<
float
>
(
float
x
)
{
return
::
asinhf
(
x
);
};
template
<
>
inline
__device__
double
asinh
<
double
>
(
double
x
)
{
return
::
asinh
(
x
);
};
template
<
typename
T
>
inline
__device__
T
acosh
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
::
acoshf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
acosh
<
float
>
(
float
x
)
{
return
::
acoshf
(
x
);
};
template
<
>
inline
__device__
double
acosh
<
double
>
(
double
x
)
{
return
::
acosh
(
x
);
};
template
<
typename
T
>
inline
__device__
T
tan
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
::
tanf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
tan
<
float
>
(
float
x
)
{
return
::
tanf
(
x
);
};
template
<
>
inline
__device__
double
tan
<
double
>
(
double
x
)
{
return
::
tan
(
x
);
};
template
<
typename
T
>
inline
__device__
T
atanh
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
::
atanhf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
atanh
<
float
>
(
float
x
)
{
return
::
atanhf
(
x
);
};
template
<
>
inline
__device__
double
atanh
<
double
>
(
double
x
)
{
return
::
atanh
(
x
);
};
template
<
typename
T
>
inline
__device__
T
sinh
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
::
sinhf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
sinh
<
float
>
(
float
x
)
{
return
::
sinhf
(
x
);
};
template
<
>
inline
__device__
double
sinh
<
double
>
(
double
x
)
{
return
::
sinh
(
x
);
};
template
<
typename
T
>
inline
__device__
T
ceil
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
::
ceilf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
ceil
<
float
>
(
float
x
)
{
return
::
ceilf
(
x
);
};
template
<
>
inline
__device__
double
ceil
<
double
>
(
double
x
)
{
return
::
ceil
(
x
);
};
template
<
>
inline
__device__
half_t
ceil
<
half_t
>
(
half_t
x
)
{
return
::
hceil
(
x
);
};
template
<
typename
T
>
inline
__device__
T
cosh
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
::
coshf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
cosh
<
float
>
(
float
x
)
{
return
::
coshf
(
x
);
};
template
<
>
inline
__device__
double
cosh
<
double
>
(
double
x
)
{
return
::
cosh
(
x
);
};
template
<
typename
T
>
inline
__device__
T
floor
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
::
floorf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
floor
<
float
>
(
float
x
)
{
return
::
floorf
(
x
);
};
template
<
>
inline
__device__
double
floor
<
double
>
(
double
x
)
{
return
::
floor
(
x
);
};
template
<
>
inline
__device__
half_t
floor
<
half_t
>
(
half_t
x
)
{
return
::
hfloor
(
x
);
};
template
<
typename
T
>
inline
__device__
T
rcp
(
T
x
)
{
#if !CK_WORKAROUND_SWDEV_383542
return
__frcp_rn
(
x
);
#else
return
__ocml_native_recip_f32
(
x
);
#endif
};
template
<
typename
T
>
inline
__device__
T
exp
(
T
x
)
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp
0 → 100644
View file @
9f1b4276
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
template
<
index_t
NumATensors
,
typename
ADataType
,
typename
BDataType
,
typename
ElementOp
>
struct
ReferenceElementwise
:
public
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
std
::
array
<
Tensor
<
ADataType
>
,
NumATensors
>&
a_tensors
,
Tensor
<
BDataType
>&
b_tensor
,
ElementOp
element_op
)
:
a_tensors_
{
a_tensors
},
b_tensor_
{
b_tensor
},
element_op_
{
element_op
}
{
}
const
std
::
array
<
Tensor
<
ADataType
>
,
NumATensors
>&
a_tensors_
;
Tensor
<
BDataType
>&
b_tensor_
;
ElementOp
element_op_
;
};
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
{
using
Argument
=
ReferenceElementwise
::
Argument
;
float
Run
(
const
Argument
&
arg
)
{
if
constexpr
(
NumATensors
==
1
)
{
arg
.
b_tensor_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
arg
.
element_op_
(
self
(
idx
),
arg
.
a_tensors_
[
0
](
idx
));
});
}
else
if
constexpr
(
NumATensors
==
2
)
{
arg
.
b_tensor_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
arg
.
element_op_
(
self
(
idx
),
arg
.
a_tensors_
[
0
](
idx
),
arg
.
a_tensors_
[
1
](
idx
));
});
}
else
if
constexpr
(
NumATensors
==
3
)
{
arg
.
b_tensor_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
arg
.
element_op_
(
self
(
idx
),
arg
.
a_tensors_
[
0
](
idx
),
arg
.
a_tensors_
[
1
](
idx
),
arg
.
a_tensors_
[
2
](
idx
));
});
}
return
0
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
const
std
::
array
<
Tensor
<
ADataType
>
,
NumATensors
>&
a_tensors
,
Tensor
<
BDataType
>&
b_tensor
,
ElementOp
element_op
)
{
return
Argument
{
a_tensors
,
b_tensor
,
element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceElementwise"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
View file @
9f1b4276
This diff is collapsed.
Click to expand it.
library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp
View file @
9f1b4276
...
...
@@ -16,7 +16,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef
CK_ENABLE_FP16
#if
def
ined(
CK_ENABLE_FP16
) && defined(CK_USE_XDL)
void
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Row
,
...
...
@@ -69,7 +69,7 @@ void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance
PassThrough
,
Bilinear
>>>&
instances
);
#endif
#ifdef
CK_ENABLE_INT8
#if
def
ined(
CK_ENABLE_INT8
) && defined(CK_USE_WMMA)
void
add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
...
...
@@ -159,7 +159,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef
CK_ENABLE_FP16
#if
def
ined(
CK_ENABLE_FP16
) && defined(CK_USE_XDL)
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
DDataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
...
...
@@ -189,7 +189,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
}
}
#endif
#ifdef
CK_ENABLE_INT8
#if
def
ined(
CK_ENABLE_INT8
) && defined(CK_USE_WMMA)
if
constexpr
(
is_same_v
<
ADataType
,
std
::
int8_t
>
&&
is_same_v
<
BDataType
,
std
::
int8_t
>
&&
is_same_v
<
DDataType
,
std
::
int8_t
>
&&
is_same_v
<
EDataType
,
std
::
int8_t
>
)
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_dl.inc
0 → 100644
View file @
9f1b4276
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <memory>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#if defined(CK_ENABLE_FP16)
void
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#if defined(CK_ENABLE_FP32)
void
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#if defined(CK_ENABLE_INT8)
void
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_wmma.inc
0 → 100644
View file @
9f1b4276
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
5
6
7
8
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment