Unverified Commit 9e366482 authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

[Fix] Fix building with CUDA 11.3 (#280)

* disable cache hint for CUDA < 11.4

* fix lint

* fix lint

* fix cuda-11.3 build
parent 06327355
...@@ -293,10 +293,4 @@ struct Shape { ...@@ -293,10 +293,4 @@ struct Shape {
} }
}; };
template<int... Ns>
Shape(std::integral_constant<int, Ns>...) -> Shape<Ns...>;
template<int... Ns>
inline constexpr Shape<Ns...> shape_c{};
} // namespace turbomind } // namespace turbomind
...@@ -7,20 +7,28 @@ ...@@ -7,20 +7,28 @@
namespace turbomind { namespace turbomind {
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
#define L2_CACHEHINT(size) ".L2::" #size "B"
#else
#define L2_CACHEHINT(size)
#endif
template<typename T> template<typename T>
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask) __inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask)
{ {
#if TURBOMIND_ARCH_SM80 #if TURBOMIND_ARCH_SM80
constexpr int cp_size = sizeof(T); constexpr int cp_size = sizeof(T);
static_assert(cp_size == 16, "cp.async.cg requreis cp_size == 16"); static_assert(cp_size == 16, "cp.async.cg requreis cp_size == 16");
// clang-format off
asm volatile("{\n" asm volatile("{\n"
" .reg .pred p;\n" " .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n" " setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global.L2::256B [%1], [%2], %3;\n" " @p cp.async.cg.shared.global" L2_CACHEHINT(256) " [%1], [%2], %3;\n"
"}\n" ::"r"((int)mask), "}\n" ::"r"((int)mask),
"r"(smem_int_ptr), "r"(smem_int_ptr),
"l"(src), "l"(src),
"n"(cp_size)); "n"(cp_size));
// clang-format on
#else #else
assert(TURBOMIND_ARCH_SM80); assert(TURBOMIND_ARCH_SM80);
#endif #endif
...@@ -32,14 +40,16 @@ __inline__ __device__ void cp_async_cg_B(uint32_t smem_int_ptr, const T* __restr ...@@ -32,14 +40,16 @@ __inline__ __device__ void cp_async_cg_B(uint32_t smem_int_ptr, const T* __restr
#if TURBOMIND_ARCH_SM80 #if TURBOMIND_ARCH_SM80
constexpr int cp_size = sizeof(T); constexpr int cp_size = sizeof(T);
static_assert(cp_size == 16, "cp.async.cg requreis cp_size == 16"); static_assert(cp_size == 16, "cp.async.cg requreis cp_size == 16");
// clang-format off
asm volatile("{\n" asm volatile("{\n"
" .reg .pred p;\n" " .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n" " setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" " @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n"
"}\n" ::"r"((int)mask), "}\n" ::"r"((int)mask),
"r"(smem_int_ptr), "r"(smem_int_ptr),
"l"(src), "l"(src),
"n"(cp_size)); "n"(cp_size));
// clang-format on
#else #else
assert(TURBOMIND_ARCH_SM80); assert(TURBOMIND_ARCH_SM80);
#endif #endif
...@@ -50,14 +60,16 @@ __inline__ __device__ void cp_async_ca(uint32_t smem_int_ptr, const T* __restric ...@@ -50,14 +60,16 @@ __inline__ __device__ void cp_async_ca(uint32_t smem_int_ptr, const T* __restric
{ {
#if TURBOMIND_ARCH_SM80 #if TURBOMIND_ARCH_SM80
constexpr int cp_size = sizeof(T); constexpr int cp_size = sizeof(T);
// clang-format off
asm volatile("{\n" asm volatile("{\n"
" .reg .pred p;\n" " .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n" " setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;\n" " @p cp.async.ca.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n"
"}\n" ::"r"((int)mask), "}\n" ::"r"((int)mask),
"r"(smem_int_ptr), "r"(smem_int_ptr),
"l"(src), "l"(src),
"n"(cp_size)); "n"(cp_size));
// clang-format on
#else #else
assert(TURBOMIND_ARCH_SM80); assert(TURBOMIND_ARCH_SM80);
#endif #endif
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "src/turbomind/utils/cuda_fp8_utils.h" #include "src/turbomind/utils/cuda_fp8_utils.h"
#ifndef CUDART_VERSION #ifndef CUDART_VERSION
#error CUDART_VERSION Undefined! #error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11050) #elif (CUDART_VERSION >= 11000)
#include <cub/cub.cuh> #include <cub/cub.cuh>
#else #else
#include "3rdparty/cub/cub.cuh" #include "3rdparty/cub/cub.cuh"
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#ifndef CUDART_VERSION #ifndef CUDART_VERSION
#error CUDART_VERSION Undefined! #error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11050) #elif (CUDART_VERSION >= 11000)
#include <cub/cub.cuh> #include <cub/cub.cuh>
#else #else
#include "3rdparty/cub/cub.cuh" #include "3rdparty/cub/cub.cuh"
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <stdexcept> #include <stdexcept>
#ifndef CUDART_VERSION #ifndef CUDART_VERSION
#error CUDART_VERSION Undefined! #error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11050) #elif (CUDART_VERSION >= 11000)
#include <cub/cub.cuh> #include <cub/cub.cuh>
#else #else
#include "3rdparty/cub/cub.cuh" #include "3rdparty/cub/cub.cuh"
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#ifndef CUDART_VERSION #ifndef CUDART_VERSION
#error CUDART_VERSION Undefined! #error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11050) #elif (CUDART_VERSION >= 11000)
#include <cub/cub.cuh> #include <cub/cub.cuh>
#else #else
#include "3rdparty/cub/cub.cuh" #include "3rdparty/cub/cub.cuh"
......
...@@ -115,7 +115,7 @@ __global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data, ...@@ -115,7 +115,7 @@ __global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data,
constexpr int PACK_DIM = sizeof(uint4) / sizeof(T); constexpr int PACK_DIM = sizeof(uint4) / sizeof(T);
const auto batch_idx = grid.block_rank(); const auto batch_idx = block.group_index().x;
uint4* __restrict__ r_ptr = reinterpret_cast<uint4*>(r_data + batch_idx * n_dims); uint4* __restrict__ r_ptr = reinterpret_cast<uint4*>(r_data + batch_idx * n_dims);
uint4* __restrict__ x_ptr = reinterpret_cast<uint4*>(x_data + batch_idx * n_dims); uint4* __restrict__ x_ptr = reinterpret_cast<uint4*>(x_data + batch_idx * n_dims);
const uint4* __restrict__ b_ptr = reinterpret_cast<const uint4*>(bias); const uint4* __restrict__ b_ptr = reinterpret_cast<const uint4*>(bias);
...@@ -123,7 +123,7 @@ __global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data, ...@@ -123,7 +123,7 @@ __global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data,
res_norm_t<T> ops; res_norm_t<T> ops;
float thread_sum{}; float thread_sum{};
for (auto i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.num_threads()) { for (auto i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.size()) {
auto r = r_ptr[i]; auto r = r_ptr[i];
auto x = x_ptr[i]; auto x = x_ptr[i];
uint4 b = b_ptr ? b_ptr[i] : uint4{}; uint4 b = b_ptr ? b_ptr[i] : uint4{};
...@@ -136,7 +136,7 @@ __global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data, ...@@ -136,7 +136,7 @@ __global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data,
float s_inv_mean = rsqrt(total_sum / n_dims + eps); float s_inv_mean = rsqrt(total_sum / n_dims + eps);
const uint4* __restrict__ s_ptr = reinterpret_cast<const uint4*>(scale); const uint4* __restrict__ s_ptr = reinterpret_cast<const uint4*>(scale);
for (uint i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.num_threads()) { for (uint i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.size()) {
auto r = r_ptr[i]; auto r = r_ptr[i];
auto s = s_ptr[i]; auto s = s_ptr[i];
auto o = ops.normvec(r, s, s_inv_mean); auto o = ops.normvec(r, s, s_inv_mean);
......
...@@ -118,10 +118,13 @@ endif() ...@@ -118,10 +118,13 @@ endif()
set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR}) set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR})
target_compile_definitions(triton-turbomind-backend target_compile_definitions(triton-turbomind-backend PUBLIC
PUBLIC USE_TRITONSERVER_DATATYPE)
USE_TRITONSERVER_DATATYPE
BUILD_MULTI_GPU) if (BUILD_MULTI_GPU)
target_compile_definitions(triton-turbomind-backend PUBLIC
BUILD_MULTI_GPU)
endif ()
target_include_directories( target_include_directories(
triton-turbomind-backend triton-turbomind-backend
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment