Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
df35f46d
Commit
df35f46d
authored
Oct 07, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
9c0811f3
7733ae16
Changes
69
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
184 additions
and
48 deletions
+184
-48
include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
...ps/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
+13
-9
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+7
-21
profiler/src/CMakeLists.txt
profiler/src/CMakeLists.txt
+6
-6
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+3
-0
script/cmake-ck-release.sh
script/cmake-ck-release.sh
+3
-0
test/CMakeLists.txt
test/CMakeLists.txt
+5
-12
test/ck_tile/CMakeLists.txt
test/ck_tile/CMakeLists.txt
+1
-0
test/ck_tile/image_to_column/CMakeLists.txt
test/ck_tile/image_to_column/CMakeLists.txt
+4
-0
test/ck_tile/image_to_column/test_tile_image_to_column.cpp
test/ck_tile/image_to_column/test_tile_image_to_column.cpp
+142
-0
No files found.
include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
View file @
df35f46d
...
@@ -14,17 +14,21 @@ template <typename XDataType_,
...
@@ -14,17 +14,21 @@ template <typename XDataType_,
typename
YDataType_
,
typename
YDataType_
,
typename
MeanDataType_
,
typename
MeanDataType_
,
typename
InvStdDataType_
,
typename
InvStdDataType_
,
typename
BlockShape_
>
typename
BlockShape_
,
bool
kPadM_
,
bool
kPadN_
>
struct
BlockLayernorm2dFwdProblem
struct
BlockLayernorm2dFwdProblem
{
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
df35f46d
...
@@ -37,11 +37,7 @@ function(add_instance_library INSTANCE_NAME)
...
@@ -37,11 +37,7 @@ function(add_instance_library INSTANCE_NAME)
endforeach
()
endforeach
()
endif
()
endif
()
if
(
INSTANCES_ONLY
)
set
(
INST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
set
(
INST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
INST_TARGETS
${
GPU_TARGETS
}
)
endif
()
# Do not build DL instances if DL_KERNELS macro is not set
# Do not build DL instances if DL_KERNELS macro is not set
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
...
@@ -64,9 +60,9 @@ function(add_instance_library INSTANCE_NAME)
...
@@ -64,9 +60,9 @@ function(add_instance_library INSTANCE_NAME)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
# Do not build mha instances if gfx94 targets are not on the target list
# Do not build mha instances if gfx94
or gfx90a
targets are not on the target list
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"mha"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND
NOT INST_TARGETS MATCHES
"gfx90a"
AND
source MATCHES
"mha"
)
message
(
"removing mha instance
${
source
}
"
)
message
(
"removing mha instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endif
()
...
@@ -75,17 +71,13 @@ function(add_instance_library INSTANCE_NAME)
...
@@ -75,17 +71,13 @@ function(add_instance_library INSTANCE_NAME)
if
(
ARGN
)
if
(
ARGN
)
set
(
INST_OBJ
)
set
(
INST_OBJ
)
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
if
(
INSTANCES_ONLY
)
set
(
INST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
set
(
INST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
INST_TARGETS
${
GPU_TARGETS
}
)
endif
()
if
(
source MATCHES
"_xdl"
)
if
(
source MATCHES
"_xdl"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
elseif
(
ARGN MATCHES
"_wmma"
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
elseif
(
ARGN MATCHES
"mha"
)
elseif
(
ARGN MATCHES
"mha"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908
gfx90a
gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
endif
()
endif
()
set
(
offload_targets
)
set
(
offload_targets
)
foreach
(
target IN LISTS INST_TARGETS
)
foreach
(
target IN LISTS INST_TARGETS
)
...
@@ -191,12 +183,7 @@ FOREACH(subdir_path ${dir_list})
...
@@ -191,12 +183,7 @@ FOREACH(subdir_path ${dir_list})
set
(
add_inst 1
)
set
(
add_inst 1
)
endif
()
endif
()
if
(
INSTANCES_ONLY
)
set
(
INST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
set
(
INST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
INST_TARGETS
${
GPU_TARGETS
}
)
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"quantization"
)
AND
(
DEFINED DTYPES
)
AND
(
NOT DTYPES MATCHES
"int8"
))
if
((
"
${
cmake_instance
}
"
MATCHES
"quantization"
)
AND
(
DEFINED DTYPES
)
AND
(
NOT DTYPES MATCHES
"int8"
))
message
(
"quantization instances will not be built!"
)
message
(
"quantization instances will not be built!"
)
...
@@ -320,8 +307,7 @@ if(CK_DEVICE_CONV_INSTANCES)
...
@@ -320,8 +307,7 @@ if(CK_DEVICE_CONV_INSTANCES)
endif
()
endif
()
if
(
CK_DEVICE_MHA_INSTANCES
)
if
(
CK_DEVICE_MHA_INSTANCES
)
set
(
gpu_list
${
INST_TARGETS
}
)
set
(
gpu_list
${
INST_TARGETS
}
)
list
(
FILTER gpu_list INCLUDE REGEX
"^gfx94"
)
if
(
gpu_list MATCHES
"gfx94"
OR gpu_list MATCHES
"gfx90a"
)
if
(
gpu_list
)
add_library
(
device_mha_operations STATIC
${
CK_DEVICE_MHA_INSTANCES
}
)
add_library
(
device_mha_operations STATIC
${
CK_DEVICE_MHA_INSTANCES
}
)
add_library
(
composablekernels::device_mha_operations ALIAS device_mha_operations
)
add_library
(
composablekernels::device_mha_operations ALIAS device_mha_operations
)
target_compile_features
(
device_mha_operations PUBLIC
)
target_compile_features
(
device_mha_operations PUBLIC
)
...
...
profiler/src/CMakeLists.txt
View file @
df35f46d
...
@@ -24,7 +24,7 @@ set(PROFILER_SOURCES
...
@@ -24,7 +24,7 @@ set(PROFILER_SOURCES
profile_permute_scale.cpp
profile_permute_scale.cpp
)
)
if
(
GPU_TARGETS MATCHES
"gfx9"
)
if
(
SUPPORTED_
GPU_TARGETS MATCHES
"gfx9"
)
if
(
DTYPES MATCHES
"fp32"
OR DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
if
(
DTYPES MATCHES
"fp32"
OR DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
list
(
APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp
)
list
(
APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp
)
list
(
APPEND PROFILER_SOURCES profile_contraction_scale.cpp
)
list
(
APPEND PROFILER_SOURCES profile_contraction_scale.cpp
)
...
@@ -49,7 +49,7 @@ if(GPU_TARGETS MATCHES "gfx9")
...
@@ -49,7 +49,7 @@ if(GPU_TARGETS MATCHES "gfx9")
list
(
APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp
)
list
(
APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp
)
endif
()
endif
()
list
(
APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp
)
if
(
GPU_TARGETS MATCHES
"gfx94"
)
if
(
SUPPORTED_
GPU_TARGETS MATCHES
"gfx94"
)
list
(
APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp
)
endif
()
endif
()
...
@@ -69,7 +69,7 @@ if(GPU_TARGETS MATCHES "gfx9")
...
@@ -69,7 +69,7 @@ if(GPU_TARGETS MATCHES "gfx9")
endif
()
endif
()
if
(
GPU_TARGETS MATCHES
"gfx11"
OR GPU_TARGETS MATCHES
"gfx12"
OR GPU_TARGETS MATCHES
"gfx9"
)
if
(
SUPPORTED_
GPU_TARGETS MATCHES
"gfx11"
OR
SUPPORTED_
GPU_TARGETS MATCHES
"gfx12"
OR
SUPPORTED_
GPU_TARGETS MATCHES
"gfx9"
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp
)
endif
()
endif
()
...
@@ -111,7 +111,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_inst
...
@@ -111,7 +111,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_inst
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_transpose_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_transpose_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_permute_scale_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_permute_scale_instance
)
if
(
GPU_TARGETS MATCHES
"gfx9"
)
if
(
SUPPORTED_
GPU_TARGETS MATCHES
"gfx9"
)
if
(
DTYPES MATCHES
"fp32"
OR DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
if
(
DTYPES MATCHES
"fp32"
OR DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_bilinear_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_bilinear_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_scale_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_scale_instance
)
...
@@ -135,7 +135,7 @@ if(GPU_TARGETS MATCHES "gfx9")
...
@@ -135,7 +135,7 @@ if(GPU_TARGETS MATCHES "gfx9")
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_multiply_add_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_multiply_add_instance
)
if
(
GPU_TARGETS MATCHES
"gfx94"
)
if
(
SUPPORTED_
GPU_TARGETS MATCHES
"gfx94"
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_multiply_multiply_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_multiply_multiply_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_ab_scale_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_ab_scale_instance
)
endif
()
endif
()
...
@@ -159,7 +159,7 @@ if(GPU_TARGETS MATCHES "gfx9")
...
@@ -159,7 +159,7 @@ if(GPU_TARGETS MATCHES "gfx9")
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_fwd_convinvscale_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_fwd_convinvscale_instance
)
endif
()
endif
()
if
(
GPU_TARGETS MATCHES
"gfx9"
OR GPU_TARGETS MATCHES
"gfx11"
OR GPU_TARGETS MATCHES
"gfx12"
)
if
(
SUPPORTED_
GPU_TARGETS MATCHES
"gfx9"
OR
SUPPORTED_
GPU_TARGETS MATCHES
"gfx11"
OR
SUPPORTED_
GPU_TARGETS MATCHES
"gfx12"
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_bilinear_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_bilinear_instance
)
endif
()
endif
()
...
...
script/cmake-ck-dev.sh
View file @
df35f46d
...
@@ -7,8 +7,10 @@ MY_PROJECT_SOURCE=$1
...
@@ -7,8 +7,10 @@ MY_PROJECT_SOURCE=$1
if
[
$#
-ge
2
]
;
then
if
[
$#
-ge
2
]
;
then
GPU_TARGETS
=
$2
GPU_TARGETS
=
$2
REST_ARGS
=
${
@
:3
}
else
else
GPU_TARGETS
=
"gfx908;gfx90a;gfx940"
GPU_TARGETS
=
"gfx908;gfx90a;gfx940"
REST_ARGS
=
fi
fi
cmake
\
cmake
\
...
@@ -20,4 +22,5 @@ cmake
...
@@ -20,4 +22,5 @@ cmake
-D
GPU_TARGETS
=
$GPU_TARGETS
\
-D
GPU_TARGETS
=
$GPU_TARGETS
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
$REST_ARGS
\
${
MY_PROJECT_SOURCE
}
${
MY_PROJECT_SOURCE
}
script/cmake-ck-release.sh
View file @
df35f46d
...
@@ -7,8 +7,10 @@ MY_PROJECT_SOURCE=$1
...
@@ -7,8 +7,10 @@ MY_PROJECT_SOURCE=$1
if
[
$#
-ge
2
]
;
then
if
[
$#
-ge
2
]
;
then
GPU_TARGETS
=
$2
GPU_TARGETS
=
$2
REST_ARGS
=
${
@
:3
}
else
else
GPU_TARGETS
=
"gfx908;gfx90a;gfx940"
GPU_TARGETS
=
"gfx908;gfx90a;gfx940"
REST_ARGS
=
fi
fi
cmake
\
cmake
\
...
@@ -20,5 +22,6 @@ cmake
...
@@ -20,5 +22,6 @@ cmake
-D
GPU_TARGETS
=
$GPU_TARGETS
\
-D
GPU_TARGETS
=
$GPU_TARGETS
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
$REST_ARGS
\
${
MY_PROJECT_SOURCE
}
${
MY_PROJECT_SOURCE
}
test/CMakeLists.txt
View file @
df35f46d
...
@@ -41,11 +41,7 @@ function(add_test_executable TEST_NAME)
...
@@ -41,11 +41,7 @@ function(add_test_executable TEST_NAME)
endforeach
()
endforeach
()
endif
()
endif
()
if
(
INSTANCES_ONLY
)
set
(
TEST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
set
(
TEST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
TEST_TARGETS
${
GPU_TARGETS
}
)
endif
()
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
...
@@ -122,11 +118,7 @@ function(add_gtest_executable TEST_NAME)
...
@@ -122,11 +118,7 @@ function(add_gtest_executable TEST_NAME)
endforeach
()
endforeach
()
endif
()
endif
()
if
(
INSTANCES_ONLY
)
set
(
TEST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
set
(
TEST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
TEST_TARGETS
${
GPU_TARGETS
}
)
endif
()
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
...
@@ -173,6 +165,7 @@ function(add_gtest_executable TEST_NAME)
...
@@ -173,6 +165,7 @@ function(add_gtest_executable TEST_NAME)
endfunction
()
endfunction
()
add_compile_options
(
-Wno-c++20-extensions
)
add_compile_options
(
-Wno-c++20-extensions
)
add_subdirectory
(
ck_tile
)
add_subdirectory
(
magic_number_division
)
add_subdirectory
(
magic_number_division
)
add_subdirectory
(
space_filling_curve
)
add_subdirectory
(
space_filling_curve
)
add_subdirectory
(
conv_util
)
add_subdirectory
(
conv_util
)
...
@@ -210,10 +203,10 @@ add_subdirectory(conv_tensor_rearrange)
...
@@ -210,10 +203,10 @@ add_subdirectory(conv_tensor_rearrange)
add_subdirectory
(
transpose
)
add_subdirectory
(
transpose
)
add_subdirectory
(
permute_scale
)
add_subdirectory
(
permute_scale
)
add_subdirectory
(
wrapper
)
add_subdirectory
(
wrapper
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
if
(
SUPPORTED_
GPU_TARGETS MATCHES
"gfx11"
)
add_subdirectory
(
wmma_op
)
add_subdirectory
(
wmma_op
)
endif
()
endif
()
if
(
GPU_TARGETS MATCHES
"gfx942"
AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2
)
# smfmac needs ROCm6.2
if
(
SUPPORTED_
GPU_TARGETS MATCHES
"gfx942"
AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2
)
# smfmac needs ROCm6.2
add_subdirectory
(
smfmac_op
)
add_subdirectory
(
smfmac_op
)
endif
()
endif
()
add_subdirectory
(
position_embedding
)
add_subdirectory
(
position_embedding
)
test/ck_tile/CMakeLists.txt
0 → 100644
View file @
df35f46d
add_subdirectory
(
image_to_column
)
test/ck_tile/image_to_column/CMakeLists.txt
0 → 100644
View file @
df35f46d
# Currently ck_tile is only built on gfx9
if
(
GPU_TARGETS MATCHES
"gfx9"
)
add_gtest_executable
(
test_tile_image_to_column test_tile_image_to_column.cpp
)
endif
()
test/ck_tile/image_to_column/test_tile_image_to_column.cpp
0 → 100644
View file @
df35f46d
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <gtest/gtest.h>
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/image_to_column.hpp"
// Host API implementation
template
<
typename
DataType
>
class
TestCkTileImageToColumn
:
public
::
testing
::
Test
{
static
constexpr
ck_tile
::
index_t
VectorSize
=
1
;
static
constexpr
ck_tile
::
index_t
NDimSpatial
=
2
;
protected:
void
Run
(
const
ck_tile
::
conv
::
ConvParam
conv_params
)
{
using
ImLayout
=
ck_tile
::
tensor_layout
::
convolution
::
NHWGC
;
const
auto
G
=
conv_params
.
G_
;
const
auto
N
=
conv_params
.
N_
;
const
auto
C
=
conv_params
.
C_
;
const
ck_tile
::
long_index_t
NDoHoWo
=
N
*
std
::
accumulate
(
conv_params
.
output_spatial_lengths_
.
begin
(),
std
::
next
(
conv_params
.
output_spatial_lengths_
.
begin
(),
NDimSpatial
),
1
,
std
::
multiplies
<>
());
const
ck_tile
::
long_index_t
CZYX
=
C
*
std
::
accumulate
(
conv_params
.
filter_spatial_lengths_
.
begin
(),
std
::
next
(
conv_params
.
filter_spatial_lengths_
.
begin
(),
NDimSpatial
),
1
,
std
::
multiplies
<>
());
const
auto
in_desc
=
ck_tile
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
ImLayout
>
(
conv_params
);
const
auto
out_desc
=
ck_tile
::
HostTensorDescriptor
({
G
,
NDoHoWo
,
CZYX
});
// host verify
ck_tile
::
HostTensor
<
DataType
>
in
(
in_desc
);
ck_tile
::
HostTensor
<
DataType
>
out_device
(
out_desc
);
ck_tile
::
HostTensor
<
DataType
>
out_host
(
out_desc
);
std
::
cout
<<
"input: "
<<
in
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"output: "
<<
out_device
.
mDesc
<<
std
::
endl
;
ck_tile
::
FillUniformDistributionIntegerValue
<
DataType
>
{
-
5.
f
,
5.
f
}(
in
);
ck_tile
::
DeviceMem
in_device_buf
(
in
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
out_device_buf
(
out_device
.
get_element_space_size_in_bytes
());
in_device_buf
.
ToDevice
(
in
.
data
());
using
thread_tile
=
ck_tile
::
sequence
<
4
,
4
>
;
using
warp_tile
=
ck_tile
::
sequence
<
8
,
128
>
;
using
block_tile
=
ck_tile
::
sequence
<
32
,
128
>
;
using
Shape
=
ck_tile
::
TileImageToColumnShape
<
thread_tile
,
warp_tile
,
block_tile
>
;
using
PipelineProblem
=
ck_tile
::
BlockImageToColumnProblem
<
DataType
,
DataType
,
Shape
,
NDimSpatial
,
VectorSize
,
VectorSize
>
;
using
Kernel
=
ck_tile
::
ImageToColumn
<
PipelineProblem
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
in_device_buf
.
GetDeviceBuffer
(),
out_device_buf
.
GetDeviceBuffer
(),
G
,
N
,
C
,
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
input_spatial_lengths_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
filter_spatial_lengths_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
output_spatial_lengths_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
+
3
>
(
in_desc
.
get_strides
()),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
3
>
(
out_desc
.
get_strides
()),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
conv_filter_strides_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
conv_filter_dilations_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
input_left_pads_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
input_right_pads_
));
const
dim3
grids
=
Kernel
::
GridSize
(
kargs
.
N
*
kargs
.
output_spatial_lengths
[
0
]
*
kargs
.
output_spatial_lengths
[
1
],
kargs
.
filter_spatial_lengths
[
0
]
*
kargs
.
filter_spatial_lengths
[
1
]
*
kargs
.
C
,
kargs
.
G
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
2
;
ck_tile
::
launch_kernel
(
ck_tile
::
stream_config
{},
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
// reference
ck_tile
::
reference_im2col
<
DataType
,
DataType
,
NDimSpatial
>
(
in
,
out_host
,
conv_params
);
out_device_buf
.
FromDevice
(
out_device
.
data
());
bool
pass
=
ck_tile
::
check_err
(
out_device
,
out_host
);
EXPECT_TRUE
(
pass
);
}
};
class
TestCkTileImageToColumnFloat
:
public
TestCkTileImageToColumn
<
float
>
{
};
class
TestCkTileImageToColumnHalf
:
public
TestCkTileImageToColumn
<
ck_tile
::
half_t
>
{
};
TEST_F
(
TestCkTileImageToColumnFloat
,
TestCorrectness
)
{
this
->
Run
({
2
,
2
,
4
,
1
,
192
,
{
3
,
3
},
{
28
,
28
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
Run
({
2
,
2
,
64
,
1
,
64
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
Run
({
2
,
1
,
64
,
1
,
64
,
{
1
,
1
},
{
7
,
7
},
{
3
,
3
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
Run
({
2
,
1
,
64
,
1
,
64
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
Run
({
2
,
2
,
64
,
1
,
64
,
{
3
,
3
},
{
28
,
28
},
{
2
,
2
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
}});
}
TEST_F
(
TestCkTileImageToColumnHalf
,
TestCorrectness
)
{
this
->
Run
({
2
,
2
,
4
,
1
,
192
,
{
3
,
3
},
{
28
,
28
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
Run
({
2
,
2
,
64
,
1
,
64
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
Run
({
2
,
1
,
64
,
1
,
64
,
{
1
,
1
},
{
7
,
7
},
{
3
,
3
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
Run
({
2
,
1
,
64
,
1
,
64
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
Run
({
2
,
2
,
64
,
1
,
64
,
{
3
,
3
},
{
28
,
28
},
{
2
,
2
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
}});
}
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment