Unverified Commit aa3eba8e authored by PGFLMG's avatar PGFLMG Committed by GitHub
Browse files

[sgl-kernel] misc: update deepgemm version for sgl-kernel (#9340)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
Co-authored-by: default avatarfzyzcjy <ch271828n@outlook.com>
parent 07ee0ab7
......@@ -23,7 +23,6 @@ limitations under the License.
#ifndef USE_ROCM
#include <cub/cub.cuh>
#include <cub/util_type.cuh>
#include <cuda/functional>
#else
#include <hipcub/hipcub.hpp>
#include <hipcub/util_type.hpp>
......@@ -34,16 +33,6 @@ limitations under the License.
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
// Define reduction operators based on CUDA version
// CUDA 13 (12.9+) deprecated cub::Max/Min in favor of cuda::maximum/minimum
#if CUDA_VERSION >= 12090
using MaxReduceOp = cuda::maximum<>;
using MinReduceOp = cuda::minimum<>;
#else
using MaxReduceOp = cub::Max;
using MinReduceOp = cub::Min;
#endif
/// Aligned array type
template <
typename T,
......@@ -83,6 +72,7 @@ __launch_bounds__(TPB) __global__
const int thread_row_offset = blockIdx.x * num_cols;
cub::Sum sum;
float threadData(-FLT_MAX);
// Don't touch finished rows.
......@@ -95,7 +85,7 @@ __launch_bounds__(TPB) __global__
threadData = max(convert_to_float<T>(input[idx]), threadData);
}
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp());
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0) {
float_max = maxElem;
......@@ -109,7 +99,7 @@ __launch_bounds__(TPB) __global__
threadData += exp((convert_to_float<T>(input[idx]) - float_max));
}
const auto Z = BlockReduce(tmpStorage).Sum(threadData);
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z;
......
......@@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build"
[project]
name = "sgl-kernel"
version = "0.3.6.post2"
version = "0.3.7"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.10"
......
......@@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build"
[project]
name = "sgl-kernel"
version = "0.3.6.post2"
version = "0.3.7"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.10"
......
......@@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "sgl-kernel"
version = "0.3.6.post2"
version = "0.3.7"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.10"
......
__version__ = "0.3.6.post2"
__version__ = "0.3.7"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment