/* Copyright 2025 SGLang Team. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #pragma once #include #ifndef USE_ROCM #include #endif #include #include struct cuda_error : public std::runtime_error { /** * @brief Constructs a `cuda_error` object with the given `message`. * * @param message The error char array used to construct `cuda_error` */ cuda_error(const char* message) : std::runtime_error(message) {} /** * @brief Constructs a `cuda_error` object with the given `message` string. * * @param message The `std::string` used to construct `cuda_error` */ cuda_error(std::string const& message) : cuda_error{message.c_str()} {} }; #define CHECK_CUDA_SUCCESS(cmd) \ do { \ cudaError_t e = cmd; \ if (e != cudaSuccess) { \ std::stringstream _message; \ auto s = cudaGetErrorString(e); \ _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ throw cuda_error(_message.str()); \ } \ } while (0) #define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_CUDA_INPUT(x) \ CHECK_IS_CUDA(x); \ CHECK_IS_CONTIGUOUS(x) inline int getSMVersion() { int device{-1}; CHECK_CUDA_SUCCESS(cudaGetDevice(&device)); int sm_major = 0; int sm_minor = 0; CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); return sm_major * 10 + sm_minor; } // SGLANG_SHFL_XOR_* adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/cuda_compat.h#L19-L28 #ifndef USE_ROCM #define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor_sync((mask), (var), (lane_mask)) #define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor_sync((mask), (var), (lane_mask), (width)) #else #define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor((var), (lane_mask)) #define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width)) #endif #ifndef USE_ROCM #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ case at::ScalarType::Float: { \ using c_type = float; \ return __VA_ARGS__(); \ } \ _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ default: \ std::ostringstream oss; \ oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ TORCH_CHECK(false, oss.str()); \ return false; \ } \ }() #endif #define DISPATCH_CASE_INTEGRAL_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) #define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) #define CEILDIV(x, y) (((x) + (y) - 1) / (y)) #define WARP_SIZE 32 #ifndef USE_ROCM #include using FP8_TYPE = c10::Float8_e4m3fn; C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); #else #include using FP8_TYPE = c10::Float8_e4m3fnuz; constexpr auto FP8_E4M3_MAX = 224.0f; #endif #ifndef USE_ROCM __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { float old; old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); return old; } __device__ __forceinline__ float warpReduceMax(float max_value) { max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 16)); max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 8)); max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 4)); max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 2)); max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 1)); return max_value; } __device__ __forceinline__ float blockReduceMax(float max_value) { static __shared__ float warpLevelMaxs[WARP_SIZE]; const int laneId = threadIdx.x % WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE; max_value = warpReduceMax(max_value); if (laneId == 0) warpLevelMaxs[warpId] = max_value; __syncthreads(); max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0; if (warpId == 0) max_value = warpReduceMax(max_value); return max_value; } #endif // Pads to a multiple of `alignment` rows. inline torch::Tensor pad_tensor(const torch::Tensor& tensor, int64_t alignment = 4, bool is_column_major = false) { int64_t rows = tensor.size(0); int64_t cols = tensor.size(1); int64_t pad_rows = (alignment - (rows % alignment)) % alignment; // Compute padding size if (pad_rows == 0) { return tensor; // Already aligned } torch::Tensor padding = torch::zeros({pad_rows, cols}, tensor.options()); torch::Tensor tensor_padded = torch::cat({tensor, padding}, 0); // Pad along rows // Ensure column-major layout if (is_column_major) { return tensor_padded.t().contiguous().t(); } return tensor_padded; }