Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9a7fa123
"...composable_kernel.git" did not exist on "5245a0162bbbe5f49be9cc8f2189f53465f691ec"
Commit
9a7fa123
authored
May 17, 2022
by
carlushuang
Browse files
support gcc with cpu only compile
parent
ad09ebdb
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
420 additions
and
336 deletions
+420
-336
CMakeLists.txt
CMakeLists.txt
+27
-14
include/ck/config.hpp
include/ck/config.hpp
+12
-1
include/ck/options.hpp.in
include/ck/options.hpp.in
+1
-0
include/ck/stream_config.hpp
include/ck/stream_config.hpp
+4
-0
include/ck/tensor/static_tensor.hpp
include/ck/tensor/static_tensor.hpp
+2
-0
include/ck/tensor_description/tensor_descriptor.hpp
include/ck/tensor_description/tensor_descriptor.hpp
+5
-0
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
...e/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
+21
-20
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
.../threadwise_tensor_slice_transfer_avx2_specialization.hpp
+10
-0
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+3
-0
include/ck/utility/amd_inline_asm.hpp
include/ck/utility/amd_inline_asm.hpp
+3
-0
include/ck/utility/amd_llvm_intrinsic.hpp
include/ck/utility/amd_llvm_intrinsic.hpp
+2
-0
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+2
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+14
-1
include/ck/utility/data_type_cpu.hpp
include/ck/utility/data_type_cpu.hpp
+300
-296
include/ck/utility/debug.hpp
include/ck/utility/debug.hpp
+2
-1
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+2
-0
include/ck/utility/generic_memory_space_atomic_add.hpp
include/ck/utility/generic_memory_space_atomic_add.hpp
+2
-0
include/ck/utility/get_id.hpp
include/ck/utility/get_id.hpp
+2
-0
include/ck/utility/inner_product.hpp
include/ck/utility/inner_product.hpp
+2
-1
include/ck/utility/magic_division.hpp
include/ck/utility/magic_division.hpp
+4
-2
No files found.
CMakeLists.txt
View file @
9a7fa123
...
@@ -7,6 +7,10 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
...
@@ -7,6 +7,10 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
enable_testing
()
enable_testing
()
option
(
CK_NOGPU
"build without gpu backend"
OFF
)
if
(
NOT CK_NOGPU
)
find_package
(
ROCM REQUIRED PATHS /opt/rocm
)
find_package
(
ROCM REQUIRED PATHS /opt/rocm
)
include
(
ROCMInstallTargets
)
include
(
ROCMInstallTargets
)
...
@@ -19,6 +23,7 @@ include(CheckCXXCompilerFlag)
...
@@ -19,6 +23,7 @@ include(CheckCXXCompilerFlag)
rocm_setup_version
(
VERSION 1.0.0
)
rocm_setup_version
(
VERSION 1.0.0
)
include
(
TargetFlags
)
include
(
TargetFlags
)
list
(
APPEND CMAKE_PREFIX_PATH
${
CMAKE_INSTALL_PREFIX
}
${
CMAKE_INSTALL_PREFIX
}
/llvm
${
CMAKE_INSTALL_PREFIX
}
/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip
)
list
(
APPEND CMAKE_PREFIX_PATH
${
CMAKE_INSTALL_PREFIX
}
${
CMAKE_INSTALL_PREFIX
}
/llvm
${
CMAKE_INSTALL_PREFIX
}
/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip
)
endif
()
## C++
## C++
enable_language
(
CXX
)
enable_language
(
CXX
)
...
@@ -31,25 +36,26 @@ option(CK_TIME_KERNEL "Turning off will disable kernel timing globally" ON)
...
@@ -31,25 +36,26 @@ option(CK_TIME_KERNEL "Turning off will disable kernel timing globally" ON)
## OpenMP
## OpenMP
if
(
CMAKE_CXX_COMPILER_ID MATCHES
"Clang"
)
if
(
CMAKE_CXX_COMPILER_ID MATCHES
"Clang"
)
# workaround issue hipcc in rocm3.5 cannot find openmp
set
(
OMP_CXX_FLAG -fopenmp=libomp -Wno-unused-command-line-argument
)
set
(
O
penMP_CXX
"
${
CMAKE_CXX_COMPILER
}
"
)
set
(
O
MP_LIBRARY /opt/rocm/llvm/lib/libomp.so
)
set
(
O
penMP_CXX_FLAGS
"-fopenmp=libomp -Wno-unused-command-line-argument"
)
set
(
O
MP_LINK_FLAG -Wl,-rpath,/opt/rocm/llvm/lib
)
set
(
OpenMP_CXX_LIB_NAMES
"libomp"
"libgomp"
"libiomp5
"
)
elseif
(
CMAKE_CXX_COMPILER_ID MATCHES
"GNU
"
)
set
(
O
penMP_libomp_LIBRARY
${
OpenMP_CXX_LIB_NAMES
}
)
set
(
O
MP_CXX_FLAG -fopenmp
)
set
(
O
penMP_libgomp_LIBRARY
${
OpenMP_CXX_LIB_NAMES
}
)
set
(
O
MP_LIBRARY
""
)
set
(
O
penMP_libiomp5_LIBRARY
${
OpenMP_CXX_LIB_NAMES
}
)
set
(
O
MP_LINK_FLAG -fopenmp
)
else
()
else
()
find_package
(
OpenMP REQUIRED
)
find_package
(
OpenMP REQUIRED
)
endif
()
endif
()
message
(
"OpenMP_CXX_LIB_NAMES:
${
OpenMP_CXX_LIB_NAMES
}
"
)
#
message("OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}")
message
(
"OpenMP_gomp_LIBRARY:
${
OpenMP_gomp_LIBRARY
}
"
)
#
message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}")
message
(
"OpenMP_pthread_LIBRARY:
${
OpenMP_pthread_LIBRARY
}
"
)
#
message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}")
message
(
"OpenMP_CXX_FLAGS:
${
OpenMP_CXX_FLAGS
}
"
)
#
message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}")
link_libraries
(
${
OpenMP_gomp_LIBRARY
}
)
#
link_libraries(${OpenMP_gomp_LIBRARY})
link_libraries
(
${
OpenMP_pthread_LIBRARY
}
)
#
link_libraries(${OpenMP_pthread_LIBRARY})
if
(
NOT CK_NOGPU
)
## HIP
## HIP
find_package
(
HIP REQUIRED
)
find_package
(
HIP REQUIRED
)
# Override HIP version in config.h, if necessary.
# Override HIP version in config.h, if necessary.
...
@@ -79,6 +85,7 @@ rocm_create_package(
...
@@ -79,6 +85,7 @@ rocm_create_package(
MAINTAINER
"MIOpen Kernels Dev Team <dl.MIOpen@amd.com>"
MAINTAINER
"MIOpen Kernels Dev Team <dl.MIOpen@amd.com>"
LDCONFIG
LDCONFIG
)
)
endif
()
## half
## half
set
(
HALF_INCLUDE_DIR
"
${
PROJECT_SOURCE_DIR
}
/external/include/half"
)
set
(
HALF_INCLUDE_DIR
"
${
PROJECT_SOURCE_DIR
}
/external/include/half"
)
...
@@ -94,7 +101,8 @@ elseif(CK_BACKEND STREQUAL "HIP" OR CK_BACKEND STREQUAL "HIPNOGPU")
...
@@ -94,7 +101,8 @@ elseif(CK_BACKEND STREQUAL "HIP" OR CK_BACKEND STREQUAL "HIPNOGPU")
set
(
CK_TIDY_ERRORS ALL
)
set
(
CK_TIDY_ERRORS ALL
)
endif
()
endif
()
if
(
NOT CK_NOGPU
)
# currently tidy and cppcheck seems also need something from rocm environment
include
(
ClangTidy
)
include
(
ClangTidy
)
enable_clang_tidy
(
enable_clang_tidy
(
CHECKS
CHECKS
...
@@ -224,6 +232,11 @@ enable_cppcheck(
...
@@ -224,6 +232,11 @@ enable_cppcheck(
CPPCHECK=1
CPPCHECK=1
__linux__=1
__linux__=1
)
)
else
()
function
(
clang_tidy_check TARGET
)
# dummy empty functoin
endfunction
()
endif
()
set
(
CMAKE_LIBRARY_OUTPUT_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
/lib
)
set
(
CMAKE_LIBRARY_OUTPUT_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
/lib
)
set
(
CMAKE_ARCHIVE_OUTPUT_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
/lib
)
set
(
CMAKE_ARCHIVE_OUTPUT_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
/lib
)
...
...
include/ck/config.hpp
View file @
9a7fa123
#ifndef CK_CONFIG_AMD_HPP
#ifndef CK_CONFIG_AMD_HPP
#define CK_CONFIG_AMD_HPP
#define CK_CONFIG_AMD_HPP
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#include "ck/options.hpp"
#ifdef CK_NOGPU
#define __host__
#define __device__
#else
#include "hip/hip_runtime.h"
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "hip/hip_fp16.h"
#endif
#endif
...
@@ -26,6 +31,12 @@
...
@@ -26,6 +31,12 @@
#endif
#endif
#endif
#endif
#if defined(__GNUC__) && !defined(__clang__) && !defined(__llvm__)
#if __GNUC__ < 9
#error "If use gcc, need make sure use at least gcc-9"
#endif
#endif
// buffer resource
// buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
...
...
include/ck/options.hpp.in
View file @
9a7fa123
#pragma once
#pragma once
#cmakedefine01 CK_TIME_KERNEL
#cmakedefine01 CK_TIME_KERNEL
#cmakedefine CK_NOGPU
include/ck/stream_config.hpp
View file @
9a7fa123
#pragma once
#pragma once
#ifndef CK_NOGPU
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_fp16.h>
#endif
struct
StreamConfig
struct
StreamConfig
{
{
#ifndef CK_NOGPU
hipStream_t
stream_id_
=
nullptr
;
hipStream_t
stream_id_
=
nullptr
;
#endif
bool
time_kernel_
=
false
;
bool
time_kernel_
=
false
;
};
};
include/ck/tensor/static_tensor.hpp
View file @
9a7fa123
...
@@ -79,6 +79,7 @@ struct StaticTensor
...
@@ -79,6 +79,7 @@ struct StaticTensor
T
ignored_element_scalar_
;
T
ignored_element_scalar_
;
};
};
#ifndef CK_NOGPU
// StaticTensor for vector
// StaticTensor for vector
template
<
AddressSpaceEnum
AddressSpace
,
template
<
AddressSpaceEnum
AddressSpace
,
typename
S
,
typename
S
,
...
@@ -244,6 +245,7 @@ struct StaticTensorTupleOfVectorBuffer
...
@@ -244,6 +245,7 @@ struct StaticTensorTupleOfVectorBuffer
const
S
invalid_element_scalar_value_
=
S
{
0
};
const
S
invalid_element_scalar_value_
=
S
{
0
};
S
ignored_element_scalar_
;
S
ignored_element_scalar_
;
};
};
#endif
template
<
AddressSpaceEnum
AddressSpace
,
template
<
AddressSpaceEnum
AddressSpace
,
typename
T
,
typename
T
,
...
...
include/ck/tensor_description/tensor_descriptor.hpp
View file @
9a7fa123
...
@@ -277,7 +277,12 @@ struct TensorCoordinateStep
...
@@ -277,7 +277,12 @@ struct TensorCoordinateStep
MultiIndex
<
NTransform
>
do_transforms_
;
MultiIndex
<
NTransform
>
do_transforms_
;
// HACK: control UpdateLowerIndex()
// HACK: control UpdateLowerIndex()
#if defined(__GNUC__) && !defined(__clang__) && !defined(__llvm__)
// constexpr static data member ‘update_lower_index_hack_’ must have an initializer
static
constexpr
UpdateLowerIndexHack
update_lower_index_hack_
{};
#else
static
constexpr
UpdateLowerIndexHack
update_lower_index_hack_
;
static
constexpr
UpdateLowerIndexHack
update_lower_index_hack_
;
#endif
};
};
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// TODO: How to fix this? It uses an struct instead of lambda because lambda
...
...
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
View file @
9a7fa123
#ifndef CK_THREADWISE_GEMM_AVX2_HPP
#ifndef CK_THREADWISE_GEMM_AVX2_HPP
#define CK_THREADWISE_GEMM_AVX2_HPP
#define CK_THREADWISE_GEMM_AVX2_HPP
#include <assert.h>
#if CK_USE_X86_INLINE_ASM == 0
#if CK_USE_X86_INLINE_ASM == 0
#include <immintrin.h>
#include <immintrin.h>
#endif
#endif
...
@@ -122,22 +123,22 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -122,22 +123,22 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".macro vbroadcast_a%= i_k, i_m, ymm
\n
"
// A in rax(r8, r9), lda in rcx
".macro vbroadcast_a%= i_k, i_m, ymm
\n
"
// A in rax(r8, r9), lda in rcx
".if m_ABytes == 4
\n
"
".if m_ABytes == 4
\n
"
".if m_TransA == 0
\n
"
".if m_TransA == 0
\n
"
"vbroadcastss_%= %%rax, 0, 0, (
\\
i_m +
\\
i_k * m_Mr) * m_ABytes,
\\
ymm
\n
"
"vbroadcastss_%= %%rax, 0, 0,
(
(
\\
i_m +
\\
i_k * m_Mr) * m_ABytes
)
,
\\
ymm
\n
"
".else
\n
"
".else
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1) || (
\\
i_m == 2)
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1) || (
\\
i_m == 2)
\n
"
"vbroadcastss_%= %%rax, %%rcx,
\\
i_m,
\\
i_k * m_ABytes,
\\
ymm
\n
"
"vbroadcastss_%= %%rax, %%rcx,
\\
i_m,
(
\\
i_k * m_ABytes
)
,
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"vbroadcastss_%= %%r8, %%rcx,
\\
i_m-3,
\\
i_k * m_ABytes,
\\
ymm
\n
"
"vbroadcastss_%= %%r8, %%rcx,
\\
i_m-3,
(
\\
i_k * m_ABytes
)
,
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if m_TransA == 0
\n
"
".if m_TransA == 0
\n
"
"vpbroadcastw_%= %%rax, 0, 0, (
\\
i_m +
\\
i_k * m_Mr) * m_ABytes, %%xmm15
\n
"
"vpbroadcastw_%= %%rax, 0, 0,
(
(
\\
i_m +
\\
i_k * m_Mr) * m_ABytes
)
, %%xmm15
\n
"
".else
\n
"
".else
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1) || (
\\
i_m == 2)
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1) || (
\\
i_m == 2)
\n
"
"vpbroadcastw_%= %%rax, %%rcx,
\\
i_m,
\\
i_k * m_ABytes, %%xmm15
\n
"
"vpbroadcastw_%= %%rax, %%rcx,
\\
i_m,
(
\\
i_k * m_ABytes
)
, %%xmm15
\n
"
".else
\n
"
".else
\n
"
"vpbroadcastw_%= %%r8, %%rcx,
\\
i_m-3,
\\
i_k * m_ABytes, %%xmm15
\n
"
"vpbroadcastw_%= %%r8, %%rcx,
\\
i_m-3,
(
\\
i_k * m_ABytes
)
, %%xmm15
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
"vcvtph2ps %%xmm15,
\\
ymm
\n
"
"vcvtph2ps %%xmm15,
\\
ymm
\n
"
...
@@ -147,15 +148,15 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -147,15 +148,15 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".macro vload_b%= i_k, i_n, ymm
\n
"
// B in rbx, lda in rdx, i_n should be 0, 1
".macro vload_b%= i_k, i_n, ymm
\n
"
// B in rbx, lda in rdx, i_n should be 0, 1
".if m_BBytes == 4
\n
"
".if m_BBytes == 4
\n
"
".if m_TransB == 0
\n
"
".if m_TransB == 0
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_n,
\\
i_k*m_BBytes*8,
\\
ymm
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_n,
(
\\
i_k*m_BBytes*8
)
,
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"vmovups_%= %%rbx, 0, 0, (
\\
i_k*m_Nr +
\\
i_n*8)*m_BBytes,
\\
ymm
\n
"
"vmovups_%= %%rbx, 0, 0,
(
(
\\
i_k*m_Nr +
\\
i_n*8)*m_BBytes
)
,
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if m_TransB == 0
\n
"
".if m_TransB == 0
\n
"
"vcvtph2ps_%= %%rbx, %%rdx,
\\
i_n,
\\
i_k*m_BBytes*8,
\\
ymm
\n
"
"vcvtph2ps_%= %%rbx, %%rdx,
\\
i_n,
(
\\
i_k*m_BBytes*8
)
,
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"vcvtph2ps_%= %%rbx, 0, 0, (
\\
i_k*m_Nr +
\\
i_n*8)*m_BBytes,
\\
ymm
\n
"
"vcvtph2ps_%= %%rbx, 0, 0,
(
(
\\
i_k*m_Nr +
\\
i_n*8)*m_BBytes
)
,
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endm
\n
"
".endm
\n
"
...
@@ -682,22 +683,22 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -682,22 +683,22 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".macro vbroadcast_a%= i_k, i_m, ymm
\n
"
// A in rax(r8), lda in rcx
".macro vbroadcast_a%= i_k, i_m, ymm
\n
"
// A in rax(r8), lda in rcx
".if m_ABytes == 4
\n
"
".if m_ABytes == 4
\n
"
".if m_TransA == 0
\n
"
".if m_TransA == 0
\n
"
"vbroadcastss_%= %%rax, 0, 0, (
\\
i_m +
\\
i_k * m_Mr) * m_ABytes,
\\
ymm
\n
"
"vbroadcastss_%= %%rax, 0, 0,
(
(
\\
i_m +
\\
i_k * m_Mr) * m_ABytes
)
,
\\
ymm
\n
"
".else
\n
"
".else
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1)
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1)
\n
"
"vbroadcastss_%= %%rax, %%rcx,
\\
i_m,
\\
i_k * m_ABytes,
\\
ymm
\n
"
"vbroadcastss_%= %%rax, %%rcx,
\\
i_m,
(
\\
i_k * m_ABytes
)
,
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"vbroadcastss_%= %%r8, %%rcx,
\\
i_m-2,
\\
i_k * m_ABytes,
\\
ymm
\n
"
"vbroadcastss_%= %%r8, %%rcx,
\\
i_m-2,
(
\\
i_k * m_ABytes
)
,
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if m_TransA == 0
\n
"
".if m_TransA == 0
\n
"
"vpbroadcastw_%= %%rax, 0, 0, (
\\
i_m +
\\
i_k * m_Mr) * m_ABytes, %%xmm15
\n
"
"vpbroadcastw_%= %%rax, 0, 0,
(
(
\\
i_m +
\\
i_k * m_Mr) * m_ABytes
)
, %%xmm15
\n
"
".else
\n
"
".else
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1)
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1)
\n
"
"vpbroadcastw_%= %%rax, %%rcx,
\\
i_m,
\\
i_k * m_ABytes, %%xmm15
\n
"
"vpbroadcastw_%= %%rax, %%rcx,
\\
i_m,
(
\\
i_k * m_ABytes
)
, %%xmm15
\n
"
".else
\n
"
".else
\n
"
"vpbroadcastw_%= %%r8, %%rcx,
\\
i_m-2,
\\
i_k * m_ABytes, %%xmm15
\n
"
"vpbroadcastw_%= %%r8, %%rcx,
\\
i_m-2,
(
\\
i_k * m_ABytes
)
, %%xmm15
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
"vcvtph2ps %%xmm15,
\\
ymm
\n
"
"vcvtph2ps %%xmm15,
\\
ymm
\n
"
...
@@ -707,15 +708,15 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -707,15 +708,15 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".macro vload_b%= i_k, i_n, ymm
\n
"
// B in rbx, lda in rdx, i_n should be 0, 1, 2
".macro vload_b%= i_k, i_n, ymm
\n
"
// B in rbx, lda in rdx, i_n should be 0, 1, 2
".if m_BBytes == 4
\n
"
".if m_BBytes == 4
\n
"
".if m_TransB == 0
\n
"
".if m_TransB == 0
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_n,
\\
i_k*m_BBytes*8,
\\
ymm
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_n,
(
\\
i_k*m_BBytes*8
)
,
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"vmovups_%= %%rbx, 0, 0, (
\\
i_k*m_Nr +
\\
i_n*8)*m_BBytes,
\\
ymm
\n
"
"vmovups_%= %%rbx, 0, 0,
(
(
\\
i_k*m_Nr +
\\
i_n*8)*m_BBytes
)
,
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if m_TransB == 0
\n
"
".if m_TransB == 0
\n
"
"vcvtph2ps_%= %%rbx, %%rdx,
\\
i_n,
\\
i_k*m_BBytes*8,
\\
ymm
\n
"
"vcvtph2ps_%= %%rbx, %%rdx,
\\
i_n,
(
\\
i_k*m_BBytes*8
)
,
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"vcvtph2ps_%= %%rbx, 0, 0, (
\\
i_k*m_Nr +
\\
i_n*8)*m_BBytes,
\\
ymm
\n
"
"vcvtph2ps_%= %%rbx, 0, 0,
(
(
\\
i_k*m_Nr +
\\
i_n*8)*m_BBytes
)
,
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endm
\n
"
".endm
\n
"
...
...
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
View file @
9a7fa123
...
@@ -46,7 +46,13 @@ void memcpy32_avx2(void* dst, const void* src, const ck::index_t n, const Elemen
...
@@ -46,7 +46,13 @@ void memcpy32_avx2(void* dst, const void* src, const ck::index_t n, const Elemen
}
}
if
(
i_n
&
2
)
if
(
i_n
&
2
)
{
{
#if defined(__GNUC__) && !defined(__clang__) && !defined(__llvm__)
__m128i
s
=
_mm_loadu_si64
(
p_src
);
__m128
v
=
element_op
.
Apply
(
*
reinterpret_cast
<
__m128
*>
(
&
s
));
_mm_storeu_si64
(
p_dst
,
*
reinterpret_cast
<
__m128i
*>
(
&
v
));
#else
_mm_storeu_si64
(
p_dst
,
element_op
.
Apply
(
_mm_loadu_si64
(
p_src
)));
_mm_storeu_si64
(
p_dst
,
element_op
.
Apply
(
_mm_loadu_si64
(
p_src
)));
#endif
p_dst
+=
2
;
p_dst
+=
2
;
p_src
+=
2
;
p_src
+=
2
;
}
}
...
@@ -82,7 +88,11 @@ inline void memset32_avx2(void* dst, const int32_t value, const ck::index_t n)
...
@@ -82,7 +88,11 @@ inline void memset32_avx2(void* dst, const int32_t value, const ck::index_t n)
}
}
if
(
i_n
&
2
)
if
(
i_n
&
2
)
{
{
#if defined(__GNUC__) && !defined(__clang__) && !defined(__llvm__)
_mm_storeu_si64
(
p_dst
,
*
reinterpret_cast
<
__m128i
*>
(
&
xmm
));
#else
_mm_storeu_si64
(
p_dst
,
xmm
);
_mm_storeu_si64
(
p_dst
,
xmm
);
#endif
p_dst
+=
2
;
p_dst
+=
2
;
}
}
if
(
i_n
&
1
)
if
(
i_n
&
1
)
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
9a7fa123
#pragma once
#pragma once
#include "data_type.hpp"
#include "data_type.hpp"
#ifndef CK_NOGPU
namespace
ck
{
namespace
ck
{
template
<
typename
T
>
template
<
typename
T
>
...
@@ -1047,3 +1048,5 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
...
@@ -1047,3 +1048,5 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
}
}
}
// namespace ck
}
// namespace ck
#endif
include/ck/utility/amd_inline_asm.hpp
View file @
9a7fa123
...
@@ -4,6 +4,8 @@
...
@@ -4,6 +4,8 @@
#include "data_type.hpp"
#include "data_type.hpp"
#include "c_style_pointer_cast.hpp"
#include "c_style_pointer_cast.hpp"
#ifndef CK_NOGPU
// TODO: deprecate all amd_assembly_outer_product_xxx
// TODO: deprecate all amd_assembly_outer_product_xxx
namespace
ck
{
namespace
ck
{
...
@@ -354,3 +356,4 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
...
@@ -354,3 +356,4 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
}
// namespace ck
}
// namespace ck
#endif
#endif
#endif
include/ck/utility/amd_llvm_intrinsic.hpp
View file @
9a7fa123
#ifndef CK_AMD_LLVM_INTRINSIC_HPP
#ifndef CK_AMD_LLVM_INTRINSIC_HPP
#define CK_AMD_LLVM_INTRINSIC_HPP
#define CK_AMD_LLVM_INTRINSIC_HPP
#ifndef CK_NOGPU
#include "data_type.hpp"
#include "data_type.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -9,3 +10,4 @@ __device__ int32_t llvm_amdgcn_readfirstlane_i32(int32_t i) __asm("llvm.amdgcn.r
...
@@ -9,3 +10,4 @@ __device__ int32_t llvm_amdgcn_readfirstlane_i32(int32_t i) __asm("llvm.amdgcn.r
}
// namespace ck
}
// namespace ck
#endif
#endif
#endif
include/ck/utility/amd_xdlops.hpp
View file @
9a7fa123
#ifndef CK_AMD_XDLOPS_HPP
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
#ifndef CK_NOGPU
#include "data_type.hpp"
#include "data_type.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -296,3 +297,4 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
...
@@ -296,3 +297,4 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
}
// namespace ck
}
// namespace ck
#endif
#endif
#endif
include/ck/utility/data_type.hpp
View file @
9a7fa123
#pragma once
#pragma once
#include "statically_indexed_array.hpp"
#include "statically_indexed_array.hpp"
#ifdef CK_NOGPU
#include "half.hpp"
#endif
namespace
ck
{
namespace
ck
{
using
bhalf_t
=
ushort
;
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
#ifdef CK_NOGPU
using
half_t
=
half_float
::
half
;
#else
using
half_t
=
_Float16
;
#endif
// vector_type
// vector_type
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
...
@@ -14,8 +21,10 @@ struct vector_type;
...
@@ -14,8 +21,10 @@ struct vector_type;
// intentionally have only declaration but no definition to cause compilation failure when trying to
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
// vectors"
// vectors"
#ifdef __clang__
template
<
typename
T
,
index_t
V
,
index_t
N
>
template
<
typename
T
,
index_t
V
,
index_t
N
>
struct
vector_type
<
T
__attribute__
((
ext_vector_type
(
V
))),
N
>
;
struct
vector_type
<
T
__attribute__
((
ext_vector_type
(
V
))),
N
>
;
#endif
// Caution: DO NOT REMOVE
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// intentionally have only declaration but no definition to cause compilation failure when trying to
...
@@ -32,11 +41,13 @@ struct vector_type_maker
...
@@ -32,11 +41,13 @@ struct vector_type_maker
using
type
=
vector_type
<
T
,
N
>
;
using
type
=
vector_type
<
T
,
N
>
;
};
};
#ifdef __clang__
template
<
typename
T
,
index_t
N0
,
index_t
N1
>
template
<
typename
T
,
index_t
N0
,
index_t
N1
>
struct
vector_type_maker
<
T
__attribute__
((
ext_vector_type
(
N1
))),
N0
>
struct
vector_type_maker
<
T
__attribute__
((
ext_vector_type
(
N1
))),
N0
>
{
{
using
type
=
vector_type
<
T
,
N0
*
N1
>
;
using
type
=
vector_type
<
T
,
N0
*
N1
>
;
};
};
#endif
template
<
typename
T
,
index_t
N0
,
index_t
N1
>
template
<
typename
T
,
index_t
N0
,
index_t
N1
>
struct
vector_type_maker
<
vector_type
<
T
,
N1
>
,
N0
>
struct
vector_type_maker
<
vector_type
<
T
,
N1
>
,
N0
>
...
@@ -69,12 +80,14 @@ template <typename X, typename Y>
...
@@ -69,12 +80,14 @@ template <typename X, typename Y>
using
has_same_scalar_type
=
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>>::
type
,
using
has_same_scalar_type
=
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>>::
type
,
typename
scalar_type
<
remove_cvref_t
<
Y
>>::
type
>
;
typename
scalar_type
<
remove_cvref_t
<
Y
>>::
type
>
;
#ifdef __clang__
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
struct
scalar_type
<
T
__attribute__
((
ext_vector_type
(
N
)))
>
struct
scalar_type
<
T
__attribute__
((
ext_vector_type
(
N
)))
>
{
{
using
type
=
T
;
using
type
=
T
;
static
constexpr
index_t
vector_size
=
N
;
static
constexpr
index_t
vector_size
=
N
;
};
};
#endif
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
struct
scalar_type
<
vector_type
<
T
,
N
>>
struct
scalar_type
<
vector_type
<
T
,
N
>>
...
...
include/ck/utility/data_type_cpu.hpp
View file @
9a7fa123
#pragma once
#pragma once
#include <immintrin.h>
#include <immintrin.h>
#include "half.hpp"
namespace
ck
{
namespace
cpu
{
namespace
ck
{
// vector_type
namespace
cpu
{
template
<
typename
T
,
index_t
N
>
struct
vector_type
;
// vector_type
template
<
typename
T
,
index_t
N
>
// Caution: DO NOT REMOVE
struct
vector_type
;
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
// Caution: DO NOT REMOVE
// vectors"
// intentionally have only declaration but no definition to cause compilation failure when trying to
template
<
typename
T
,
index_t
V
,
index_t
N
>
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
struct
vector_type
<
T
__attribute__
((
ext_vector_type
(
V
))),
N
>
;
// vectors"
#ifdef __clang__
// Caution: DO NOT REMOVE
template
<
typename
T
,
index_t
V
,
index_t
N
>
// intentionally have only declaration but no definition to cause compilation failure when trying to
struct
vector_type
<
T
__attribute__
((
ext_vector_type
(
V
))),
N
>
;
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
#endif
// vectors"
template
<
typename
T
,
index_t
V
,
index_t
N
>
// Caution: DO NOT REMOVE
struct
vector_type
<
vector_type
<
T
,
V
>
,
N
>
;
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
// vector_type_maker
// vectors"
// This is the right way to handle "vector of vectors": making a bigger vector instead
template
<
typename
T
,
index_t
V
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
struct
vector_type
<
vector_type
<
T
,
V
>
,
N
>
;
struct
vector_type_maker
{
// vector_type_maker
using
type
=
vector_type
<
T
,
N
>
;
// This is the right way to handle "vector of vectors": making a bigger vector instead
};
template
<
typename
T
,
index_t
N
>
struct
vector_type_maker
template
<
typename
T
,
index_t
N
>
{
using
vector_type_maker_t
=
typename
vector_type_maker
<
T
,
N
>::
type
;
using
type
=
vector_type
<
T
,
N
>
;
};
template
<
typename
T
,
index_t
N
>
constexpr
auto
make_vector_type
(
Number
<
N
>
)
template
<
typename
T
,
index_t
N
>
{
using
vector_type_maker_t
=
typename
vector_type_maker
<
T
,
N
>::
type
;
return
typename
vector_type_maker
<
T
,
N
>::
type
{};
}
template
<
typename
T
,
index_t
N
>
constexpr
auto
make_vector_type
(
Number
<
N
>
)
template
<
>
{
struct
vector_type
<
float
,
1
>
return
typename
vector_type_maker
<
T
,
N
>::
type
{};
{
}
using
d1_t
=
float
;
// SSE
template
<
>
using
type
=
float
;
struct
vector_type
<
float
,
1
>
{
type
data_
;
using
d1_t
=
float
;
// SSE
vector_type
()
:
data_
{
0
}
{}
using
type
=
float
;
// vector_type(float x) : data_{x} {}
type
data_
;
vector_type
(
type
v
)
:
data_
{
v
}
{}
vector_type
()
:
data_
{
0
}
{}
vector_type
(
const
float
*
mem
)
:
data_
{
*
mem
}
{}
// vector_type(float x) : data_{x} {}
template
<
typename
X
>
vector_type
(
type
v
)
:
data_
{
v
}
{}
constexpr
const
auto
&
AsType
()
const
{
vector_type
(
const
float
*
mem
)
:
data_
{
*
mem
}
{}
static_assert
(
std
::
is_same
<
X
,
type
>::
value
,
"wrong!"
);
template
<
typename
X
>
return
data_
;
constexpr
const
auto
&
AsType
()
const
}
{
static_assert
(
std
::
is_same
<
X
,
type
>::
value
,
"wrong!"
);
template
<
typename
X
>
constexpr
auto
&
AsType
()
return
data_
;
{
}
static_assert
(
std
::
is_same
<
X
,
type
>::
value
,
"wrong!"
);
template
<
typename
X
>
return
data_
;
constexpr
auto
&
AsType
()
}
{
static_assert
(
std
::
is_same
<
X
,
type
>::
value
,
"wrong!"
);
constexpr
void
Load
(
const
float
*
mem
)
{
data_
=
*
mem
;
}
return
data_
;
constexpr
void
Store
(
float
*
mem
)
const
{
*
mem
=
data_
;
}
}
};
constexpr
void
Load
(
const
float
*
mem
)
{
data_
=
*
mem
;
}
template
<
>
struct
vector_type
<
float
,
4
>
constexpr
void
Store
(
float
*
mem
)
const
{
*
mem
=
data_
;
}
{
};
using
d1_t
=
float
;
// SSE
template
<
>
using
type
=
__m128
;
struct
vector_type
<
float
,
4
>
{
type
data_
;
using
d1_t
=
float
;
// SSE
vector_type
()
:
data_
{
_mm_setzero_ps
()}
{}
using
type
=
__m128
;
vector_type
(
float
x
)
:
data_
{
_mm_set1_ps
(
x
)}
{}
type
data_
;
vector_type
(
type
v
)
:
data_
{
v
}
{}
vector_type
()
:
data_
{
_mm_setzero_ps
()}
{}
vector_type
(
const
float
*
mem
)
:
data_
{
_mm_loadu_ps
(
mem
)}
{}
vector_type
(
float
x
)
:
data_
{
_mm_set1_ps
(
x
)}
{}
template
<
typename
X
>
vector_type
(
type
v
)
:
data_
{
v
}
{}
constexpr
const
auto
&
AsType
()
const
{
vector_type
(
const
float
*
mem
)
:
data_
{
_mm_loadu_ps
(
mem
)}
{}
static_assert
(
std
::
is_same
<
X
,
type
>::
value
,
"wrong!"
);
template
<
typename
X
>
return
data_
;
constexpr
const
auto
&
AsType
()
const
}
{
static_assert
(
std
::
is_same
<
X
,
type
>::
value
,
"wrong!"
);
template
<
typename
X
>
constexpr
auto
&
AsType
()
return
data_
;
{
}
static_assert
(
std
::
is_same
<
X
,
type
>::
value
,
"wrong!"
);
template
<
typename
X
>
return
data_
;
constexpr
auto
&
AsType
()
}
{
static_assert
(
std
::
is_same
<
X
,
type
>::
value
,
"wrong!"
);
constexpr
void
Load
(
const
float
*
mem
)
{
data_
=
_mm_loadu_ps
(
mem
);
}
return
data_
;
constexpr
void
Store
(
float
*
mem
)
const
{
_mm_storeu_ps
(
mem
,
data_
);
}
}
};
void
Load
(
const
float
*
mem
)
{
data_
=
_mm_loadu_ps
(
mem
);
}
template
<
>
struct
vector_type
<
float
,
8
>
void
Store
(
float
*
mem
)
const
{
_mm_storeu_ps
(
mem
,
data_
);
}
{
};
using
d1_t
=
float
;
// SSE
template
<
>
using
type
=
__m256
;
struct
vector_type
<
float
,
8
>
{
type
data_
;
using
d1_t
=
float
;
// SSE
vector_type
()
:
data_
{
_mm256_setzero_ps
()}
{}
using
type
=
__m256
;
vector_type
(
float
x
)
:
data_
{
_mm256_set1_ps
(
x
)}
{}
type
data_
;
vector_type
(
type
v
)
:
data_
{
v
}
{}
vector_type
()
:
data_
{
_mm256_setzero_ps
()}
{}
vector_type
(
const
float
*
mem
)
:
data_
{
_mm256_loadu_ps
(
mem
)}
{}
vector_type
(
float
x
)
:
data_
{
_mm256_set1_ps
(
x
)}
{}
template
<
typename
X
>
vector_type
(
type
v
)
:
data_
{
v
}
{}
constexpr
const
auto
&
AsType
()
const
{
vector_type
(
const
float
*
mem
)
:
data_
{
_mm256_loadu_ps
(
mem
)}
{}
static_assert
(
std
::
is_same
<
X
,
type
>::
value
,
"wrong!"
);
template
<
typename
X
>
return
data_
;
constexpr
const
auto
&
AsType
()
const
}
{
static_assert
(
std
::
is_same
<
X
,
type
>::
value
,
"wrong!"
);
template
<
typename
X
>
constexpr
auto
&
AsType
()
return
data_
;
{
}
static_assert
(
std
::
is_same
<
X
,
type
>::
value
,
"wrong!"
);
template
<
typename
X
>
return
data_
;
constexpr
auto
&
AsType
()
}
{
static_assert
(
std
::
is_same
<
X
,
type
>::
value
,
"wrong!"
);
constexpr
void
Load
(
const
float
*
mem
)
{
data_
=
_mm256_loadu_ps
(
mem
);
}
return
data_
;
constexpr
void
Store
(
float
*
mem
)
const
{
_mm256_storeu_ps
(
mem
,
data_
);
}
}
};
void
Load
(
const
float
*
mem
)
{
data_
=
_mm256_loadu_ps
(
mem
);
}
template
<
typename
T
>
struct
to_vector_type
void
Store
(
float
*
mem
)
const
{
_mm256_storeu_ps
(
mem
,
data_
);
}
{
};
using
type
=
T
;
};
template
<
typename
T
>
struct
to_vector_type
template
<
>
{
struct
to_vector_type
<
__m128
>
using
type
=
T
;
{
};
using
type
=
vector_type
<
float
,
4
>
;
};
template
<
>
struct
to_vector_type
<
__m128
>
template
<
>
{
struct
to_vector_type
<
__m256
>
using
type
=
vector_type
<
float
,
4
>
;
{
};
using
type
=
vector_type
<
float
,
8
>
;
};
template
<
>
struct
to_vector_type
<
__m256
>
template
<
typename
Tv
,
typename
Tp
>
{
inline
void
load_vector
(
Tv
&
v
,
const
Tp
*
mem
)
using
type
=
vector_type
<
float
,
8
>
;
{
};
v
=
*
reinterpret_cast
<
const
Tv
*>
(
mem
);
}
template
<
typename
Tv
,
typename
Tp
>
inline
void
load_vector
(
Tv
&
v
,
const
Tp
*
mem
)
template
<
>
{
inline
void
load_vector
(
__m128
&
v
,
const
float
*
mem
)
v
=
*
reinterpret_cast
<
const
Tv
*>
(
mem
);
{
}
v
=
_mm_loadu_ps
(
mem
);
}
template
<
>
inline
void
load_vector
(
__m128
&
v
,
const
float
*
mem
)
template
<
>
{
inline
void
load_vector
(
__m256
&
v
,
const
float
*
mem
)
v
=
_mm_loadu_ps
(
mem
);
{
}
v
=
_mm256_loadu_ps
(
mem
);
}
template
<
>
inline
void
load_vector
(
__m256
&
v
,
const
float
*
mem
)
template
<
typename
Tv
,
typename
Tp
>
{
inline
void
store_vector
(
const
Tv
&
v
,
Tp
*
mem
)
v
=
_mm256_loadu_ps
(
mem
);
{
}
*
reinterpret_cast
<
Tv
*>
(
mem
)
=
v
;
}
template
<
typename
Tv
,
typename
Tp
>
inline
void
store_vector
(
const
Tv
&
v
,
Tp
*
mem
)
template
<
>
{
inline
void
store_vector
(
const
__m128
&
v
,
float
*
mem
)
*
reinterpret_cast
<
Tv
*>
(
mem
)
=
v
;
{
}
_mm_storeu_ps
(
mem
,
v
);
}
template
<
>
inline
void
store_vector
(
const
__m128
&
v
,
float
*
mem
)
template
<
>
{
inline
void
store_vector
(
const
__m256
&
v
,
float
*
mem
)
_mm_storeu_ps
(
mem
,
v
);
{
}
_mm256_storeu_ps
(
mem
,
v
);
}
template
<
>
inline
void
store_vector
(
const
__m256
&
v
,
float
*
mem
)
template
<
typename
Tv
,
typename
Tx
>
{
inline
void
set_vector
(
Tv
&
v
,
const
Tx
x
)
_mm256_storeu_ps
(
mem
,
v
);
{
}
v
=
static_cast
<
const
Tv
>
(
x
);
}
template
<
typename
Tv
,
typename
Tx
>
inline
void
set_vector
(
Tv
&
v
,
const
Tx
x
)
template
<
>
{
inline
void
set_vector
(
__m128
&
v
,
const
float
x
)
v
=
static_cast
<
const
Tv
>
(
x
);
{
}
v
=
_mm_set1_ps
(
x
);
}
template
<
>
inline
void
set_vector
(
__m128
&
v
,
const
float
x
)
template
<
>
{
inline
void
set_vector
(
__m256
&
v
,
const
float
x
)
v
=
_mm_set1_ps
(
x
);
{
}
v
=
_mm256_set1_ps
(
x
);
}
template
<
>
inline
void
set_vector
(
__m256
&
v
,
const
float
x
)
template
<
typename
Tv
>
{
inline
void
clear_vector
(
Tv
&
v
)
v
=
_mm256_set1_ps
(
x
);
{
}
v
=
static_cast
<
Tv
>
(
0
);
}
template
<
typename
Tv
>
inline
void
clear_vector
(
Tv
&
v
)
template
<
>
{
inline
void
clear_vector
(
__m128
&
v
)
v
=
static_cast
<
Tv
>
(
0
);
{
}
v
=
_mm_setzero_ps
();
}
template
<
>
inline
void
clear_vector
(
__m128
&
v
)
template
<
>
{
inline
void
clear_vector
(
__m256
&
v
)
v
=
_mm_setzero_ps
();
{
}
v
=
_mm256_setzero_ps
();
}
template
<
>
inline
void
clear_vector
(
__m256
&
v
)
using
float4_t
=
typename
vector_type
<
float
,
4
>::
type
;
{
using
float8_t
=
typename
vector_type
<
float
,
8
>::
type
;
v
=
_mm256_setzero_ps
();
}
// scalar_type
template
<
typename
TV
>
using
float4_t
=
typename
vector_type
<
float
,
4
>::
type
;
struct
scalar_type
;
using
float8_t
=
typename
vector_type
<
float
,
8
>::
type
;
// is_scalar_type
// scalar_type
template
<
typename
TV
>
template
<
typename
TV
>
struct
is_scalar_type
struct
scalar_type
;
{
static
constexpr
bool
value
=
(
scalar_type
<
remove_cvref_t
<
TV
>>::
vector_size
==
1
);
// is_scalar_type
};
template
<
typename
TV
>
struct
is_scalar_type
// has_same_scalar_type
{
template
<
typename
X
,
typename
Y
>
static
constexpr
bool
value
=
(
scalar_type
<
remove_cvref_t
<
TV
>>::
vector_size
==
1
);
using
has_same_scalar_type
=
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>>::
type
,
};
typename
scalar_type
<
remove_cvref_t
<
Y
>>::
type
>
;
// has_same_scalar_type
template
<
typename
T
,
index_t
N
>
template
<
typename
X
,
typename
Y
>
struct
scalar_type
<
vector_type
<
T
,
N
>>
using
has_same_scalar_type
=
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>>::
type
,
{
typename
scalar_type
<
remove_cvref_t
<
Y
>>::
type
>
;
using
type
=
T
;
static
constexpr
index_t
vector_size
=
N
;
template
<
typename
T
,
index_t
N
>
};
struct
scalar_type
<
vector_type
<
T
,
N
>>
{
template
<
>
using
type
=
T
;
struct
scalar_type
<
float4_t
>
static
constexpr
index_t
vector_size
=
N
;
{
};
using
type
=
float
;
static
constexpr
index_t
vector_size
=
4
;
template
<
>
};
struct
scalar_type
<
float4_t
>
{
template
<
>
using
type
=
float
;
struct
scalar_type
<
float8_t
>
static
constexpr
index_t
vector_size
=
4
;
{
};
using
type
=
float
;
static
constexpr
index_t
vector_size
=
8
;
template
<
>
};
struct
scalar_type
<
float8_t
>
{
//
using
type
=
float
;
template
<
>
static
constexpr
index_t
vector_size
=
8
;
struct
scalar_type
<
float
>
};
{
using
type
=
float
;
//
static
constexpr
index_t
vector_size
=
1
;
template
<
>
};
struct
scalar_type
<
float
>
{
}
// namespace cpu
using
type
=
float
;
}
// namespace ck
static
constexpr
index_t
vector_size
=
1
;
};
}
// namespace cpu
}
// namespace ck
include/ck/utility/debug.hpp
View file @
9a7fa123
#ifndef UTILITY_DEBUG_HPP
#ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
#ifndef CK_NOGPU
namespace
ck
{
namespace
ck
{
namespace
debug
{
namespace
debug
{
...
@@ -74,4 +74,5 @@ __device__ void print_shared(T const* p_shared, index_t num_elements)
...
@@ -74,4 +74,5 @@ __device__ void print_shared(T const* p_shared, index_t num_elements)
}
// namespace debug
}
// namespace debug
}
// namespace ck
}
// namespace ck
#endif
#endif // UTILITY_DEBUG_HPP
#endif // UTILITY_DEBUG_HPP
include/ck/utility/dynamic_buffer.hpp
View file @
9a7fa123
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "amd_buffer_addressing.hpp"
#include "amd_buffer_addressing.hpp"
#include "generic_memory_space_atomic_add.hpp"
#include "generic_memory_space_atomic_add.hpp"
#ifndef CK_NOGPU
namespace
ck
{
namespace
ck
{
// T may be scalar or vector
// T may be scalar or vector
...
@@ -351,3 +352,4 @@ make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element
...
@@ -351,3 +352,4 @@ make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element
}
}
}
// namespace ck
}
// namespace ck
#endif
include/ck/utility/generic_memory_space_atomic_add.hpp
View file @
9a7fa123
#pragma once
#pragma once
#ifndef CK_NOGPU
#include "data_type.hpp"
#include "data_type.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -42,3 +43,4 @@ __device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x)
...
@@ -42,3 +43,4 @@ __device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x)
}
}
}
// namespace ck
}
// namespace ck
#endif
include/ck/utility/get_id.hpp
View file @
9a7fa123
#pragma once
#pragma once
#include "config.hpp"
#include "config.hpp"
#ifndef CK_NOGPU
namespace
ck
{
namespace
ck
{
__host__
__device__
constexpr
index_t
get_warp_size
()
__host__
__device__
constexpr
index_t
get_warp_size
()
...
@@ -18,3 +19,4 @@ __device__ index_t get_block_1d_id() { return blockIdx.x; }
...
@@ -18,3 +19,4 @@ __device__ index_t get_block_1d_id() { return blockIdx.x; }
__device__
index_t
get_grid_size
()
{
return
gridDim
.
x
;
}
__device__
index_t
get_grid_size
()
{
return
gridDim
.
x
;
}
}
// namespace ck
}
// namespace ck
#endif
\ No newline at end of file
include/ck/utility/inner_product.hpp
View file @
9a7fa123
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#define CK_INNER_PRODUCT_HPP
#define CK_INNER_PRODUCT_HPP
#include "data_type.hpp"
#include "data_type.hpp"
#ifndef CK_NOGPU
namespace
ck
{
namespace
ck
{
template
<
typename
TA
,
typename
TB
,
typename
TC
>
template
<
typename
TA
,
typename
TB
,
typename
TC
>
...
@@ -203,3 +203,4 @@ inner_product<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a, const int8x16_t
...
@@ -203,3 +203,4 @@ inner_product<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a, const int8x16_t
}
// namespace ck
}
// namespace ck
#endif
#endif
#endif
\ No newline at end of file
include/ck/utility/magic_division.hpp
View file @
9a7fa123
...
@@ -118,7 +118,7 @@ struct MagicDivision
...
@@ -118,7 +118,7 @@ struct MagicDivision
{
{
return
CalculateMagicShift
(
integral_constant
<
uint32_t
,
Divisor
>
{});
return
CalculateMagicShift
(
integral_constant
<
uint32_t
,
Divisor
>
{});
}
}
#ifndef CK_NOGPU
// magic division for uint32_t
// magic division for uint32_t
__device__
static
constexpr
uint32_t
__device__
static
constexpr
uint32_t
DoMagicDivision
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
DoMagicDivision
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
...
@@ -126,7 +126,7 @@ struct MagicDivision
...
@@ -126,7 +126,7 @@ struct MagicDivision
uint32_t
tmp
=
__umulhi
(
dividend
,
multiplier
);
uint32_t
tmp
=
__umulhi
(
dividend
,
multiplier
);
return
(
tmp
+
dividend
)
>>
shift
;
return
(
tmp
+
dividend
)
>>
shift
;
}
}
#endif
__host__
static
constexpr
uint32_t
__host__
static
constexpr
uint32_t
DoMagicDivision
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
DoMagicDivision
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
{
{
...
@@ -138,6 +138,7 @@ struct MagicDivision
...
@@ -138,6 +138,7 @@ struct MagicDivision
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
// TODO: figure out how to do magic number divison for int32_t as dividended
#ifndef CK_NOGPU
__device__
static
constexpr
int32_t
__device__
static
constexpr
int32_t
DoMagicDivision
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
DoMagicDivision
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
{
{
...
@@ -145,6 +146,7 @@ struct MagicDivision
...
@@ -145,6 +146,7 @@ struct MagicDivision
uint32_t
tmp
=
__umulhi
(
dividend_u32
,
multiplier
);
uint32_t
tmp
=
__umulhi
(
dividend_u32
,
multiplier
);
return
(
tmp
+
dividend_u32
)
>>
shift
;
return
(
tmp
+
dividend_u32
)
>>
shift
;
}
}
#endif
__host__
static
constexpr
int32_t
__host__
static
constexpr
int32_t
DoMagicDivision
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
DoMagicDivision
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment