Unverified Commit af4b9bae authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

[AMD] Add silu_and_mul, gelu_and_mul, gelu_tanh_and_mul, and gelu_quick...


[AMD] Add silu_and_mul, gelu_and_mul, gelu_tanh_and_mul, and gelu_quick kernels for AMD GPUs (#7135)
Co-authored-by: default avataryiakwy-xpu-ml-framework-team <961186938@qq.com>
Co-authored-by: default avatarHAI <hixiao@gmail.com>
parent 7ad6b766
......@@ -33,6 +33,7 @@ from sglang.srt.utils import (
cpu_has_amx_support,
is_cpu,
is_cuda,
is_hip,
is_npu,
set_weight_attrs,
)
......@@ -42,9 +43,12 @@ _is_cuda = is_cuda()
_is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_is_hip = is_hip()
if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
elif _is_hip:
from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
if is_npu():
import torch_npu
......@@ -126,9 +130,13 @@ class QuickGELU(CustomOp):
return x * torch.sigmoid(1.702 * x)
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
# TODO(zhyncs): Implement the CUDA kernel for QuickGELU in sgl-kernel
return self.forward_native(x)
def forward_hip(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
gelu_quick(x, out)
return out
class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
......@@ -222,8 +230,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
return nn.Identity()
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
logger.info(
"sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
"sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
)
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
......@@ -3,9 +3,12 @@ import unittest
import torch
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.activation import GeluAndMul, QuickGELU
from sglang.srt.utils import is_hip
from sglang.test.test_utils import CustomTestCase
_is_hip = is_hip()
class TestGeluAndMul(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
......@@ -52,5 +55,51 @@ class TestGeluAndMul(CustomTestCase):
self._run_gelu_and_mul_test(*params)
class TestQuickGELU(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
NUM_TOKENS = [7, 83, 2048] # batch = sequence length
DIMS = [512, 4096, 5120, 13824] # all multiples of 16 bytes
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _run_gelu_quick_test(self, n_tok: int, dim: int, dtype: torch.dtype, seed: int):
torch.manual_seed(seed)
layer = QuickGELU().to(dtype=dtype)
x = torch.randn(n_tok, dim, dtype=dtype, device="cuda")
with torch.inference_mode():
ref = layer.forward_native(x) # x * sigmoid(1.702 * x), fp32 math
if _is_hip:
out = layer.forward_hip(x) # 128-bit vectorised kernel from sgl-kernel
else:
out = layer.forward_cuda(x)
tol = 1e-2 if dtype is torch.bfloat16 else 1e-3
self.assertTrue(
torch.allclose(out, ref, atol=tol, rtol=tol),
msg=f"Mismatch @ B={n_tok}, D={dim}, dtype={dtype}",
)
print(f"Match @ B={n_tok}, D={dim}, dtype={dtype}")
def test_quick_gelu(self):
for params in itertools.product(
self.NUM_TOKENS, self.DIMS, self.DTYPES, self.SEEDS
):
with self.subTest(
num_tokens=params[0],
dim=params[1],
dtype=params[2],
seed=params[3],
):
self._run_gelu_quick_test(*params)
if __name__ == "__main__":
unittest.main(verbosity=2)
# Benchmarks SGLang kernels versus vLLM across
# (kernel, dtype, batch_size, seq_len, dim) and prints speed-up.
import argparse
import itertools
import re
from typing import List, Tuple
import sgl_kernel
import torch
import torch.nn.functional as F
import triton
import triton.testing
from sgl_kernel import gelu_quick # activation-only kernel
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm import _custom_ops as vllm_ops
if not hasattr(vllm_ops, "silu_and_mul"):
vllm_ops = torch.ops._C
def str2int_list(arg: str) -> List[int]:
if arg in ("", None):
return []
if re.fullmatch(r"\d+(,\d+)*", arg.strip()) is None:
raise argparse.ArgumentTypeError(f"Bad int list: {arg}")
return [int(x) for x in arg.split(",")]
def calculate_diff(
kernel: str, dtype: torch.dtype, batch_size: int, seq_len: int, dim: int
) -> bool:
"""Compare vLLM with SGLang for one shape."""
device = torch.device("cuda")
# activation-only quick GELU
if kernel == "gelu_quick":
x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device)
ref_out = torch.zeros_like(x)
getattr(vllm_ops, kernel)(ref_out, x)
test_out = getattr(sgl_kernel, kernel)(x)
# fused activation x mul kernels
else:
x = torch.randn(batch_size, seq_len, 2 * dim, dtype=dtype, device=device)
ref_out = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
getattr(vllm_ops, kernel)(ref_out, x)
test_out = getattr(sgl_kernel, kernel)(x)
ok = torch.allclose(ref_out, test_out, rtol=1e-3, atol=1e-5)
tag = "✅ match" if ok else "❌ mismatch"
print(
f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
f"L={seq_len:3d} | D={dim:5d}] {tag}"
)
return ok
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", "gelu_quick"]
dtypes = [torch.float16, torch.bfloat16]
def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]:
return list(itertools.product(kernels, dtypes, bsizes, slens, dims_))
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
default_dims = [2**i for i in range(7, 15)] # 128...16384
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["kernel", "dtype", "batch_size", "seq_len", "dim"],
x_vals=[],
line_arg="provider",
line_vals=["vllm", "sglang", "speedup"],
line_names=["vLLM", "SGL Kernel", "Speed-up (x)"],
styles=[("blue", "-"), ("green", "-"), ("red", "--")],
ylabel="µs (median) or × (speed-up)",
plot_name="activation-performance",
args={},
)
)
def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
device = torch.device("cuda")
in_mult = 1 if kernel == "gelu_quick" else 2
x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device)
y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
vllm_kernel = getattr(vllm_ops, kernel)
sglang_kernel = getattr(sgl_kernel, kernel)
def baseline():
tmp = y0.clone()
vllm_kernel(tmp, x)
return tmp
def sglang():
return sglang_kernel(x)
# one-time correctness check
if provider == "vllm" and not calculate_diff(
kernel, dtype, batch_size, seq_len, dim
):
raise ValueError("Mismatch – abort benchmark")
# timing helper
def timed(fn):
for _ in range(5):
fn()
torch.cuda.synchronize()
ms, qmin, qmax = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
return 1000 * ms, 1000 * qmax, 1000 * qmin
if provider == "vllm":
return timed(baseline)
if provider == "sglang":
return timed(sglang)
# provider == "speedup"
t_ref, _, _ = timed(baseline)
t_sgl, _, _ = timed(sglang)
spd = t_ref / t_sgl
return (spd, spd, spd)
if __name__ == "__main__":
p = argparse.ArgumentParser("Activation kernel benchmark")
p.add_argument("--batch_sizes", type=str2int_list, default=default_batch_sizes)
p.add_argument("--seq_lens", type=str2int_list, default=default_seq_lens)
p.add_argument("--dims", type=str2int_list, default=default_dims)
p.add_argument("--verify_only", action="store_true")
args = p.parse_args()
# coerce lists
if isinstance(args.batch_sizes, str):
args.batch_sizes = str2int_list(args.batch_sizes)
if isinstance(args.seq_lens, str):
args.seq_lens = str2int_list(args.seq_lens)
if isinstance(args.dims, str):
args.dims = str2int_list(args.dims)
# patch perf_report grid
benchmark_grid = make_configs(args.batch_sizes, args.seq_lens, args.dims)
if hasattr(benchmark, "benchmarks"):
benchmark.benchmarks.x_vals = benchmark_grid
else:
benchmark.benchmark.x_vals = benchmark_grid
if args.verify_only:
ok = calculate_diff("gelu_quick", torch.float16, 1, 1, args.dims[0])
print("✅ sanity pass" if ok else "❌ mismatch")
else:
benchmark.run(print_data=True)
......@@ -78,13 +78,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()");
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
m.def(
......
......@@ -13,70 +13,158 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#ifndef USE_ROCM
#include <flashinfer/activation.cuh>
#include "pytorch_extension_utils.h"
#include "utils.h"
#else
#include "hip_act_and_mul.cuh"
#endif
// Adapted from flashinfer activation
// https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/csrc/activation.cu#L44
namespace detail {
template <typename T>
__device__ __forceinline__ float to_f32(const T& x) {
#if USE_ROCM
return castToFloat(x);
#else
return static_cast<float>(x);
#endif
}
template <typename T>
__device__ __forceinline__ T from_f32(float f32) {
#if USE_ROCM
return castFromFloat<T>(f32);
#else
return static_cast<T>(f32);
#endif
}
using namespace flashinfer;
} // namespace detail
__device__ __forceinline__ float silu(const float& val) {
return val / (1.0f + __expf(-val));
template <typename T>
__device__ __forceinline__ T silu(const T& x) {
float f32_val = detail::to_f32(x);
return detail::from_f32<T>(f32_val / (1.0f + expf(-f32_val)));
}
__device__ __forceinline__ float gelu(const float& val) {
template <typename T>
__device__ __forceinline__ T gelu(const T& x) {
constexpr float kAlpha = M_SQRT1_2;
return val * 0.5f * (1.0f + ::erf(val * kAlpha));
float f32_val = detail::to_f32(x);
return detail::from_f32<T>(f32_val * (0.5f * (1.0f + erf(f32_val * kAlpha))));
}
// gelu_quick(x) = x * torch.sigmoid(1.702 * x)
template <typename T>
__device__ __forceinline__ T gelu_quick_act(const T& x) {
float f32_val = detail::to_f32(x);
return detail::from_f32<T>(f32_val / (1.0f + expf(-f32_val * 1.702f)));
}
__device__ __forceinline__ float gelu_tanh(const float& val) {
const float cdf = 0.5f * (1.0f + math::tanh((0.7978845608028654f * (val + 0.044715f * val * val * val))));
return val * cdf;
template <typename T>
__device__ __forceinline__ T gelu_tanh(const T& x) {
constexpr float kAlpha = 0.044715f;
constexpr float kBeta = 0.7978845608028654f;
float f32_val = detail::to_f32(x);
const float cdf = 0.5f * (1.0f + tanhf((kBeta * (f32_val + kAlpha * f32_val * f32_val * f32_val))));
return detail::from_f32<T>(f32_val * cdf);
}
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream) {
void silu_and_mul(at::Tensor& out, at::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
#if USE_ROCM
sgl_hip::activation::act_and_mul_kernel<c_type, silu>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#else
flashinfer::activation::act_and_mul_kernel<c_type, silu>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#endif
return true;
});
}
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream) {
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
#if USE_ROCM
sgl_hip::activation::act_and_mul_kernel<c_type, gelu_tanh>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#else
flashinfer::activation::act_and_mul_kernel<c_type, gelu_tanh>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#endif
return true;
});
}
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream) {
void gelu_and_mul(at::Tensor& out, at::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
#if USE_ROCM
sgl_hip::activation::act_and_mul_kernel<c_type, gelu>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#else
flashinfer::activation::act_and_mul_kernel<c_type, gelu>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#endif
return true;
});
}
#if USE_ROCM
void gelu_quick(at::Tensor& out, const at::Tensor& input) {
int d = input.size(-1);
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
sgl_hip::activation::act_only_kernel<c_type, gelu_quick_act>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
return true;
});
}
#endif
......@@ -19,6 +19,20 @@ limitations under the License.
#include "sgl_kernel_ops.h"
TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From csrc/activation
*/
m.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
m.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
m.def("gelu_quick(Tensor! out, Tensor input) -> ()");
m.impl("gelu_quick", torch::kCUDA, &gelu_quick);
/*
* From csrc/allreduce
*/
......
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#pragma once
#include "utils.h"
#define kBitsToLoad 128
#define kBytesToLoad (kBitsToLoad / 8)
// Adapted from
// [flashinfer::activation::act_and_mul_kernel](https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/include/flashinfer/activation.cuh#L29)
namespace sgl_hip {
namespace activation {
template <typename T, T (*Activation)(const T&)>
__global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) {
constexpr uint32_t vec_size = kBytesToLoad / sizeof(T);
const int64_t token_idx = blockIdx.x;
const int64_t thread_idx = threadIdx.x;
const int64_t stride = blockDim.x;
const int64_t offset = token_idx * 2 * d;
#pragma unroll 1
for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) {
sgl_hip::vec_t<T, vec_size> x_vec, y_vec, out_vec;
x_vec.cast_load(input + offset + idx * vec_size);
y_vec.cast_load(input + offset + d + idx * vec_size);
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
out_vec[i] = Activation(x_vec[i]) * y_vec[i];
}
out_vec.cast_store(out + token_idx * d + idx * vec_size);
}
const int64_t remaining_offset = d - d % (stride * vec_size);
// process the remaining elements
#pragma unroll 1
for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) {
T x = input[offset + remaining_offset + idx], y = input[offset + remaining_offset + d + idx];
out[token_idx * d + remaining_offset + idx] = Activation(x) * y;
}
}
template <typename T, T (*Activation)(const T&)>
__global__ void act_only_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) {
constexpr uint32_t vec_size = kBytesToLoad / sizeof(T);
const int64_t token_idx = blockIdx.x;
const int64_t thread_idx = threadIdx.x;
const int64_t stride = blockDim.x;
const int64_t offset = token_idx * d;
#pragma unroll 1
for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) {
sgl_hip::vec_t<T, vec_size> x_vec, y_vec, out_vec;
x_vec.cast_load(input + offset + idx * vec_size);
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
out_vec[i] = Activation(x_vec[i]);
}
out_vec.cast_store(out + token_idx * d + idx * vec_size);
}
const int64_t remaining_offset = d - d % (stride * vec_size);
// process the remaining elements
#pragma unroll 1
for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) {
T x = input[offset + remaining_offset + idx];
out[token_idx * d + remaining_offset + idx] = Activation(x);
}
}
} // namespace activation
} // namespace sgl_hip
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#pragma once
#if defined(__HIP_PLATFORM_AMD__)
#include <hip/hip_bf16.h>
#include <hip/hip_common.h>
#include <hip/hip_fp16.h>
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
namespace amdgpu {
template <typename T>
__forceinline__ __device__ T shfl_xor_sync(unsigned mask, T var, int laneMask, int width = warpSize);
template <typename srcDtype, typename destDtype>
__forceinline__ __device__ destDtype cast(srcDtype val);
// specialization
template <>
__forceinline__ __device__ float shfl_xor_sync(unsigned mask, float var, int laneMask, int width) {
return __shfl_xor(var, laneMask, width);
}
template <>
__forceinline__ __device__ int shfl_xor_sync(unsigned mask, int var, int laneMask, int width) {
return __shfl_xor(var, laneMask, width);
}
template <>
__forceinline__ __device__ float cast(float val) {
return val;
}
template <>
__forceinline__ __device__ float cast(__half val) {
return __half2float(val);
}
template <>
__forceinline__ __device__ float cast(__hip_bfloat16 val) {
return __bfloat162float(val);
}
template <>
__forceinline__ __device__ __half cast(float fval) {
return __float2half(fval);
}
template <>
__forceinline__ __device__ __hip_bfloat16 cast(float fval) {
return __float2bfloat16(fval);
}
} // namespace amdgpu
template <typename T>
__forceinline__ __device__ T __shfl_xor_sync(unsigned mask, T var, int laneMask, int width = warpSize) {
return amdgpu::shfl_xor_sync(mask, var, laneMask, width);
}
template <typename srcDtype>
__device__ __forceinline__ float castToFloat(srcDtype val) {
return amdgpu::cast<srcDtype, float>(val);
}
template <typename dstDtype>
__device__ __forceinline__ dstDtype castFromFloat(float val) {
return amdgpu::cast<float, dstDtype>(val);
}
// operator overload to support flashinfer
__host__ __device__ __forceinline__ __half operator*(const __half& x, const __half& y) {
__half h_x = x;
__half h_y = y;
return __hmul(h_x, h_y);
}
#endif
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#pragma once
#if USE_ROCM
#include <hip/hip_bf16.h>
#include <hip/hip_common.h>
#include <hip/hip_fp16.h>
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)d
#define SGL_HIP_INLINE inline __attribute__((always_inline)) __device__
namespace sgl_hip {
template <typename float_t, size_t vec_size>
struct vec_t;
template <typename srcDtype, typename dstDtype, size_t vec_size>
SGL_HIP_INLINE void cast_load_impl(vec_t<dstDtype, vec_size>& dst, const srcDtype* src);
template <typename srcDtype, typename dstDtype, size_t vec_size>
SGL_HIP_INLINE void cast_store_impl(dstDtype* dst_ptr, const vec_t<srcDtype, vec_size>& src);
template <typename float_t, size_t vec_size>
struct vec_t {
SGL_HIP_INLINE float_t& operator[](size_t i);
SGL_HIP_INLINE const float_t& operator[](size_t i) const;
SGL_HIP_INLINE float_t* ptr();
SGL_HIP_INLINE void load(const float_t* ptr);
SGL_HIP_INLINE void store(float_t* ptr) const;
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, vec_size>& src);
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr);
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const;
};
} // namespace sgl_hip
// **** impl *****
namespace sgl_hip {
template <typename srcDtype, typename dstDtype, size_t vec_size>
SGL_HIP_INLINE void cast_load_impl(vec_t<dstDtype, vec_size>& dst, const srcDtype* src_ptr) {
if constexpr (std::is_same<srcDtype, dstDtype>::value) {
dst.load(src_ptr);
} else {
vec_t<srcDtype, vec_size> tmp;
tmp.load(src_ptr);
dst.cast_from(tmp);
}
}
template <typename srcDtype, typename dstDtype, size_t vec_size>
SGL_HIP_INLINE void cast_store_impl(dstDtype* dst_ptr, const vec_t<srcDtype, vec_size>& src) {
if constexpr (std::is_same<srcDtype, dstDtype>::value) {
src.store(dst_ptr);
} else {
vec_t<dstDtype, vec_size> tmp;
tmp.cast_from(src);
tmp.store(dst_ptr);
}
}
template <typename float_t, size_t vec_size>
template <typename T>
SGL_HIP_INLINE void vec_t<float_t, vec_size>::cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename float_t, size_t vec_size>
template <typename T>
SGL_HIP_INLINE void vec_t<float_t, vec_size>::cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
} // namespace sgl_hip
#include "impl/hip_vec_bf16_impl.h"
#include "impl/hip_vec_fp32_impl.h"
#include "impl/hip_vec_half_impl.h"
#endif
#pragma once
#if USE_ROCM
#include <hip/hip_bf16.h>
#include <hip/hip_common.h>
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
using nv_bfloat16 = __hip_bfloat16;
using nv_bfloat162 = __hip_bfloat162;
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 make_bfloat162(const __hip_bfloat16 x, const __hip_bfloat16 y) {
__hip_bfloat162 t;
t.x = x;
t.y = y;
return t;
}
namespace sgl_hip {
// nv_bfloat16 x 1
template <>
struct vec_t<nv_bfloat16, 1> {
nv_bfloat16 data;
SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) {
return ((nv_bfloat16*)(&data))[i];
}
SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const {
return ((const nv_bfloat16*)(&data))[i];
}
SGL_HIP_INLINE nv_bfloat16* ptr() {
return reinterpret_cast<nv_bfloat16*>(&data);
}
SGL_HIP_INLINE void load(const nv_bfloat16* ptr);
SGL_HIP_INLINE void store(nv_bfloat16* ptr) const;
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, 1>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
SGL_HIP_INLINE void vec_t<nv_bfloat16, 1>::load(const nv_bfloat16* ptr) {
data = *ptr;
}
SGL_HIP_INLINE void vec_t<nv_bfloat16, 1>::store(nv_bfloat16* ptr) const {
*ptr = data;
}
// nv_bfloat16 x 2
template <>
struct vec_t<nv_bfloat16, 2> {
nv_bfloat162 data;
SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) {
return ((nv_bfloat16*)(&data))[i];
}
SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const {
return ((const nv_bfloat16*)(&data))[i];
}
SGL_HIP_INLINE nv_bfloat16* ptr() {
return reinterpret_cast<nv_bfloat16*>(&data);
}
SGL_HIP_INLINE void load(const nv_bfloat16* ptr);
SGL_HIP_INLINE void store(nv_bfloat16* ptr) const;
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, 2>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
SGL_HIP_INLINE void vec_t<nv_bfloat16, 2>::load(const nv_bfloat16* ptr) {
data = *((nv_bfloat162*)ptr);
}
SGL_HIP_INLINE void vec_t<nv_bfloat16, 2>::store(nv_bfloat16* ptr) const {
*((nv_bfloat162*)ptr) = data;
}
template <>
struct vec_t<nv_bfloat16, 4> {
uint2 data;
SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) {
return ((nv_bfloat16*)(&data))[i];
}
SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const {
return ((const nv_bfloat16*)(&data))[i];
}
SGL_HIP_INLINE nv_bfloat16* ptr() {
return reinterpret_cast<nv_bfloat16*>(&data);
}
SGL_HIP_INLINE void load(const nv_bfloat16* ptr);
SGL_HIP_INLINE void store(nv_bfloat16* ptr) const;
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, 4>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
SGL_HIP_INLINE void vec_t<nv_bfloat16, 4>::load(const nv_bfloat16* ptr) {
data = *((uint2*)ptr);
}
SGL_HIP_INLINE void vec_t<nv_bfloat16, 4>::store(nv_bfloat16* ptr) const {
*((uint2*)ptr) = data;
}
// nv_bfloat16 x 8 or more
template <size_t vec_size>
struct vec_t<nv_bfloat16, vec_size> {
uint4 data[vec_size / 8];
SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) {
return ((nv_bfloat16*)data)[i];
}
SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const {
return ((const nv_bfloat16*)data)[i];
}
SGL_HIP_INLINE nv_bfloat16* ptr() {
return reinterpret_cast<nv_bfloat16*>(&data);
}
SGL_HIP_INLINE void load(const nv_bfloat16* ptr) {
#pragma unoll
for (size_t i = 0; i < vec_size / 8; ++i) {
data[i] = ((uint4*)ptr)[i];
}
}
SGL_HIP_INLINE void store(nv_bfloat16* ptr) const {
#pragma unoll
for (size_t i = 0; i < vec_size / 8; ++i) {
((uint4*)ptr)[i] = data[i];
}
}
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, vec_size>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
} // namespace sgl_hip
#endif
#pragma once
#if USE_ROCM
#include <hip/hip_common.h>
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
namespace sgl_hip {
template <>
struct vec_t<float, 1> {
float data;
SGL_HIP_INLINE float& operator[](size_t i) {
return ((float*)(&data))[i];
}
SGL_HIP_INLINE const float& operator[](size_t i) const {
return ((const float*)(&data))[i];
}
SGL_HIP_INLINE float* ptr() {
return reinterpret_cast<float*>(&data);
}
SGL_HIP_INLINE void load(const float* ptr);
SGL_HIP_INLINE void store(float* ptr) const;
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, 1>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
SGL_HIP_INLINE void vec_t<float, 1>::load(const float* ptr) {
data = *ptr;
}
SGL_HIP_INLINE void vec_t<float, 1>::store(float* ptr) const {
*ptr = data;
}
// float x 2
template <>
struct vec_t<float, 2> {
float2 data;
SGL_HIP_INLINE float& operator[](size_t i) {
return ((float*)(&data))[i];
}
SGL_HIP_INLINE const float& operator[](size_t i) const {
return ((const float*)(&data))[i];
}
SGL_HIP_INLINE float* ptr() {
return reinterpret_cast<float*>(&data);
}
SGL_HIP_INLINE void load(const float* ptr);
SGL_HIP_INLINE void store(float* ptr) const;
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, 2>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
SGL_HIP_INLINE void vec_t<float, 2>::load(const float* ptr) {
data = *((float2*)ptr);
}
SGL_HIP_INLINE void vec_t<float, 2>::store(float* ptr) const {
*((float2*)ptr) = data;
}
// float x 4 or more
template <size_t vec_size>
struct vec_t<float, vec_size> {
float4 data[vec_size / 4];
SGL_HIP_INLINE float& operator[](size_t i) {
return ((float*)(data))[i];
}
SGL_HIP_INLINE const float& operator[](size_t i) const {
return ((const float*)(data))[i];
}
SGL_HIP_INLINE float* ptr() {
return reinterpret_cast<float*>(&data);
}
SGL_HIP_INLINE void load(const float* ptr) {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
data[i] = ((float4*)ptr)[i];
}
}
SGL_HIP_INLINE void store(float* ptr) const {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((float4*)ptr)[i] = data[i];
}
}
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, vec_size>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
} // namespace sgl_hip
#endif
#pragma once
#if USE_ROCM
#include <hip/hip_common.h>
#include <hip/hip_fp16.h>
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
using half = __half;
using half2 = __half2;
namespace sgl_hip {
// half x 1
template <>
struct vec_t<half, 1> {
half data;
SGL_HIP_INLINE half& operator[](size_t i) {
return ((half*)(&data))[i];
}
SGL_HIP_INLINE const half& operator[](size_t i) const {
return ((const half*)(&data))[i];
}
SGL_HIP_INLINE half* ptr() {
return reinterpret_cast<half*>(&data);
}
SGL_HIP_INLINE void load(const half* ptr);
SGL_HIP_INLINE void store(half* ptr) const;
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, 1>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
SGL_HIP_INLINE void vec_t<half, 1>::load(const half* ptr) {
data = *ptr;
}
SGL_HIP_INLINE void vec_t<half, 1>::store(half* ptr) const {
*ptr = data;
}
// half x 2
template <>
struct vec_t<half, 2> {
half2 data;
SGL_HIP_INLINE half& operator[](size_t i) {
return ((half*)(&data))[i];
}
SGL_HIP_INLINE const half& operator[](size_t i) const {
return ((const half*)(&data))[i];
}
SGL_HIP_INLINE half* ptr() {
return reinterpret_cast<half*>(&data);
}
SGL_HIP_INLINE void load(const half* ptr);
SGL_HIP_INLINE void store(half* ptr) const;
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, 2>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
SGL_HIP_INLINE void vec_t<half, 2>::load(const half* ptr) {
data = *((half2*)ptr);
}
SGL_HIP_INLINE void vec_t<half, 2>::store(half* ptr) const {
*((half2*)ptr) = data;
}
// half x 4
template <>
struct vec_t<half, 4> {
uint2 data;
SGL_HIP_INLINE half& operator[](size_t i) {
return ((half*)(&data))[i];
}
SGL_HIP_INLINE const half& operator[](size_t i) const {
return ((const half*)(&data))[i];
}
SGL_HIP_INLINE half* ptr() {
return reinterpret_cast<half*>(&data);
}
SGL_HIP_INLINE void load(const half* ptr);
SGL_HIP_INLINE void store(half* ptr) const;
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, 4>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
SGL_HIP_INLINE void vec_t<half, 4>::load(const half* ptr) {
data = *((uint2*)ptr);
}
SGL_HIP_INLINE void vec_t<half, 4>::store(half* ptr) const {
*((uint2*)ptr) = data;
}
// half x 8 or more
template <size_t vec_size>
struct vec_t<half, vec_size> {
uint4 data[vec_size / 8];
SGL_HIP_INLINE half& operator[](size_t i) {
return ((half*)data)[i];
}
SGL_HIP_INLINE const half& operator[](size_t i) const {
return ((const half*)data)[i];
}
SGL_HIP_INLINE half* ptr() {
return reinterpret_cast<half*>(&data);
}
SGL_HIP_INLINE void load(const half* ptr) {
#pragma unroll
for (size_t i = 0; i < vec_size / 8; ++i) {
data[i] = ((uint4*)ptr)[i];
}
}
SGL_HIP_INLINE void store(half* ptr) const {
#pragma unroll
for (size_t i = 0; i < vec_size / 8; ++i) {
((uint4*)ptr)[i] = data[i];
}
}
template <typename T>
SGL_HIP_INLINE void cast_from(const vec_t<T, vec_size>& src) {
cast_from_impl(*this, src);
}
template <typename T>
SGL_HIP_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T>
SGL_HIP_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
};
} // namespace sgl_hip
#endif
......@@ -138,9 +138,10 @@ void sgl_fused_add_rmsnorm(
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl);
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, bool enable_pdl);
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void silu_and_mul(at::Tensor& out, at::Tensor& input);
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input);
void gelu_and_mul(at::Tensor& out, at::Tensor& input);
void apply_rope_pos_ids_cos_sin_cache(
at::Tensor q,
at::Tensor k,
......@@ -151,6 +152,9 @@ void apply_rope_pos_ids_cos_sin_cache(
bool interleave,
int64_t cuda_stream);
#ifdef USE_ROCM
void gelu_quick(at::Tensor& out, const at::Tensor& input);
#endif
/*
* From csrc/gemm
*/
......
......@@ -19,7 +19,20 @@ limitations under the License.
#include <cuda_runtime.h>
#include <torch/all.h>
#include <sstream>
#ifdef USE_ROCM
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
#define _DISPATCH_CASE_F16(c_type, ...) \
case at::ScalarType::Half: { \
using c_type = __half; \
return __VA_ARGS__(); \
}
#define _DISPATCH_CASE_BF16(c_type, ...) \
case at::ScalarType::BFloat16: { \
using c_type = __hip_bfloat16; \
return __VA_ARGS__(); \
}
#endif // USE_ROCM
#ifndef USE_ROCM
// Adapt from FlashInfer
......@@ -31,7 +44,7 @@ limitations under the License.
}
#else
#define _DISPATCH_CASE_F16(c_type, ...)
#endif
#endif // FLASHINFER_ENABLE_F16
#ifdef FLASHINFER_ENABLE_BF16
#define _DISPATCH_CASE_BF16(c_type, ...) \
......@@ -41,7 +54,7 @@ limitations under the License.
}
#else
#define _DISPATCH_CASE_BF16(c_type, ...)
#endif
#endif // FLASHINFER_ENABLE_BF16
#ifdef FLASHINFER_ENABLE_FP8_E4M3
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \
......@@ -51,7 +64,7 @@ limitations under the License.
}
#else
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...)
#endif
#endif // FLASHINFER_ENABLE_FP8_E4M3
#ifdef FLASHINFER_ENABLE_FP8_E5M2
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \
......@@ -61,7 +74,7 @@ limitations under the License.
}
#else
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...)
#endif
#endif // FLASHINFER_ENABLE_FP8_E5M2
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
......@@ -197,7 +210,7 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
inline bool is_float8_tensor(const at::Tensor& tensor) {
return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || tensor.scalar_type() == at::ScalarType::Float8_e5m2;
}
#endif
#endif // USE_ROCM
struct cuda_error : public std::runtime_error {
/**
......@@ -267,7 +280,6 @@ inline bool getEnvEnablePDL() {
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width))
#endif
#ifndef USE_ROCM
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
......@@ -284,7 +296,6 @@ inline bool getEnvEnablePDL() {
return false; \
} \
}()
#endif
#define DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
......@@ -297,52 +308,99 @@ inline bool getEnvEnablePDL() {
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize // 64
#endif
#if defined(__HIP_PLATFORM_AMD__)
#include "hip_math_def.h"
#include "hip_vec_dtypes.h"
#else
template <typename srcDtype>
__device__ __forceinline__ float castToFloat(srcDtype val) {
return static_cast<srcDtype>(val);
}
template <typename dstDtype>
__device__ __forceinline__ dstDtype castFromFloat(float val) {
return static_cast<dstDtype>(val);
}
#endif
// add FP8 support
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
#else
#include <c10/util/Float8_e4m3fnuz.h>
#else // USE_ROCM
#if HIP_FP8_TYPE_FNUZ
#include <c10/util/Float8_e4m3fnuz.h>
using FP8_TYPE = c10::Float8_e4m3fnuz;
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif
#else
#if HIP_FP8_TYPE_E4M3
#include <c10/util/Float8_e4m3fn.h>
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
#else
#error "fp8 is not supported in this processor (arch < gfx942)."
#endif // HIP_FP8_TYPE_E4M3
#endif // HIP_FP8_TYPE_FNUZ
#endif // USE_ROCM
#define FULL_MASK 0xffffffff
#ifndef USE_ROCM
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
#ifndef USE_ROCM
float old;
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
return old;
#else
int* addr_as_i = (int*)addr;
int old = *addr_as_i, assumed;
do {
assumed = old;
old = atomicCAS(addr_as_i, assumed, __float_as_int(fmaxf(value, __int_as_float(assumed))));
} while (assumed != old);
return __int_as_float(old);
#endif
}
__device__ __forceinline__ float warpReduceMax(float max_value) {
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 16));
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 8));
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 4));
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 2));
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 1));
return max_value;
__device__ __forceinline__ float warpReduceMax(float value) {
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 16));
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 8));
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 4));
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 2));
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 1));
return value;
}
__device__ __forceinline__ float blockReduceMax(float max_value) {
__device__ __forceinline__ float blockReduceMax(float value) {
static __shared__ float warpLevelMaxs[WARP_SIZE];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
max_value = warpReduceMax(max_value);
value = warpReduceMax(value);
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
if (laneId == 0) warpLevelMaxs[warpId] = value;
__syncthreads();
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
if (warpId == 0) max_value = warpReduceMax(max_value);
value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
if (warpId == 0) value = warpReduceMax(value);
return max_value;
return value;
}
#endif
// Pads to a multiple of `alignment` rows.
inline torch::Tensor pad_tensor(const torch::Tensor& tensor, int64_t alignment = 4, bool is_column_major = false) {
......
......@@ -31,6 +31,10 @@ from sgl_kernel.elementwise import (
silu_and_mul,
)
from sgl_kernel.fused_moe import fused_marlin_moe
if torch.version.hip is not None:
from sgl_kernel.elementwise import gelu_quick
from sgl_kernel.gemm import (
awq_dequantize,
bmm_fp8,
......
......@@ -179,7 +179,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.silu_and_mul.default(out, input, get_cuda_stream())
torch.ops.sgl_kernel.silu_and_mul.default(out, input)
return out
......@@ -194,7 +194,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input, get_cuda_stream())
torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input)
return out
......@@ -209,10 +209,34 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.gelu_and_mul.default(out, input, get_cuda_stream())
torch.ops.sgl_kernel.gelu_and_mul.default(out, input)
return out
if torch.version.hip is not None:
def gelu_quick(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
"""
Quick-GELU: y = x * sigmoid(1.702 * x)
The CUDA/HIP kernel uses 128-bit (16-byte) vector loads & stores,
so the last-dimension byte length must be a multiple of 16 bytes.
"""
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError(
f"The last dimension ({input.shape[-1]}) x itemsize "
f"({input.dtype.itemsize}) must be a multiple of 16 bytes."
)
if out is not None:
assert input.shape == out.shape, f"{input.shape} != {out.shape}"
else:
out = torch.empty_like(input)
torch.ops.sgl_kernel.gelu_quick(out, input)
return out
def apply_rope_with_cos_sin_cache_inplace(
positions: torch.Tensor,
query: torch.Tensor,
......
......@@ -36,16 +36,18 @@ def _get_version():
operator_namespace = "sgl_kernel"
include_dirs = [
root / "include",
root / "include" / "impl",
root / "csrc",
]
sources = [
"csrc/allreduce/custom_all_reduce.hip",
"csrc/allreduce/quick_all_reduce.cu",
"csrc/elementwise/activation.cu",
"csrc/moe/moe_align_kernel.cu",
"csrc/moe/moe_topk_softmax_kernels.cu",
"csrc/torch_extension_rocm.cc",
"csrc/speculative/eagle_utils.cu",
"csrc/torch_extension_rocm.cc",
]
cxx_flags = ["-O3"]
......@@ -69,6 +71,7 @@ if amdgpu_target not in ["gfx942", "gfx950"]:
)
sys.exit(1)
hipcc_flags = [
"-DNDEBUG",
f"-DOPERATOR_NAMESPACE={operator_namespace}",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment