Unverified Commit 9286740e authored by yinfan98's avatar yinfan98 Committed by GitHub
Browse files

feat: refactor sgl-kernel and use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#3130)


Co-authored-by: default avataryinfan.1024 <yinfan.1024@bytedance.com>
Co-authored-by: default avataryinfan98 <1106110035@qq.com>
Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 896c0744
...@@ -26,10 +26,11 @@ Third-party libraries: ...@@ -26,10 +26,11 @@ Third-party libraries:
Steps to add a new kernel: Steps to add a new kernel:
1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc) 1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc)
2. Expose interface in [csrc/sgl_kernel_ops.cu](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu) with pybind11 2. Expose interface in [src/sgl-kernel/include/sgl_kernel_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/include/sgl_kernel_ops.h)
3. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py) 3. Create torch extension in [src/sgl-kernel/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/torch_extension.cc)
4. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py) 4. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py)
5. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source 5. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py)
6. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source
### Build & Install ### Build & Install
...@@ -37,8 +38,6 @@ Development build: ...@@ -37,8 +38,6 @@ Development build:
```bash ```bash
make build make build
pip3 install dist/*whl --force-reinstall --no-deps
# Or use: make install (runs pip install -e .)
``` ```
### Testing & Benchmarking ### Testing & Benchmarking
......
...@@ -38,6 +38,7 @@ def _get_version(): ...@@ -38,6 +38,7 @@ def _get_version():
return line.split("=")[1].strip().strip('"') return line.split("=")[1].strip().strip('"')
operator_namespace = "sgl_kernels"
cutlass_default = root / "3rdparty" / "cutlass" cutlass_default = root / "3rdparty" / "cutlass"
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
flashinfer = root / "3rdparty" / "flashinfer" flashinfer = root / "3rdparty" / "flashinfer"
...@@ -45,15 +46,19 @@ turbomind = root / "3rdparty" / "turbomind" ...@@ -45,15 +46,19 @@ turbomind = root / "3rdparty" / "turbomind"
include_dirs = [ include_dirs = [
cutlass.resolve() / "include", cutlass.resolve() / "include",
cutlass.resolve() / "tools" / "util" / "include", cutlass.resolve() / "tools" / "util" / "include",
root / "src" / "sgl-kernel" / "include",
root / "src" / "sgl-kernel" / "csrc", root / "src" / "sgl-kernel" / "csrc",
flashinfer.resolve() / "include", flashinfer.resolve() / "include",
flashinfer.resolve() / "include" / "gemm", flashinfer.resolve() / "include" / "gemm",
flashinfer.resolve() / "csrc", flashinfer.resolve() / "csrc",
"cublas",
"cublasLt",
turbomind.resolve(), turbomind.resolve(),
turbomind.resolve() / "src", turbomind.resolve() / "src",
] ]
nvcc_flags = [ nvcc_flags = [
"-DNDEBUG", "-DNDEBUG",
f"-DOPERATOR_NAMESPACE={operator_namespace}",
"-O3", "-O3",
"-Xcompiler", "-Xcompiler",
"-fPIC", "-fPIC",
...@@ -72,13 +77,13 @@ nvcc_flags_fp8 = [ ...@@ -72,13 +77,13 @@ nvcc_flags_fp8 = [
] ]
sources = [ sources = [
"src/sgl-kernel/torch_extension.cc",
"src/sgl-kernel/csrc/trt_reduce_internal.cu", "src/sgl-kernel/csrc/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/trt_reduce_kernel.cu", "src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/moe_align_kernel.cu", "src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/int8_gemm_kernel.cu", "src/sgl-kernel/csrc/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu", "src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu", "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu", "src/sgl-kernel/csrc/rotary_embedding.cu",
"src/sgl-kernel/csrc/fused_add_rms_norm.cu", "src/sgl-kernel/csrc/fused_add_rms_norm.cu",
"3rdparty/flashinfer/csrc/activation.cu", "3rdparty/flashinfer/csrc/activation.cu",
...@@ -125,7 +130,7 @@ for flag in [ ...@@ -125,7 +130,7 @@ for flag in [
pass pass
cxx_flags = ["-O3"] cxx_flags = ["-O3"]
libraries = ["c10", "torch", "torch_python", "cuda"] libraries = ["c10", "torch", "torch_python", "cuda", "cublas", "cublasLt"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"] extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
ext_modules = [ ext_modules = [
...@@ -139,6 +144,7 @@ ext_modules = [ ...@@ -139,6 +144,7 @@ ext_modules = [
}, },
libraries=libraries, libraries=libraries,
extra_link_args=extra_link_args, extra_link_args=extra_link_args,
py_limited_api=True,
), ),
] ]
...@@ -149,6 +155,7 @@ setup( ...@@ -149,6 +155,7 @@ setup(
package_dir={"": "src"}, package_dir={"": "src"},
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension}, cmdclass={"build_ext": BuildExtension},
options={"bdist_wheel": {"py_limited_api": "cp39"}},
) )
_update_wheel_platform_tag() _update_wheel_platform_tag()
#pragma once
#include <Python.h>
#include <torch/extension.h>
#include <vector> #include <vector>
#include "utils.h" #include "utils.h"
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
#define _STRINGIFY(A) #A
#define STRINGIFY(A) _STRINGIFY(A)
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
#define REGISTER_EXTENSION(NAME) \
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \
return PyModule_Create(&module); \
}
// trt_reduce // trt_reduce
using fptr_t = int64_t; using fptr_t = int64_t;
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers, fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
...@@ -67,9 +85,18 @@ void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at: ...@@ -67,9 +85,18 @@ void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at:
int64_t cuda_stream); int64_t cuda_stream);
// top k renorm probs // top k renorm probs
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr, void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr,
unsigned int top_k_val, int64_t cuda_stream); unsigned int top_k_val, int64_t cuda_stream);
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
// wrapper for binding
inline void top_k_renorm_probs_wrapper(at::Tensor probs, at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val,
int64_t cuda_stream) {
top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast<unsigned int>(top_k_val), cuda_stream);
}
// top p renorm probs // top p renorm probs
void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr, void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val, int64_t cuda_stream); double top_p_val, int64_t cuda_stream);
...@@ -84,48 +111,3 @@ void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_sample ...@@ -84,48 +111,3 @@ void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_sample
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success, void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success,
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic, std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
int64_t cuda_stream); int64_t cuda_stream);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// trt_reduce
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
m.def("dispose", &dispose, "dispose custom allreduce meta");
m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)");
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "custom all reduce get graph ipc meta");
m.def("register_graph_buffers", &register_graph_buffers, "custom all reduce register graph buffers");
// moe_align_block_size
m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)");
// sampling_scaling_penalties
m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)");
// int8_scaled_mm
m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)");
// lightning_attention_decode
m.def("lightning_attention_decode", &lightning_attention_decode, "Lightning Attention Ddecode (CUDA)");
// rotary embedding
m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)");
// rms norm
m.def("rmsnorm", &rmsnorm, "RMSNorm (CUDA)");
// fused rms norm
m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused Add RMSNorm (CUDA)");
// gemma rms norm
m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma RMSNorm (CUDA)");
// fused gemma rms norm
m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, "Gemma Fused Add RMSNorm (CUDA)");
// silu and mul
m.def("silu_and_mul", &silu_and_mul, "Silu and Mul (CUDA)");
// gelu tanh and mul
m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Gelu Tanh and Mul (CUDA)");
// gelu and mul
m.def("gelu_and_mul", &gelu_and_mul, "Gelu and Mul (CUDA)");
// bmm fp8
m.def("bmm_fp8", &bmm_fp8, "BMM FP8 (CUDA)");
// min p sampling from probs
m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs, "Min P Sampling From Probs (CUDA)");
// top k renorm probs
m.def("top_k_renorm_probs", &top_k_renorm_probs, "Top K Renorm Probs (CUDA)");
// top p renorm probs
m.def("top_p_renorm_probs", &top_p_renorm_probs, "Top P Renorm Probs (CUDA)");
// top k top p sampling from probs
m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, "Top K Top P Sampling From Probs (CUDA)");
// top p sampling from probs
m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, "Top P Sampling From Probs (CUDA)");
}
#pragma once #pragma once
#include <cuda_runtime.h>
#include <pytorch_extension_utils.h> #include <pytorch_extension_utils.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <sstream> #include <sstream>
#include "sgl_kernels_ops.h"
struct cuda_error : public std::runtime_error { struct cuda_error : public std::runtime_error {
/** /**
* @brief Constructs a `cuda_error` object with the given `message`. * @brief Constructs a `cuda_error` object with the given `message`.
......
import os
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import sgl_kernel.ops._kernels
import torch import torch
from sgl_kernel.ops._kernels import all_reduce as _all_reduce
from sgl_kernel.ops._kernels import bmm_fp8 as _bmm_fp8
from sgl_kernel.ops._kernels import dispose as _dispose
from sgl_kernel.ops._kernels import fused_add_rmsnorm as _fused_add_rmsnorm
from sgl_kernel.ops._kernels import gelu_and_mul as _gelu_and_mul
from sgl_kernel.ops._kernels import gelu_tanh_and_mul as _gelu_tanh_and_mul
from sgl_kernel.ops._kernels import gemma_fused_add_rmsnorm as _gemma_fused_add_rmsnorm
from sgl_kernel.ops._kernels import gemma_rmsnorm as _gemma_rmsnorm
from sgl_kernel.ops._kernels import (
get_graph_buffer_ipc_meta as _get_graph_buffer_ipc_meta,
)
from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar
from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
from sgl_kernel.ops._kernels import (
lightning_attention_decode as _lightning_attention_decode,
)
from sgl_kernel.ops._kernels import (
min_p_sampling_from_probs as _min_p_sampling_from_probs,
)
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers
from sgl_kernel.ops._kernels import rmsnorm as _rmsnorm
from sgl_kernel.ops._kernels import rotary_embedding as _rotary_embedding
from sgl_kernel.ops._kernels import (
sampling_scaling_penalties as _sampling_scaling_penalties,
)
from sgl_kernel.ops._kernels import silu_and_mul as _silu_and_mul
from sgl_kernel.ops._kernels import top_k_renorm_probs as _top_k_renorm_probs
from sgl_kernel.ops._kernels import (
top_k_top_p_sampling_from_probs as _top_k_top_p_sampling_from_probs,
)
from sgl_kernel.ops._kernels import top_p_renorm_probs as _top_p_renorm_probs
from sgl_kernel.ops._kernels import (
top_p_sampling_from_probs as _top_p_sampling_from_probs,
)
from sgl_kernel.ops.utils import ( from sgl_kernel.ops.utils import (
_get_cache_buf, _get_cache_buf,
_get_cuda_stream, _get_cuda_stream,
...@@ -46,25 +13,25 @@ from sgl_kernel.ops.utils import ( ...@@ -46,25 +13,25 @@ from sgl_kernel.ops.utils import (
def init_custom_reduce( def init_custom_reduce(
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
): ):
return _init_custom_ar( return torch.ops.sgl_kernels.init_custom_ar(
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
) )
def custom_dispose(fa): def custom_dispose(fa):
_dispose(fa) torch.ops.sgl_kernels.dispose(fa)
def custom_reduce(fa, inp, out): def custom_reduce(fa, inp, out):
_all_reduce(fa, inp, out) torch.ops.sgl_kernels.all_reduce(fa, inp, out)
def get_graph_buffer_ipc_meta(fa): def get_graph_buffer_ipc_meta(fa):
return _get_graph_buffer_ipc_meta(fa) return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(fa, handles, offsets): def register_graph_buffers(fa, handles, offsets):
_register_graph_buffers(fa, handles, offsets) torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
def moe_align_block_size( def moe_align_block_size(
...@@ -77,7 +44,7 @@ def moe_align_block_size( ...@@ -77,7 +44,7 @@ def moe_align_block_size(
token_cnts_buffer, token_cnts_buffer,
cumsum_buffer, cumsum_buffer,
): ):
_moe_align_block_size( torch.ops.sgl_kernels.moe_align_block_size(
topk_ids, topk_ids,
num_experts, num_experts,
block_size, block_size,
...@@ -90,11 +57,11 @@ def moe_align_block_size( ...@@ -90,11 +57,11 @@ def moe_align_block_size(
def sampling_scaling_penalties(logits, scaling_penalties): def sampling_scaling_penalties(logits, scaling_penalties):
return _sampling_scaling_penalties(logits, scaling_penalties) return torch.ops.sgl_kernels.sampling_scaling_penalties(logits, scaling_penalties)
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return _int8_scaled_mm( return torch.ops.sgl_kernels.int8_scaled_mm(
mat_a, mat_a,
mat_b, mat_b,
scales_a, scales_a,
...@@ -105,11 +72,15 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): ...@@ -105,11 +72,15 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
_lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv) torch.ops.sgl_kernels.lightning_attention_decode(
q, k, v, past_kv, slope, output, new_kv
)
def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox): def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox):
return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox) return torch.ops.sgl_kernels.rotary_embedding(
positions, query, key, head_size, cos_sin_cache, is_neox
)
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer # These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
...@@ -123,7 +94,7 @@ def rmsnorm( ...@@ -123,7 +94,7 @@ def rmsnorm(
with input.device as device: with input.device as device:
if out is None: if out is None:
out = torch.empty_like(input) out = torch.empty_like(input)
_rmsnorm(out, input, weight, eps, _get_cuda_stream(device)) torch.ops.sgl_kernels.rmsnorm(out, input, weight, eps, _get_cuda_stream(device))
return out return out
...@@ -131,7 +102,9 @@ def fused_add_rmsnorm( ...@@ -131,7 +102,9 @@ def fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None: ) -> None:
with input.device as device: with input.device as device:
_fused_add_rmsnorm(input, residual, weight, eps, _get_cuda_stream(device)) torch.ops.sgl_kernels.fused_add_rmsnorm(
input, residual, weight, eps, _get_cuda_stream(device)
)
def gemma_rmsnorm( def gemma_rmsnorm(
...@@ -143,7 +116,9 @@ def gemma_rmsnorm( ...@@ -143,7 +116,9 @@ def gemma_rmsnorm(
with input.device as device: with input.device as device:
if out is None: if out is None:
out = torch.empty_like(input) out = torch.empty_like(input)
_gemma_rmsnorm(out, input, weight, eps, _get_cuda_stream(device)) torch.ops.sgl_kernels.gemma_rmsnorm(
out, input, weight, eps, _get_cuda_stream(device)
)
return out return out
...@@ -151,7 +126,9 @@ def gemma_fused_add_rmsnorm( ...@@ -151,7 +126,9 @@ def gemma_fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None: ) -> None:
with input.device as device: with input.device as device:
_gemma_fused_add_rmsnorm(input, residual, weight, eps, _get_cuda_stream(device)) torch.ops.sgl_kernels.gemma_fused_add_rmsnorm(
input, residual, weight, eps, _get_cuda_stream(device)
)
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None: def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
...@@ -176,7 +153,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: ...@@ -176,7 +153,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
dtype=input.dtype, dtype=input.dtype,
) )
with input.device as device: with input.device as device:
_silu_and_mul(out, input, _get_cuda_stream(device)) torch.ops.sgl_kernels.silu_and_mul(out, input, _get_cuda_stream(device))
return out return out
...@@ -192,7 +169,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te ...@@ -192,7 +169,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
dtype=input.dtype, dtype=input.dtype,
) )
with input.device as device: with input.device as device:
_gelu_tanh_and_mul(out, input, _get_cuda_stream(device)) torch.ops.sgl_kernels.gelu_tanh_and_mul(out, input, _get_cuda_stream(device))
return out return out
...@@ -208,7 +185,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: ...@@ -208,7 +185,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
dtype=input.dtype, dtype=input.dtype,
) )
with input.device as device: with input.device as device:
_gelu_and_mul(out, input, _get_cuda_stream(device)) torch.ops.sgl_kernels.gelu_and_mul(out, input, _get_cuda_stream(device))
return out return out
...@@ -222,7 +199,7 @@ def _bmm_fp8_internal( ...@@ -222,7 +199,7 @@ def _bmm_fp8_internal(
) -> None: ) -> None:
with A.device as device: with A.device as device:
cublas_handle = torch.cuda.current_blas_handle() cublas_handle = torch.cuda.current_blas_handle()
_bmm_fp8( torch.ops.sgl_kernels.bmm_fp8(
A, A,
B, B,
D, D,
...@@ -262,7 +239,7 @@ def _top_k_renorm_probs_internal( ...@@ -262,7 +239,7 @@ def _top_k_renorm_probs_internal(
probs = probs.float() probs = probs.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
renorm_probs = torch.empty_like(probs) renorm_probs = torch.empty_like(probs)
_top_k_renorm_probs( torch.ops.sgl_kernels.top_k_renorm_probs_wrapper(
probs, probs,
renorm_probs, renorm_probs,
maybe_top_k_arr, maybe_top_k_arr,
...@@ -293,7 +270,7 @@ def _top_p_renorm_probs_internal( ...@@ -293,7 +270,7 @@ def _top_p_renorm_probs_internal(
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
) )
renorm_probs = torch.empty_like(probs) renorm_probs = torch.empty_like(probs)
_top_p_renorm_probs( torch.ops.sgl_kernels.top_p_renorm_probs(
probs, probs,
renorm_probs, renorm_probs,
maybe_top_p_arr, maybe_top_p_arr,
...@@ -328,7 +305,7 @@ def _top_p_sampling_from_probs_internal( ...@@ -328,7 +305,7 @@ def _top_p_sampling_from_probs_internal(
) )
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
success = torch.empty(probs.size(0), dtype=torch.bool, device=device) success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
_top_p_sampling_from_probs( torch.ops.sgl_kernels.top_p_sampling_from_probs(
probs, probs,
uniform_samples, uniform_samples,
samples, samples,
...@@ -374,7 +351,7 @@ def _top_k_top_p_sampling_from_probs_internal( ...@@ -374,7 +351,7 @@ def _top_k_top_p_sampling_from_probs_internal(
) )
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
success = torch.empty(probs.size(0), dtype=torch.bool, device=device) success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
_top_k_top_p_sampling_from_probs( torch.ops.sgl_kernels.top_k_top_p_sampling_from_probs(
probs, probs,
uniform_samples, uniform_samples,
samples, samples,
...@@ -432,7 +409,7 @@ def _min_p_sampling_from_probs_internal( ...@@ -432,7 +409,7 @@ def _min_p_sampling_from_probs_internal(
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
) )
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
_min_p_sampling_from_probs( torch.ops.sgl_kernels.min_p_sampling_from_probs(
probs, probs,
uniform_samples, uniform_samples,
samples, samples,
......
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>
#include "sgl_kernels_ops.h"
TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
// trt_reduce
m.def(
"init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] "
"barrier_in, int[] barrier_out) -> int");
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
m.def("dispose", &dispose);
m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()");
m.impl("all_reduce", torch::kCUDA, &all_reduce);
m.def("get_graph_buffer_ipc_meta(int fa) -> (int[], int[])");
m.impl("get_graph_buffer_ipc_meta", torch::kCUDA, &get_graph_buffer_ipc_meta);
m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()");
m.impl("register_graph_buffers", torch::kCUDA, &register_graph_buffers);
// moe_align_block_size
m.def(
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
// sampling_scaling_penalties
m.def("sampling_scaling_penalties(Tensor logits, Tensor scaling_penalties) -> Tensor");
m.impl("sampling_scaling_penalties", torch::kCUDA, &sampling_scaling_penalties);
// int8_scaled_mm
m.def(
"int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
"bias) -> Tensor");
m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm);
// lightning_attention_decode
m.def(
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
"new_kv) -> ()");
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
// rotary embedding
m.def(
"rotary_embedding(Tensor positions, Tensor! query, Tensor! key, int head_size, Tensor cos_sin_cache, bool "
"is_neox) -> ()");
m.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
// rms norm
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
// fused rms norm
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("fused_add_rmsnorm", torch::kCUDA, &fused_add_rmsnorm);
// gemma rms norm
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
// fused gemma rms norm
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
// silu and mul
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
// gelu tanh and mul
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
// gelu and mul
m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
// bmm fp8
m.def(
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
"cublas_handle, int cuda_stream) -> ()");
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
// min p sampling from probs
m.def(
"min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float "
"min_p_val, bool deterministic, int cuda_stream) -> ()");
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
// top k renorm probs
m.def(
"top_k_renorm_probs_wrapper(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
"cuda_stream) -> ()");
m.impl("top_k_renorm_probs_wrapper", torch::kCUDA, &top_k_renorm_probs_wrapper);
// top p renorm probs
m.def(
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
"cuda_stream) -> ()");
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
// top k top p sampling from probs
m.def(
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int "
"cuda_stream) -> ()");
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
// top p sampling from probs
m.def(
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
}
REGISTER_EXTENSION(_kernels)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment