Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0b11569f
Commit
0b11569f
authored
Jul 01, 2022
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/develop' into batched_gemm_c_permute
parents
e8d3a0fb
fa9a0a5c
Changes
554
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
129 additions
and
93 deletions
+129
-93
library/src/host_tensor/device_memory.cpp
library/src/host_tensor/device_memory.cpp
+3
-0
library/src/host_tensor/host_tensor.cpp
library/src/host_tensor/host_tensor.cpp
+3
-22
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+28
-25
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
..._batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
+4
-1
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp
..._batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp
+4
-1
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
..._batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
+4
-1
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
..._batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
+4
-1
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
...ice_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
+4
-1
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
...ice_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
+4
-1
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
...ice_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
+4
-1
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
...ice_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
+4
-1
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp
...ice_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp
+4
-1
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp
...ice_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp
+4
-1
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
...ice_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
+4
-1
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp
...ice_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp
+4
-1
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp
..._batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp
+4
-1
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp
..._batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp
+4
-1
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
..._batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
+4
-1
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp
..._batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp
+4
-1
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
...xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
+31
-30
No files found.
library/src/host_tensor/device_memory.cpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/device_utility/hip_check_error.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
...
...
library/src/host_tensor/host_tensor.cpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cassert>
#include "ck/library/host_tensor/host_tensor.hpp"
...
...
@@ -51,25 +54,3 @@ std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc)
return
os
;
}
void
ostream_HostTensorDescriptor
(
const
HostTensorDescriptor
&
desc
,
std
::
ostream
&
os
)
{
os
<<
"dim "
<<
desc
.
GetNumOfDimension
()
<<
", "
;
os
<<
"lengths {"
;
LogRange
(
os
,
desc
.
GetLengths
(),
", "
);
os
<<
"}, "
;
os
<<
"strides {"
;
LogRange
(
os
,
desc
.
GetStrides
(),
", "
);
os
<<
"}"
<<
std
::
endl
;
}
#if 1
// FIXME: remove
void
bf16_to_f32_
(
const
Tensor
<
ck
::
bhalf_t
>&
src
,
Tensor
<
float
>&
dst
)
{
for
(
std
::
size_t
i
=
0
;
i
<
src
.
mData
.
size
();
++
i
)
dst
.
mData
[
i
]
=
ck
::
type_convert
<
float
>
(
src
.
mData
[
i
]);
}
#endif
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
0b11569f
...
...
@@ -5,44 +5,51 @@ function(add_instance_library INSTANCE_NAME)
set_target_properties
(
${
INSTANCE_NAME
}
PROPERTIES POSITION_INDEPENDENT_CODE ON
)
endfunction
(
add_instance_library INSTANCE_NAME
)
add_subdirectory
(
elementwise
)
add_subdirectory
(
gemm
)
add_subdirectory
(
gemm_splitk
)
add_subdirectory
(
gemm_bias2d
)
add_subdirectory
(
gemm_bias_relu
)
add_subdirectory
(
gemm_bias_relu_add
)
add_subdirectory
(
gemm_reduce
)
add_subdirectory
(
gemm_bias_add_reduce
)
add_subdirectory
(
gemm_add_add_fastgelu
)
add_subdirectory
(
batched_gemm
)
add_subdirectory
(
batched_gemm_reduce
)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
conv1d_fwd
)
add_subdirectory
(
conv2d_fwd
)
add_subdirectory
(
conv3d_fwd
)
add_subdirectory
(
conv2d_fwd_bias_relu
)
add_subdirectory
(
conv2d_fwd_bias_relu_add
)
add_subdirectory
(
conv2d_bwd_data
)
add_subdirectory
(
reduce
)
add_subdirectory
(
convnd_bwd_data
)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
conv2d_bwd_weight
)
add_subdirectory
(
batched_gemm_reduce
)
add_subdirectory
(
gemm_add_add_fastgelu
)
add_subdirectory
(
normalization
)
add_subdirectory
(
reduce
)
add_library
(
device_operations STATIC
$<TARGET_OBJECTS:device_conv1d_fwd_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance>
$<TARGET_OBJECTS:device_gemm_instance>
$<TARGET_OBJECTS:device_gemm_splitk_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_add_instance>
$<TARGET_OBJECTS:device_gemm_bias_add_reduce_instance>
$<TARGET_OBJECTS:device_gemm_bias2d_instance>
$<TARGET_OBJECTS:device_reduce_instance>
$<TARGET_OBJECTS:device_convnd_bwd_data_instance>
$<TARGET_OBJECTS:device_grouped_gemm_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_weight_instance>
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
$<TARGET_OBJECTS:device_grouped_gemm_instance>
$<TARGET_OBJECTS:device_conv1d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_instance>
$<TARGET_OBJECTS:device_conv3d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
$<TARGET_OBJECTS:device_convnd_bwd_data_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_weight_instance>
$<TARGET_OBJECTS:device_elementwise_instance>
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
$<TARGET_OBJECTS:device_reduce_instance>
)
add_library
(
composablekernels::device_operations ALIAS device_operations
)
...
...
@@ -67,29 +74,25 @@ target_include_directories(device_operations PUBLIC
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_operation/gpu/thread>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_operation/gpu/element>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/host_tensor>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/host>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu/reduce>
)
#once new arches are enabled make this an option on the main cmake file
# and pass down here to be exported
target_compile_options
(
device_operations PRIVATE
target_compile_options
(
device_operations PRIVATE
--offload-arch=gfx908
--offload-arch=gfx90a
)
# install(TARGETS device_operations LIBRARY DESTINATION lib)
install
(
TARGETS device_operations
EXPORT device_operationsTargets
LIBRARY DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
ARCHIVE DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
RUNTIME DESTINATION
${
CMAKE_INSTALL_BINDIR
}
INCLUDES DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
)
install
(
DIRECTORY
${
DEV_OPS_INC_DIRS
}
DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck
)
install
(
EXPORT device_operationsTargets
rocm_install
(
TARGETS device_operations
EXPORT device_operationsTargets
)
rocm_install
(
DIRECTORY
${
DEV_OPS_INC_DIRS
}
DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck
)
rocm_install
(
EXPORT device_operationsTargets
FILE composable_kerneldevice_operationsTargets.cmake
NAMESPACE composable_kernel::
DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
...
...
@@ -41,7 +44,7 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances = std::tuple<
>
;
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
Device
Batched
GemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
...
...
@@ -41,7 +44,7 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances = std::tuple<
>
;
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
Device
Batched
GemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
...
...
@@ -45,7 +48,7 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances = std::tuple<
>
;
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
Device
Batched
GemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
...
...
@@ -46,7 +49,7 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances = std::tuple<
>
;
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
Device
Batched
GemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
...
...
@@ -41,7 +44,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple<
>
;
void
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
Device
Batched
GemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
...
...
@@ -41,7 +44,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple<
>
;
void
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
Device
Batched
GemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
...
...
@@ -50,7 +53,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple<
>
;
void
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
Device
Batched
GemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
...
...
@@ -46,7 +49,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple<
>
;
void
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
Device
Batched
GemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
...
...
@@ -41,7 +44,7 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances = std::tuple<
>
;
void
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
Device
Batched
GemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
...
...
@@ -41,7 +44,7 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances = std::tuple<
>
;
void
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
Device
Batched
GemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
...
...
@@ -41,7 +44,7 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances = std::tuple<
>
;
void
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
Device
Batched
GemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
...
...
@@ -46,7 +49,7 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances = std::tuple<
>
;
void
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
Device
Batched
GemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
...
...
@@ -56,7 +59,7 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple<
>
;
void
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
Device
Batched
GemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
...
...
@@ -56,7 +59,7 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances = std::tuple<
>
;
void
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
Device
Batched
GemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
...
...
@@ -56,7 +59,7 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances = std::tuple<
>
;
void
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
Device
Batched
GemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
...
...
@@ -48,7 +51,7 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances = std::tuple<
>
;
void
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
Device
Batched
GemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
...
...
@@ -12,9 +15,9 @@ namespace tensor_operation {
namespace
device
{
namespace
device_gemm_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
D
PtrsGlobal
=
ck
::
Tuple
<
F32
*
,
F32
*>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Reduce
PtrsGlobal
=
ck
::
Tuple
<
F32
*
,
F32
*>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
@@ -26,10 +29,10 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using
ReduceSum
=
ck
::
reduce
::
Add
;
using
ReduceOps
=
ck
::
Tuple
<
ReduceSum
,
ReduceSum
>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
using
D
InElementOps
=
ck
::
Tuple
<
Identity
,
Square
>
;
using
D
OutElementOps
=
ck
::
Tuple
<
Identity
,
Identity
>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
using
Reduce
InElementOps
=
ck
::
Tuple
<
Identity
,
Square
>
;
using
Reduce
OutElementOps
=
ck
::
Tuple
<
Identity
,
Identity
>
;
using
ReduceMemOp
=
ck
::
InMemoryDataOperationEnumSequence
<
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
,
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
>
;
...
...
@@ -40,33 +43,31 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances
=
std
::
tuple
<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc|
D
Data| A| B| C|
Dxs| Dxs
InEleOp|
Dxs
AccEleOp|
D
| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise|
Reduce|
| | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | |
|
Operation
| Operation| Operation| Operation|
|
| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | |
|
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
D
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
D
InElementOps
,
D
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
D
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
D
InElementOps
,
D
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
D
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
D
InElementOps
,
D
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
D
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
D
InElementOps
,
D
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
D
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
D
InElementOps
,
D
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
D
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
D
InElementOps
,
D
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
D
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
D
InElementOps
,
D
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
D
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
D
InElementOps
,
D
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
D
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
D
InElementOps
,
D
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
D
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
D
InElementOps
,
D
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
D
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
D
InElementOps
,
D
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
D
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
D
InElementOps
,
D
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
D
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
D
InElementOps
,
D
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
D
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
D
InElementOps
,
D
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
D
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
D
InElementOps
,
D
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
D
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
D
InElementOps
,
D
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc|
Reduce
Data| A| B| C|
Reduce| Reduce
InEleOp|
Reduce
AccEleOp|
Reduce
| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType|
Type Tuple| Elementwise| Elementwise| Elementwise|
Operation|
|
| MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | |
| Operation| Operation| Operation|
| |
| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | |
| | | | |
|
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
Reduce
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
Reduce
InElementOps
,
Reduce
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
Reduce
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
Reduce
InElementOps
,
Reduce
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
Reduce
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
Reduce
InElementOps
,
Reduce
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
Reduce
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
Reduce
InElementOps
,
Reduce
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
Reduce
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
Reduce
InElementOps
,
Reduce
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
Reduce
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
Reduce
InElementOps
,
Reduce
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
Reduce
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
Reduce
InElementOps
,
Reduce
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
Reduce
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
Reduce
InElementOps
,
Reduce
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
Reduce
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
Reduce
InElementOps
,
Reduce
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
Reduce
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
Reduce
InElementOps
,
Reduce
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
Reduce
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
Reduce
InElementOps
,
Reduce
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
Reduce
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
Reduce
InElementOps
,
Reduce
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
Reduce
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
Reduce
InElementOps
,
Reduce
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
Reduce
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
Reduce
InElementOps
,
Reduce
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
Reduce
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
Reduce
InElementOps
,
Reduce
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
Reduce
PtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
Reduce
InElementOps
,
Reduce
OutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
// clang-format on
>
;
void
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
std
::
vector
<
DeviceGemmReducePtr
<
0
,
ReducePtrsGlobal
::
Size
()
>>&
instances
)
{
add_device_operation_instances
(
instances
,
...
...
Prev
1
…
12
13
14
15
16
17
18
19
20
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment