Unverified Commit 8abf74e3 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Rename files in sgl kernel to avoid nested folder structure (#4213)


Co-authored-by: default avatarzhyncs <me@zhyncs.com>
parent ee132a45
...@@ -16,33 +16,9 @@ limitations under the License. ...@@ -16,33 +16,9 @@ limitations under the License.
#include <ATen/core/dispatch/Dispatcher.h> #include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h> #include <torch/library.h>
#include "sgl_kernels_ops.h" #include "sgl_kernel_ops.h"
TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
/*
* From csrc/activation
*/
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()");
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/* /*
* From csrc/allreduce * From csrc/allreduce
*/ */
...@@ -67,6 +43,30 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { ...@@ -67,6 +43,30 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
*/ */
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode); m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
/*
* From csrc/elementwise
*/
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()");
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
/* /*
* From csrc/gemm * From csrc/gemm
*/ */
...@@ -93,6 +93,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { ...@@ -93,6 +93,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"); m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()");
m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8); m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8);
m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()");
m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8);
m.def( m.def(
"cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs," "cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs,"
" ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()"); " ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()");
...@@ -171,9 +174,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { ...@@ -171,9 +174,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()"); "Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()");
m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8);
} }
REGISTER_EXTENSION(_kernels) REGISTER_EXTENSION(common_ops)
...@@ -16,9 +16,9 @@ limitations under the License. ...@@ -16,9 +16,9 @@ limitations under the License.
#include <ATen/core/dispatch/Dispatcher.h> #include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h> #include <torch/library.h>
#include "sgl_kernels_ops.h" #include "sgl_kernel_ops.h"
TORCH_LIBRARY_EXPAND(sgl_kernels, m) { TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/* /*
* From csrc/allreduce * From csrc/allreduce
*/ */
......
...@@ -36,18 +36,6 @@ limitations under the License. ...@@ -36,18 +36,6 @@ limitations under the License.
using fptr_t = int64_t; using fptr_t = int64_t;
/*
* From csrc/activation
*/
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
void gemma_fused_add_rmsnorm(
at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream);
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
/* /*
* From csrc/allreduce * From csrc/allreduce
*/ */
...@@ -88,6 +76,30 @@ void register_graph_buffers( ...@@ -88,6 +76,30 @@ void register_graph_buffers(
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets); fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets);
#endif #endif
/*
* From csrc/attention
*/
void lightning_attention_decode(
const torch::Tensor& q,
const torch::Tensor& k,
const torch::Tensor& v,
const torch::Tensor& past_kv,
const torch::Tensor& slope,
torch::Tensor output,
torch::Tensor new_kv);
/*
* From csrc/elementwise
*/
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
void gemma_fused_add_rmsnorm(
at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream);
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
/* /*
* From csrc/gemm * From csrc/gemm
*/ */
...@@ -120,6 +132,7 @@ void sgl_per_token_group_quant_fp8( ...@@ -120,6 +132,7 @@ void sgl_per_token_group_quant_fp8(
double fp8_min, double fp8_min,
double fp8_max); double fp8_max);
void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static); void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static);
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
void cublas_grouped_gemm( void cublas_grouped_gemm(
const std::vector<torch::Tensor>& inputs, const std::vector<torch::Tensor>& inputs,
const std::vector<torch::Tensor>& weights, const std::vector<torch::Tensor>& weights,
...@@ -254,18 +267,3 @@ void apply_rope_pos_ids_cos_sin_cache( ...@@ -254,18 +267,3 @@ void apply_rope_pos_ids_cos_sin_cache(
at::Tensor pos_ids, at::Tensor pos_ids,
bool interleave, bool interleave,
int64_t cuda_stream); int64_t cuda_stream);
/*
* Other
*/
void lightning_attention_decode(
const torch::Tensor& q,
const torch::Tensor& k,
const torch::Tensor& v,
const torch::Tensor& past_kv,
const torch::Tensor& slope,
torch::Tensor output,
torch::Tensor new_kv);
// sgl_per_token_quant_fp8
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
...@@ -20,10 +20,6 @@ dependencies = [] ...@@ -20,10 +20,6 @@ dependencies = []
"Homepage" = "https://github.com/sgl-project/sglang/tree/main/sgl-kernel" "Homepage" = "https://github.com/sgl-project/sglang/tree/main/sgl-kernel"
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues" "Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
[tool.setuptools]
package-dir = {"sgl_kernel" = "src/sgl-kernel"}
packages = ["sgl_kernel", "sgl_kernel.ops", "sgl_kernel.csrc"]
[tool.wheel] [tool.wheel]
exclude = [ exclude = [
"dist*", "dist*",
......
...@@ -9,7 +9,10 @@ if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"): ...@@ -9,7 +9,10 @@ if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
mode=ctypes.RTLD_GLOBAL, mode=ctypes.RTLD_GLOBAL,
) )
from sgl_kernel.ops.activation import ( from sgl_kernel import common_ops
from sgl_kernel.allreduce import *
from sgl_kernel.attention import lightning_attention_decode
from sgl_kernel.elementwise import (
apply_rope_with_cos_sin_cache_inplace, apply_rope_with_cos_sin_cache_inplace,
fused_add_rmsnorm, fused_add_rmsnorm,
gelu_and_mul, gelu_and_mul,
...@@ -19,9 +22,7 @@ from sgl_kernel.ops.activation import ( ...@@ -19,9 +22,7 @@ from sgl_kernel.ops.activation import (
rmsnorm, rmsnorm,
silu_and_mul, silu_and_mul,
) )
from sgl_kernel.ops.allreduce import * from sgl_kernel.gemm import (
from sgl_kernel.ops.attention import lightning_attention_decode
from sgl_kernel.ops.gemm import (
bmm_fp8, bmm_fp8,
cublas_grouped_gemm, cublas_grouped_gemm,
fp8_blockwise_scaled_mm, fp8_blockwise_scaled_mm,
...@@ -31,15 +32,15 @@ from sgl_kernel.ops.gemm import ( ...@@ -31,15 +32,15 @@ from sgl_kernel.ops.gemm import (
sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_fp8,
sgl_per_token_quant_fp8, sgl_per_token_quant_fp8,
) )
from sgl_kernel.ops.moe import moe_align_block_size from sgl_kernel.moe import moe_align_block_size
from sgl_kernel.ops.sampling import ( from sgl_kernel.sampling import (
min_p_sampling_from_probs, min_p_sampling_from_probs,
top_k_renorm_prob, top_k_renorm_prob,
top_k_top_p_sampling_from_probs, top_k_top_p_sampling_from_probs,
top_p_renorm_prob, top_p_renorm_prob,
top_p_sampling_from_probs, top_p_sampling_from_probs,
) )
from sgl_kernel.ops.speculative import ( from sgl_kernel.speculative import (
build_tree_kernel, build_tree_kernel,
build_tree_kernel_efficient, build_tree_kernel_efficient,
tree_speculative_sampling_target_only, tree_speculative_sampling_target_only,
......
from typing import List, Tuple from typing import List, Tuple
import sgl_kernel.ops._kernels
import torch import torch
if torch.version.hip is not None: if torch.version.hip is not None:
...@@ -13,49 +12,49 @@ if torch.version.hip is not None: ...@@ -13,49 +12,49 @@ if torch.version.hip is not None:
rank: int, rank: int,
full_nvlink: bool, full_nvlink: bool,
) -> int: ) -> int:
return torch.ops.sgl_kernels.init_custom_ar( return torch.ops.sgl_kernel.init_custom_ar(
meta, rank_data, handles, offsets, rank, full_nvlink meta, rank_data, handles, offsets, rank, full_nvlink
) )
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
torch.ops.sgl_kernels.all_reduce_reg(fa, inp, out) torch.ops.sgl_kernel.all_reduce_reg(fa, inp, out)
def all_reduce_unreg( def all_reduce_unreg(
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
) -> None: ) -> None:
torch.ops.sgl_kernels.all_reduce_unreg(fa, inp, reg_buffer, out) torch.ops.sgl_kernel.all_reduce_unreg(fa, inp, reg_buffer, out)
def dispose(fa: int) -> None: def dispose(fa: int) -> None:
torch.ops.sgl_kernels.dispose(fa) torch.ops.sgl_kernel.dispose(fa)
def meta_size() -> int: def meta_size() -> int:
return torch.ops.sgl_kernels.meta_size() return torch.ops.sgl_kernel.meta_size()
def register_buffer( def register_buffer(
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int] fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
) -> None: ) -> None:
return torch.ops.sgl_kernels.register_buffer(fa, t, handles, offsets) return torch.ops.sgl_kernel.register_buffer(fa, t, handles, offsets)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]: def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa) return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers( def register_graph_buffers(
fa: int, handles: List[str], offsets: List[List[int]] fa: int, handles: List[str], offsets: List[List[int]]
) -> None: ) -> None:
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets) torch.ops.sgl_kernel.register_graph_buffers(fa, handles, offsets)
def allocate_meta_buffer(size: int) -> torch.Tensor: def allocate_meta_buffer(size: int) -> torch.Tensor:
return torch.ops.sgl_kernels.allocate_meta_buffer(size) return torch.ops.sgl_kernel.allocate_meta_buffer(size)
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return torch.ops.sgl_kernels.get_meta_buffer_ipc_handle(inp) return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle(inp)
else: else:
# TRTLLM custom allreduce # TRTLLM custom allreduce
def init_custom_reduce( def init_custom_reduce(
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
): ):
return torch.ops.sgl_kernels.init_custom_ar( return torch.ops.sgl_kernel.init_custom_ar(
rank_id, rank_id,
num_devices, num_devices,
rank_data, rank_data,
...@@ -66,13 +65,13 @@ else: ...@@ -66,13 +65,13 @@ else:
) )
def custom_dispose(fa): def custom_dispose(fa):
torch.ops.sgl_kernels.dispose(fa) torch.ops.sgl_kernel.dispose(fa)
def custom_reduce(fa, inp, out): def custom_reduce(fa, inp, out):
torch.ops.sgl_kernels.all_reduce(fa, inp, out) torch.ops.sgl_kernel.all_reduce(fa, inp, out)
def get_graph_buffer_ipc_meta(fa): def get_graph_buffer_ipc_meta(fa):
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa) return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(fa, handles, offsets): def register_graph_buffers(fa, handles, offsets):
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets) torch.ops.sgl_kernel.register_graph_buffers(fa, handles, offsets)
import sgl_kernel.ops._kernels
import torch import torch
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
torch.ops.sgl_kernels.lightning_attention_decode( torch.ops.sgl_kernel.lightning_attention_decode(
q, k, v, past_kv, slope, output, new_kv q, k, v, past_kv, slope, output, new_kv
) )
from typing import Optional from typing import Optional
import sgl_kernel.ops._kernels
import torch import torch
from sgl_kernel.ops.utils import get_cuda_stream from sgl_kernel.utils import get_cuda_stream
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer # These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
...@@ -15,14 +14,14 @@ def rmsnorm( ...@@ -15,14 +14,14 @@ def rmsnorm(
) -> torch.Tensor: ) -> torch.Tensor:
if out is None: if out is None:
out = torch.empty_like(input) out = torch.empty_like(input)
torch.ops.sgl_kernels.rmsnorm(out, input, weight, eps, get_cuda_stream()) torch.ops.sgl_kernel.rmsnorm(out, input, weight, eps, get_cuda_stream())
return out return out
def fused_add_rmsnorm( def fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None: ) -> None:
torch.ops.sgl_kernels.fused_add_rmsnorm(input, residual, weight, eps) torch.ops.sgl_kernel.fused_add_rmsnorm(input, residual, weight, eps)
def gemma_rmsnorm( def gemma_rmsnorm(
...@@ -33,14 +32,14 @@ def gemma_rmsnorm( ...@@ -33,14 +32,14 @@ def gemma_rmsnorm(
) -> torch.Tensor: ) -> torch.Tensor:
if out is None: if out is None:
out = torch.empty_like(input) out = torch.empty_like(input)
torch.ops.sgl_kernels.gemma_rmsnorm(out, input, weight, eps, get_cuda_stream()) torch.ops.sgl_kernel.gemma_rmsnorm(out, input, weight, eps, get_cuda_stream())
return out return out
def gemma_fused_add_rmsnorm( def gemma_fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None: ) -> None:
torch.ops.sgl_kernels.gemma_fused_add_rmsnorm( torch.ops.sgl_kernel.gemma_fused_add_rmsnorm(
input, residual, weight, eps, get_cuda_stream() input, residual, weight, eps, get_cuda_stream()
) )
...@@ -66,7 +65,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: ...@@ -66,7 +65,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
device=input.device, device=input.device,
dtype=input.dtype, dtype=input.dtype,
) )
torch.ops.sgl_kernels.silu_and_mul(out, input, get_cuda_stream()) torch.ops.sgl_kernel.silu_and_mul(out, input, get_cuda_stream())
return out return out
...@@ -81,7 +80,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te ...@@ -81,7 +80,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
device=input.device, device=input.device,
dtype=input.dtype, dtype=input.dtype,
) )
torch.ops.sgl_kernels.gelu_tanh_and_mul(out, input, get_cuda_stream()) torch.ops.sgl_kernel.gelu_tanh_and_mul(out, input, get_cuda_stream())
return out return out
...@@ -96,7 +95,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: ...@@ -96,7 +95,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
device=input.device, device=input.device,
dtype=input.dtype, dtype=input.dtype,
) )
torch.ops.sgl_kernels.gelu_and_mul(out, input, get_cuda_stream()) torch.ops.sgl_kernel.gelu_and_mul(out, input, get_cuda_stream())
return out return out
...@@ -141,7 +140,7 @@ def apply_rope_with_cos_sin_cache_inplace( ...@@ -141,7 +140,7 @@ def apply_rope_with_cos_sin_cache_inplace(
raise ValueError("cos_sin_cache should be float32") raise ValueError("cos_sin_cache should be float32")
positions = positions.int() positions = positions.int()
torch.ops.sgl_kernels.apply_rope_pos_ids_cos_sin_cache( torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache(
q=query.view(query.shape[0], -1, head_size), q=query.view(query.shape[0], -1, head_size),
k=key.view(key.shape[0], -1, head_size), k=key.view(key.shape[0], -1, head_size),
q_rope=query.view(query.shape[0], -1, head_size), q_rope=query.view(query.shape[0], -1, head_size),
......
from typing import List, Optional from typing import List, Optional
import sgl_kernel.ops._kernels
import torch import torch
from sgl_kernel.ops.utils import _get_cache_buf, get_cuda_stream from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return torch.ops.sgl_kernels.int8_scaled_mm( return torch.ops.sgl_kernel.int8_scaled_mm(
mat_a, mat_a,
mat_b, mat_b,
scales_a, scales_a,
...@@ -17,7 +16,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): ...@@ -17,7 +16,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype): def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
return torch.ops.sgl_kernels.fp8_blockwise_scaled_mm( return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm(
mat_a, mat_a,
mat_b, mat_b,
scales_a, scales_a,
...@@ -27,7 +26,7 @@ def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype): ...@@ -27,7 +26,7 @@ def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return torch.ops.sgl_kernels.fp8_scaled_mm( return torch.ops.sgl_kernel.fp8_scaled_mm(
mat_a, mat_a,
mat_b, mat_b,
scales_a, scales_a,
...@@ -46,7 +45,7 @@ def _bmm_fp8_internal( ...@@ -46,7 +45,7 @@ def _bmm_fp8_internal(
B_scale: torch.Tensor, B_scale: torch.Tensor,
) -> None: ) -> None:
cublas_handle = torch.cuda.current_blas_handle() cublas_handle = torch.cuda.current_blas_handle()
torch.ops.sgl_kernels.bmm_fp8( torch.ops.sgl_kernel.bmm_fp8(
A, A,
B, B,
D, D,
...@@ -86,7 +85,7 @@ def sgl_per_token_group_quant_fp8( ...@@ -86,7 +85,7 @@ def sgl_per_token_group_quant_fp8(
fp8_min: float, fp8_min: float,
fp8_max: float, fp8_max: float,
) -> None: ) -> None:
torch.ops.sgl_kernels.sgl_per_token_group_quant_fp8( torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max input, output_q, output_s, group_size, eps, fp8_min, fp8_max
) )
...@@ -97,7 +96,7 @@ def sgl_per_tensor_quant_fp8( ...@@ -97,7 +96,7 @@ def sgl_per_tensor_quant_fp8(
output_s: torch.Tensor, output_s: torch.Tensor,
is_static: bool, is_static: bool,
) -> None: ) -> None:
torch.ops.sgl_kernels.sgl_per_tensor_quant_fp8(input, output_q, output_s, is_static) torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8(input, output_q, output_s, is_static)
def cublas_grouped_gemm( def cublas_grouped_gemm(
...@@ -110,7 +109,7 @@ def cublas_grouped_gemm( ...@@ -110,7 +109,7 @@ def cublas_grouped_gemm(
len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0 len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0
), "Inputs/weights/outputs should not be empty!" ), "Inputs/weights/outputs should not be empty!"
cublas_handle = torch.cuda.current_blas_handle() cublas_handle = torch.cuda.current_blas_handle()
torch.ops.sgl_kernels.cublas_grouped_gemm( torch.ops.sgl_kernel.cublas_grouped_gemm(
inputs, inputs,
weights, weights,
outputs, outputs,
...@@ -125,4 +124,4 @@ def sgl_per_token_quant_fp8( ...@@ -125,4 +124,4 @@ def sgl_per_token_quant_fp8(
output_q: torch.Tensor, output_q: torch.Tensor,
output_s: torch.Tensor, output_s: torch.Tensor,
) -> None: ) -> None:
torch.ops.sgl_kernels.sgl_per_token_quant_fp8(input, output_q, output_s) torch.ops.sgl_kernel.sgl_per_token_quant_fp8(input, output_q, output_s)
import sgl_kernel.ops._kernels
import torch import torch
...@@ -12,7 +11,7 @@ def moe_align_block_size( ...@@ -12,7 +11,7 @@ def moe_align_block_size(
token_cnts_buffer, token_cnts_buffer,
cumsum_buffer, cumsum_buffer,
): ):
torch.ops.sgl_kernels.moe_align_block_size( torch.ops.sgl_kernel.moe_align_block_size(
topk_ids, topk_ids,
num_experts, num_experts,
block_size, block_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment