/* Copyright 2025 SGLang Team. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #pragma once #include #ifndef USE_ROCM #include #endif #include #include struct cuda_error : public std::runtime_error { /** * @brief Constructs a `cuda_error` object with the given `message`. * * @param message The error char array used to construct `cuda_error` */ cuda_error(const char* message) : std::runtime_error(message) {} /** * @brief Constructs a `cuda_error` object with the given `message` string. * * @param message The `std::string` used to construct `cuda_error` */ cuda_error(std::string const& message) : cuda_error{message.c_str()} {} }; #define CHECK_CUDA_SUCCESS(cmd) \ do { \ cudaError_t e = cmd; \ if (e != cudaSuccess) { \ std::stringstream _message; \ auto s = cudaGetErrorString(e); \ _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ throw cuda_error(_message.str()); \ } \ } while (0) #define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_CUDA_INPUT(x) \ CHECK_IS_CUDA(x); \ CHECK_IS_CONTIGUOUS(x) inline int getSMVersion() { int device{-1}; CHECK_CUDA_SUCCESS(cudaGetDevice(&device)); int sm_major = 0; int sm_minor = 0; CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); return sm_major * 10 + sm_minor; } #ifndef USE_ROCM #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ case at::ScalarType::Float: { \ using c_type = float; \ return __VA_ARGS__(); \ } \ _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ default: \ std::ostringstream oss; \ oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ TORCH_CHECK(false, oss.str()); \ return false; \ } \ }() #endif #define DISPATCH_CASE_INTEGRAL_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) #define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) #define CEILDIV(x, y) (((x) + (y) - 1) / (y)) #define WARP_SIZE 32 #ifndef USE_ROCM #include using FP8_TYPE = c10::Float8_e4m3fn; C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); #else #include using FP8_TYPE = c10::Float8_e4m3fnuz; constexpr auto FP8_E4M3_MAX = 224.0f; #endif #ifndef USE_ROCM __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { float old; old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); return old; } __device__ __forceinline__ float warpReduceMax(float max_value) { max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16)); max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8)); max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4)); max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2)); max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1)); return max_value; } __device__ __forceinline__ float blockReduceMax(float max_value) { static __shared__ float warpLevelMaxs[WARP_SIZE]; const int laneId = threadIdx.x % WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE; max_value = warpReduceMax(max_value); if (laneId == 0) warpLevelMaxs[warpId] = max_value; __syncthreads(); max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0; if (warpId == 0) max_value = warpReduceMax(max_value); return max_value; } #endif