Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
769d992c
Commit
769d992c
authored
Jun 23, 2020
by
Chao Liu
Browse files
nvidia build
parent
820320ef
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
120 additions
and
58 deletions
+120
-58
CMakeLists.txt
CMakeLists.txt
+8
-3
composable_kernel/include/utility/config.nvidia.hpp.in
composable_kernel/include/utility/config.nvidia.hpp.in
+3
-4
composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in
..._kernel/include/utility/in_memory_operation.nvidia.hpp.in
+81
-31
composable_kernel/include/utility/synchronization.nvidia.hpp.in
...able_kernel/include/utility/synchronization.nvidia.hpp.in
+13
-0
driver/include/device.hpp
driver/include/device.hpp
+1
-3
driver/src/device.cpp
driver/src/device.cpp
+4
-5
external/half/include/half.hpp
external/half/include/half.hpp
+0
-0
external/rocm/include/bfloat16_dev.hpp
external/rocm/include/bfloat16_dev.hpp
+0
-0
script/cmake-cuda.sh
script/cmake-cuda.sh
+8
-11
script/cmake-rocm3.5.sh
script/cmake-rocm3.5.sh
+2
-1
No files found.
CMakeLists.txt
View file @
769d992c
...
...
@@ -44,7 +44,6 @@ if(DEVICE_BACKEND STREQUAL "AMD")
find_package
(
HIP REQUIRED
)
elseif
(
DEVICE_BACKEND STREQUAL
"NVIDIA"
)
enable_language
(
CUDA
)
include_directories
(
BEFORE
${
CUDA_COMMON_INCLUDE_DIR
}
)
endif
()
#
...
...
@@ -54,11 +53,17 @@ include_directories(BEFORE
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/tensor_description
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/tensor_operation
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/kernel_algorithm
${
PROJECT_SOURCE_DIR
}
/external/include
${
PROJECT_SOURCE_DIR
}
/external/
half/
include
${
PROJECT_SOURCE_DIR
}
/driver/include
${
PROJECT_BINARY_DIR
}
/composable_kernel/include/utility
)
if
(
DEVICE_BACKEND STREQUAL
"AMD"
)
include_directories
(
BEFORE
${
PROJECT_SOURCE_DIR
}
/external/rocm/include
)
endif
()
if
(
DEVICE_BACKEND STREQUAL
"AMD"
)
configure_file
(
"
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/utility/config.amd.hpp.in"
"
${
PROJECT_BINARY_DIR
}
/composable_kernel/include/utility/config.hpp"
)
configure_file
(
"
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/utility/float_type.amd.hpp.in"
"
${
PROJECT_BINARY_DIR
}
/composable_kernel/include/utility/float_type.hpp"
)
...
...
composable_kernel/include/utility/config.nvidia.hpp.in
View file @
769d992c
#ifndef CK_CONFIG_NVIDIA_HPP
#define CK_CONFIG_NVIDIA_HPP
#include "cuda_runtime.h"
#include "cuda_fp16.h"
#include "nvToolsExt.h"
#include "helper_cuda.h"
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <nvToolsExt.h>
// index type: unsigned or signed
#define CK_UNSIGNED_INDEX_TYPE 0
...
...
composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in
View file @
769d992c
...
...
@@ -3,56 +3,106 @@
namespace ck {
template <typename T,
index_t DataPerAccess,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace>
__device__ void copy_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
template <typename T>
__device__ void atomic_add_impl(T* p_dst, T src)
{
atomicAdd(p_dst, src);
}
// atomicAdd for float does not support vector type
template <>
__device__ void atomic_add_impl<float2_t>(float2_t* p_dst, float2_t src)
{
float* p_dst_float = reinterpret_cast<float*>(p_dst);
const float* p_src_float = reinterpret_cast<const float*>(&src);
for(index_t i = 0; i < 2; ++i)
{
atomicAdd(&(p_dst_float[i]), p_src_float[i]);
}
}
template <>
__device__ void atomic_add_impl<float4_t>(float4_t* p_dst, float4_t src)
{
float* p_dst_float = reinterpret_cast<float*>(p_dst);
const float* p_src_float = reinterpret_cast<const float*>(&src);
for(index_t i = 0; i < 4; ++i)
{
atomicAdd(&(p_dst_float[i]), p_src_float[i]);
}
}
template <typename T, index_t DataPerAccess>
struct SetData
{
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
__device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const
{
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
}
}
};
template <typename T,
index_t DataPerAccess,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace>
__device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
template <typename T, index_t DataPerAccess>
struct AtomicAddData
{
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
static_if<SrcAddressSpace == AddressSpace::Vgpr &&
DstAddressSpace == AddressSpace::Global>{}([&](auto) {
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
__device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const
{
atomic_add_impl(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
}).Else([&](auto fwd) {
static_assert(fwd(false), "atomic_add doesn't support this memory space");
});
}
}
};
template <typename T,
index_t DataPerAccess,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp>
InMemoryDataOperation DstInMemOp,
index_t SrcDataStride = 1,
index_t DstDataStride = 1>
__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
{
static_assert(DstInMemOp == InMemoryDataOperation::Set ||
DstInMemOp == InMemoryDataOperation::AtomicAdd,
"wrong! InMemoryDataOperation not supported!");
// keep it simple, don't use static_if here, otherwise compiler will do weird things
if(SrcDataStride == 1 && DstDataStride == 1)
{
// TODO: use static_if::ElseIf
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
copy_d
ata<T, DataPerAccess
,
SrcAddressSpace, DstAddressSpace>(
SetD
ata<T, DataPerAccess
>{}.template Run<
SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, p_dst, dst_offset);
});
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
atomic_add_d
ata<T, DataPerAccess
,
SrcAddressSpace, DstAddressSpace>(
AtomicAddD
ata<T, DataPerAccess
>{}.template Run<
SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, p_dst, dst_offset);
});
}
else
{
for(index_t i = 0; i < DataPerAccess; i++)
{
// TODO: use static_if::ElseIf
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
SetData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset + i * SrcDataStride, p_dst, dst_offset + i * DstDataStride);
});
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
AtomicAddData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset + i * SrcDataStride, p_dst, dst_offset + i * DstDataStride);
});
}
}
}
} // namespace ck
...
...
composable_kernel/include/utility/synchronization.nvidia.hpp.in
0 → 100644
View file @
769d992c
#ifndef CK_SYNCHRONIZATION_NVIDIA_HPP
#define CK_SYNCHRONIZATION_NVIDIA_HPP
#include "config.hpp"
namespace ck {
__device__ void block_sync_lds() { __syncthreads(); }
__device__ void block_sync_lds_vmem() { __syncthreads(); }
} // namespace ck
#endif
driver/include/device.hpp
View file @
769d992c
...
...
@@ -60,7 +60,7 @@ float launch_and_time_kernel(F kernel,
timer
.
End
();
hipGetErrorString
(
hipGetLastError
()
)
;
hipGetLastError
();
return
timer
.
GetElapsedTime
();
}
...
...
@@ -101,8 +101,6 @@ float launch_and_time_kernel(F kernel,
timer
.
End
();
checkCudaErrors
(
error
);
return
timer
.
GetElapsedTime
();
}
#endif
...
...
driver/src/device.cpp
View file @
769d992c
...
...
@@ -6,7 +6,7 @@ DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
#if CK_DEVICE_BACKEND_AMD
hipGetErrorString
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
#elif CK_DEVICE_BACKEND_NVIDIA
checkCudaErrors
(
cudaMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
)
)
;
cudaMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
);
#endif
}
...
...
@@ -18,8 +18,7 @@ void DeviceMem::ToDevice(const void* p)
hipGetErrorString
(
hipMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
mMemSize
,
hipMemcpyHostToDevice
));
#elif CK_DEVICE_BACKEND_NVIDIA
checkCudaErrors
(
cudaMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
mMemSize
,
cudaMemcpyHostToDevice
));
cudaMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
mMemSize
,
cudaMemcpyHostToDevice
);
#endif
}
...
...
@@ -28,7 +27,7 @@ void DeviceMem::FromDevice(void* p)
#if CK_DEVICE_BACKEND_AMD
hipGetErrorString
(
hipMemcpy
(
p
,
mpDeviceBuf
,
mMemSize
,
hipMemcpyDeviceToHost
));
#elif CK_DEVICE_BACKEND_NVIDIA
checkCudaErrors
(
cudaMemcpy
(
p
,
mpDeviceBuf
,
mMemSize
,
cudaMemcpyDeviceToHost
)
)
;
cudaMemcpy
(
p
,
mpDeviceBuf
,
mMemSize
,
cudaMemcpyDeviceToHost
);
#endif
}
...
...
@@ -37,7 +36,7 @@ DeviceMem::~DeviceMem()
#if CK_DEVICE_BACKEND_AMD
hipGetErrorString
(
hipFree
(
mpDeviceBuf
));
#elif CK_DEVICE_BACKEND_NVIDIA
checkCudaErrors
(
cudaFree
(
mpDeviceBuf
)
)
;
cudaFree
(
mpDeviceBuf
);
#endif
}
...
...
external/include/half.hpp
→
external/
half/
include/half.hpp
View file @
769d992c
File moved
external/include/bfloat16_dev.hpp
→
external/
rocm/
include/bfloat16_dev.hpp
View file @
769d992c
File moved
script/cmake-cuda.sh
View file @
769d992c
#!/bin/bash
MY_PROJECT_SOURCE
=
../
MY_PROJECT_INSTALL
=
../install.dir
MY_PROJECT_SOURCE
=
../../../
export
CUDA_ROOT
=
/usr/local/cuda
export
CPATH
=
$CPATH
:
$CUDA_ROOT
/include
export
LIBRARY_PATH
=
$LIBRARY_PATH
:
$CUDA_ROOT
/lib64
export
LD_LIBRARY_PATH
=
$LD_LIBRARY_PATH
:
$CUDA_ROOT
/lib64
export
CUDA_ROOT
=
/usr/local/cuda
export
CPATH
=
$CPATH
:
$CUDA_ROOT
/include
export
LIBRARY_PATH
=
$LIBRARY_PATH
:
$CUDA_ROOT
/lib64
export
LD_LIBRARY_PATH
=
$LD_LIBRARY_PATH
:
$CUDA_ROOT
/lib64
cmake
\
-D
CMAKE_INSTALL_PREFIX
=
${
MY_PROJECT_INSTALL
}
\
-D
CMAKE_CXX_COMPILER
=
clang++-6.0
\
-D
CMAKE_CXX_COMPILER
=
clang++
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
DEVICE_BACKEND
=
NVIDIA
\
-D
CUDA_COMMON_INCLUDE_DIR
=
"/root/NVIDIA_CUDA-10.1_Samples/common/inc"
\
-D
CMAKE_CUDA_FLAGS
=
"-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_70,code=sm_70 -Xptxas -v -maxrregcount=128"
\
-D
CMAKE_CUDA_FLAGS
=
"-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -maxrregcount=128"
\
${
MY_PROJECT_SOURCE
}
...
...
script/cmake-rocm3.5.sh
View file @
769d992c
...
...
@@ -10,11 +10,12 @@ cmake
-D
CMAKE_INSTALL_PREFIX
=
${
MY_PROJECT_INSTALL
}
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
DEVICE_BACKEND
=
"AMD"
\
-D
CMAKE_CXX_FLAGS
=
"-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0
-save-temps
"
\
-D
CMAKE_CXX_FLAGS
=
"-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0"
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_PREFIX_PATH
=
"/opt/rocm"
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
${
MY_PROJECT_SOURCE
}
#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0" \
#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0 -save-temps" \
#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0 -v -gline-tables-only -save-temps" \
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment